mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-02-28 22:51:45 +00:00
Force safe loading of files in torch format on pytorch 2.4+
If this breaks something for you make an issue.
This commit is contained in:
parent
5b657f8c15
commit
2feb8d0b77
@ -29,17 +29,29 @@ import itertools
|
|||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
ALWAYS_SAFE_LOAD = False
|
||||||
|
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||||
|
class ModelCheckpoint:
|
||||||
|
pass
|
||||||
|
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||||
|
|
||||||
|
from numpy.core.multiarray import scalar
|
||||||
|
from numpy import dtype
|
||||||
|
from numpy.dtypes import Float64DType
|
||||||
|
from _codecs import encode
|
||||||
|
|
||||||
|
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||||
|
ALWAYS_SAFE_LOAD = True
|
||||||
|
logging.info("Checkpoint files will always be loaded safely.")
|
||||||
|
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load or ALWAYS_SAFE_LOAD:
|
||||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
|
||||||
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
|
||||||
safe_load = False
|
|
||||||
if safe_load:
|
|
||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||||
|
Loading…
Reference in New Issue
Block a user