From 2feb8d0b77ce80a471ecab84b92d5bbcaa37f8fe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 03:50:27 -0500 Subject: [PATCH] Force safe loading of files in torch format on pytorch 2.4+ If this breaks something for you make an issue. --- comfy/utils.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index b486b2de..bcefa180 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,17 +29,29 @@ import itertools from torch.nn.functional import interpolate from einops import rearrange +ALWAYS_SAFE_LOAD = False +if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated + class ModelCheckpoint: + pass + ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" + + from numpy.core.multiarray import scalar + from numpy import dtype + from numpy.dtypes import Float64DType + from _codecs import encode + + torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) + ALWAYS_SAFE_LOAD = True + logging.info("Checkpoint files will always be loaded safely.") + + def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: - if safe_load: - if not 'weights_only' in torch.load.__code__.co_varnames: - logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") - safe_load = False - if safe_load: + if safe_load or ALWAYS_SAFE_LOAD: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)