diff --git a/comfy/model_base.py b/comfy/model_base.py index ec15e9fc..94f4d333 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -94,6 +94,7 @@ class BaseModel(torch.nn.Module): self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) + logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): diff --git a/comfy/model_management.py b/comfy/model_management.py index da0b989a..c0fb1509 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -527,6 +527,9 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev +def maximum_vram_for_weights(device=None): + return (get_total_memory(device) * 0.8 - minimum_inference_memory()) + def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: return torch.bfloat16 @@ -536,6 +539,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 + + fp8_dtype = None + try: + for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if dtype in supported_dtypes: + fp8_dtype = dtype + break + except: + pass + + if fp8_dtype is not None: + free_model_memory = maximum_vram_for_weights(device) + if model_params * 2 > free_model_memory: + return fp8_dtype + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if torch.float16 in supported_dtypes: return torch.float16 @@ -871,7 +889,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma fp16_works = True if fp16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -920,7 +938,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma bf16_works = torch.cuda.is_bf16_supported() if bf16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True diff --git a/comfy/sd.py b/comfy/sd.py index 41ce18c8..bf336c85 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -510,13 +510,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) + weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/utils.py b/comfy/utils.py index 0db9fbb6..d9fe36f9 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -40,9 +40,19 @@ def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): - params += sd[k].nelement() + w = sd[k] + params += w.nelement() return params +def weight_dtype(sd, prefix=""): + dtypes = {} + for k in sd.keys(): + if k.startswith(prefix): + w = sd[k] + dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + + return max(dtypes, key=dtypes.get) + def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: if x in state_dict: