diff --git a/comfy/model_management.py b/comfy/model_management.py index 6e273839..686f124c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -688,6 +688,9 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0): if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device + if is_device_mps(load_device): + return offload_device + mem_l = get_free_memory(load_device) mem_o = get_free_memory(offload_device) if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: diff --git a/comfy/sd.py b/comfy/sd.py index c88f5a30..c8a2f086 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -79,6 +79,9 @@ class CLIP: for dt in self.cond_stage_model.dtypes: if not model_management.supports_cast(load_device, dt): load_device = offload_device + if params['device'] != offload_device: + self.cond_stage_model.to(offload_device) + logging.warning("Had to shift TE back.") self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)