mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add a --bf16-unet to test running the unet in bf16.
This commit is contained in:
parent
9a55dadb4c
commit
fd4c5f07e7
@ -53,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
|
|||||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
|
||||||
|
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||||
|
@ -449,6 +449,8 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
return cpu_dev
|
return cpu_dev
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0):
|
def unet_dtype(device=None, model_params=0):
|
||||||
|
if args.bf16_unet:
|
||||||
|
return torch.bfloat16
|
||||||
if should_use_fp16(device=device, model_params=model_params):
|
if should_use_fp16(device=device, model_params=model_params):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
Loading…
Reference in New Issue
Block a user