mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Make it easy for models to process the unet state dict on load.
This commit is contained in:
parent
2dd5b4dd78
commit
ce67dcbcda
@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if k.startswith(unet_prefix):
|
if k.startswith(unet_prefix):
|
||||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||||
|
|
||||||
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("unet missing:", m)
|
print("unet missing:", m)
|
||||||
|
@ -53,6 +53,9 @@ class BASE:
|
|||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "cond_stage_model."}
|
replace_prefix = {"": "cond_stage_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
Loading…
Reference in New Issue
Block a user