diff --git a/comfy/model_management.py b/comfy/model_management.py index a0105131..6e273839 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -684,6 +684,17 @@ def text_encoder_device(): else: return torch.device("cpu") +def text_encoder_initial_device(load_device, offload_device, model_size=0): + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: + return offload_device + + mem_l = get_free_memory(load_device) + mem_o = get_free_memory(offload_device) + if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l: + return load_device + else: + return offload_device + def text_encoder_dtype(device=None): if args.fp8_e4m3fn_text_enc: return torch.float8_e4m3fn diff --git a/comfy/sd.py b/comfy/sd.py index ee91ad53..6d729929 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0): if no_init: return params = target.params.copy() @@ -71,10 +71,9 @@ class CLIP: load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() - params['device'] = offload_device dtype = model_management.text_encoder_dtype(load_device) params['dtype'] = dtype - + params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)) self.cond_stage_model = clip(**(params)) for dt in self.cond_stage_model.dtypes: @@ -84,7 +83,7 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.layer_idx = None - logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) + logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) def clone(self): n = CLIP(no_init=True) @@ -456,7 +455,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) + clip = CLIP(clip_target, embedding_directory=embedding_directory, state_dicts=clip_data) for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: @@ -554,7 +553,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd) + parameters = comfy.utils.calculate_parameters(clip_sd) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters) m, u = clip.load_sd(clip_sd, full_model=True) if len(m) > 0: m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))