mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 14:40:27 +00:00
Force safe loading of files in torch format on pytorch 2.4+
If this breaks something for you make an issue.
This commit is contained in:
parent
5b657f8c15
commit
2feb8d0b77
@ -29,17 +29,29 @@ import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
from numpy.core.multiarray import scalar
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
from _codecs import encode
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
|
Loading…
Reference in New Issue
Block a user