Add a --bf16-unet to test running the unet in bf16.

This commit is contained in:
comfyanonymous 2023-10-13 14:51:10 -04:00
parent 9a55dadb4c
commit fd4c5f07e7
2 changed files with 4 additions and 0 deletions

View File

@ -53,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpvae_group = parser.add_mutually_exclusive_group() fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")

View File

@ -449,6 +449,8 @@ def unet_inital_load_device(parameters, dtype):
return cpu_dev return cpu_dev
def unet_dtype(device=None, model_params=0): def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if should_use_fp16(device=device, model_params=model_params): if should_use_fp16(device=device, model_params=model_params):
return torch.float16 return torch.float16
return torch.float32 return torch.float32