diff --git a/comfy/model_management.py b/comfy/model_management.py index cbaa8087..aa40c502 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0): return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params): + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): return torch.float16 return torch.float32 @@ -696,7 +696,7 @@ def is_device_mps(device): return True return False -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled if device is not None: @@ -738,7 +738,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if x in props.name.lower(): fp16_works = True - if fp16_works: + if fp16_works or manual_cast: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True