mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Use fp16 as the default vae dtype for the audio VAE.
This commit is contained in:
parent
8ddc151a4c
commit
6425252c4f
@ -167,7 +167,7 @@ if args.use_pytorch_cross_attention:
|
|||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPES = [torch.float32]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
@ -176,7 +176,7 @@ try:
|
|||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@ -184,17 +184,10 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||||
|
|
||||||
if args.cpu_vae:
|
if args.cpu_vae:
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPES = [torch.float32]
|
||||||
|
|
||||||
if args.fp16_vae:
|
|
||||||
VAE_DTYPE = torch.float16
|
|
||||||
elif args.bf16_vae:
|
|
||||||
VAE_DTYPE = torch.bfloat16
|
|
||||||
elif args.fp32_vae:
|
|
||||||
VAE_DTYPE = torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
@ -258,7 +251,6 @@ try:
|
|||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
|
||||||
logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
|
||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
@ -619,9 +611,22 @@ def vae_offload_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype():
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||||
global VAE_DTYPE
|
global VAE_DTYPES
|
||||||
return VAE_DTYPE
|
if args.fp16_vae:
|
||||||
|
return torch.float16
|
||||||
|
elif args.bf16_vae:
|
||||||
|
return torch.bfloat16
|
||||||
|
elif args.fp32_vae:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
for d in allowed_dtypes:
|
||||||
|
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
||||||
|
return d
|
||||||
|
if d in VAE_DTYPES:
|
||||||
|
return d
|
||||||
|
|
||||||
|
return VAE_DTYPES[0]
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
|
@ -178,6 +178,7 @@ class VAE:
|
|||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
@ -245,6 +246,7 @@ class VAE:
|
|||||||
self.downscale_ratio = 2048
|
self.downscale_ratio = 2048
|
||||||
self.process_output = lambda audio: audio
|
self.process_output = lambda audio: audio
|
||||||
self.process_input = lambda audio: audio
|
self.process_input = lambda audio: audio
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -265,12 +267,13 @@ class VAE:
|
|||||||
self.device = device
|
self.device = device
|
||||||
offload_device = model_management.vae_offload_device()
|
offload_device = model_management.vae_offload_device()
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.vae_dtype()
|
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||||
self.vae_dtype = dtype
|
self.vae_dtype = dtype
|
||||||
self.first_stage_model.to(self.vae_dtype)
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
|
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
dims = pixels.shape[1:-1]
|
dims = pixels.shape[1:-1]
|
||||||
|
Loading…
Reference in New Issue
Block a user