Use fp16 as the default vae dtype for the audio VAE.

This commit is contained in:
comfyanonymous 2024-06-16 13:12:54 -04:00
parent 8ddc151a4c
commit 6425252c4f
2 changed files with 24 additions and 16 deletions

View File

@ -167,7 +167,7 @@ if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
VAE_DTYPE = torch.float32 VAE_DTYPES = [torch.float32]
try: try:
if is_nvidia(): if is_nvidia():
@ -176,7 +176,7 @@ try:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPE = torch.bfloat16 VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if is_intel_xpu(): if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
@ -184,17 +184,10 @@ except:
pass pass
if is_intel_xpu(): if is_intel_xpu():
VAE_DTYPE = torch.bfloat16 VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if args.cpu_vae: if args.cpu_vae:
VAE_DTYPE = torch.float32 VAE_DTYPES = [torch.float32]
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
VAE_DTYPE = torch.bfloat16
elif args.fp32_vae:
VAE_DTYPE = torch.float32
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
@ -258,7 +251,6 @@ try:
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
logging.info("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = [] current_loaded_models = []
@ -619,9 +611,22 @@ def vae_offload_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def vae_dtype(): def vae_dtype(device=None, allowed_dtypes=[]):
global VAE_DTYPE global VAE_DTYPES
return VAE_DTYPE if args.fp16_vae:
return torch.float16
elif args.bf16_vae:
return torch.bfloat16
elif args.fp32_vae:
return torch.float32
for d in allowed_dtypes:
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
return d
if d in VAE_DTYPES:
return d
return VAE_DTYPES[0]
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):

View File

@ -178,6 +178,7 @@ class VAE:
self.output_channels = 3 self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd: if "decoder.mid.block_1.mix_factor" in sd:
@ -245,6 +246,7 @@ class VAE:
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.process_output = lambda audio: audio self.process_output = lambda audio: audio
self.process_input = lambda audio: audio self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -265,12 +267,13 @@ class VAE:
self.device = device self.device = device
offload_device = model_management.vae_offload_device() offload_device = model_management.vae_offload_device()
if dtype is None: if dtype is None:
dtype = model_management.vae_dtype() dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype) self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device() self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
dims = pixels.shape[1:-1] dims = pixels.shape[1:-1]