Support loading some checkpoint files with nested dicts.

This commit is contained in:
comfyanonymous 2024-12-11 08:04:54 -05:00
parent 44db978531
commit 7a7efe8424

View File

@ -46,7 +46,13 @@ def load_torch_file(ckpt, safe_load=False, device=None):
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
if len(pl_sd) == 1:
key = list(pl_sd.keys())[0]
sd = pl_sd[key]
if not isinstance(sd, dict):
sd = pl_sd
else:
sd = pl_sd
return sd
def save_torch_file(sd, ckpt, metadata=None):