Support loading some checkpoint files with nested dicts.

This commit is contained in:
comfyanonymous 2024-12-11 08:04:54 -05:00
parent 44db978531
commit 7a7efe8424

View File

@ -45,6 +45,12 @@ def load_torch_file(ckpt, safe_load=False, device=None):
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
if len(pl_sd) == 1:
key = list(pl_sd.keys())[0]
sd = pl_sd[key]
if not isinstance(sd, dict):
sd = pl_sd
else:
sd = pl_sd
return sd