mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Add some command line arguments to store text encoder weights in fp8.
Pytorch supports two variants of fp8: --fp8_e4m3fn-text-enc (the one that seems to give better results) --fp8_e5m2-text-enc
This commit is contained in:
parent
107e78b1cb
commit
0cf4e86939
@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
|
|||||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||||
|
|
||||||
|
fpte_group = parser.add_mutually_exclusive_group()
|
||||||
|
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
||||||
|
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||||
|
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||||
|
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||||
|
@ -482,6 +482,21 @@ def text_encoder_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_dtype(device=None):
|
||||||
|
if args.fp8_e4m3fn_text_enc:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
elif args.fp8_e5m2_text_enc:
|
||||||
|
return torch.float8_e5m2
|
||||||
|
elif args.fp16_text_enc:
|
||||||
|
return torch.float16
|
||||||
|
elif args.fp32_text_enc:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
if should_use_fp16(device, prioritize_performance=False):
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def vae_device():
|
def vae_device():
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
|
|
||||||
|
@ -95,10 +95,7 @@ class CLIP:
|
|||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
params['device'] = offload_device
|
||||||
if model_management.should_use_fp16(load_device, prioritize_performance=False):
|
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||||
params['dtype'] = torch.float16
|
|
||||||
else:
|
|
||||||
params['dtype'] = torch.float32
|
|
||||||
|
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user