Fix issue with regular torch version.

This commit is contained in:
comfyanonymous 2023-08-17 01:58:54 -04:00
parent 89a0767abf
commit 2be2742711

View File

@ -390,7 +390,11 @@ def unet_inital_load_device(parameters, dtype):
return torch_dev
cpu_dev = torch.device("cpu")
model_size = dtype.itemsize * parameters
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
model_size = dtype_size * parameters
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)