diff --git a/comfy/model_base.py b/comfy/model_base.py index 898adb66..8c89adf5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -503,8 +503,10 @@ class SD15_instructpix2pix(IP2P, BaseModel): class SDXL_instructpix2pix(IP2P, SDXL): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) - # self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) - self.process_ip2p_image_in = lambda image: image + if model_type == ModelType.V_PREDICTION_EDM: + self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p + else: + self.process_ip2p_image_in = lambda image: image #diffusers ip2p class StableCascade_C(BaseModel): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9bfe3ea1..b3b69e05 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE): self.sampling_settings["sigma_max"] = 80.0 self.sampling_settings["sigma_min"] = 0.002 return model_base.ModelType.EDM + elif "edm_vpred.sigma_max" in state_dict: + self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item()) + if "edm_vpred.sigma_min" in state_dict: + self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item()) + return model_base.ModelType.V_PREDICTION_EDM elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: @@ -469,7 +474,7 @@ class SDXL_instructpix2pix(SDXL): } def get_model(self, state_dict, prefix="", device=None): - return model_base.SDXL_instructpix2pix(self, device=device) + return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]