diff --git a/comfy/ops.py b/comfy/ops.py index d7596634..43ed55ad 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -20,9 +20,13 @@ import torch import comfy.model_management from comfy.cli_args import args -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True): - if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): - return weight +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r