diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 35d44164..d8655764 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -53,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") +parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") + fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 1161c244..c24c7b27 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -449,6 +449,8 @@ def unet_inital_load_device(parameters, dtype): return cpu_dev def unet_dtype(device=None, model_params=0): + if args.bf16_unet: + return torch.bfloat16 if should_use_fp16(device=device, model_params=model_params): return torch.float16 return torch.float32