mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Load TE model straight to vram if certain conditions are met.
This commit is contained in:
parent
e9589d6d92
commit
5c69cde037
@ -684,6 +684,17 @@ def text_encoder_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||||
|
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
mem_l = get_free_memory(load_device)
|
||||||
|
mem_o = get_free_memory(offload_device)
|
||||||
|
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||||
|
return load_device
|
||||||
|
else:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
def text_encoder_dtype(device=None):
|
def text_encoder_dtype(device=None):
|
||||||
if args.fp8_e4m3fn_text_enc:
|
if args.fp8_e4m3fn_text_enc:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
|
12
comfy/sd.py
12
comfy/sd.py
@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@ -71,10 +71,9 @@ class CLIP:
|
|||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
|
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
for dt in self.cond_stage_model.dtypes:
|
for dt in self.cond_stage_model.dtypes:
|
||||||
@ -84,7 +83,7 @@ class CLIP:
|
|||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -456,7 +455,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, state_dicts=clip_data)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@ -554,7 +553,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
|
Loading…
Reference in New Issue
Block a user