mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Automatically use fp8 for diffusion model weights if:
Checkpoint contains weights in fp8. There isn't enough memory to load the diffusion model in GPU vram.
This commit is contained in:
parent
f123328b82
commit
ba9095e5bd
@ -94,6 +94,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.concat_keys = ()
|
self.concat_keys = ()
|
||||||
logging.info("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
|
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
@ -527,6 +527,9 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
else:
|
else:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
|
|
||||||
|
def maximum_vram_for_weights(device=None):
|
||||||
|
return (get_total_memory(device) * 0.8 - minimum_inference_memory())
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
@ -536,6 +539,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
if args.fp8_e5m2_unet:
|
if args.fp8_e5m2_unet:
|
||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
|
|
||||||
|
fp8_dtype = None
|
||||||
|
try:
|
||||||
|
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
if dtype in supported_dtypes:
|
||||||
|
fp8_dtype = dtype
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if fp8_dtype is not None:
|
||||||
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
|
if model_params * 2 > free_model_memory:
|
||||||
|
return fp8_dtype
|
||||||
|
|
||||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||||
if torch.float16 in supported_dtypes:
|
if torch.float16 in supported_dtypes:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
@ -871,7 +889,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
fp16_works = True
|
fp16_works = True
|
||||||
|
|
||||||
if fp16_works or manual_cast:
|
if fp16_works or manual_cast:
|
||||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -920,7 +938,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
bf16_works = torch.cuda.is_bf16_supported()
|
bf16_works = torch.cuda.is_bf16_supported()
|
||||||
|
|
||||||
if bf16_works or manual_cast:
|
if bf16_works or manual_cast:
|
||||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -510,13 +510,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes)
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
@ -40,9 +40,19 @@ def calculate_parameters(sd, prefix=""):
|
|||||||
params = 0
|
params = 0
|
||||||
for k in sd.keys():
|
for k in sd.keys():
|
||||||
if k.startswith(prefix):
|
if k.startswith(prefix):
|
||||||
params += sd[k].nelement()
|
w = sd[k]
|
||||||
|
params += w.nelement()
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def weight_dtype(sd, prefix=""):
|
||||||
|
dtypes = {}
|
||||||
|
for k in sd.keys():
|
||||||
|
if k.startswith(prefix):
|
||||||
|
w = sd[k]
|
||||||
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||||
|
|
||||||
|
return max(dtypes, key=dtypes.get)
|
||||||
|
|
||||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||||
for x in keys_to_replace:
|
for x in keys_to_replace:
|
||||||
if x in state_dict:
|
if x in state_dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user