diff --git a/comfy/sd.py b/comfy/sd.py index dafa6f52..6948fcac 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,7 +17,10 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False)