diff --git a/comfy/model_management.py b/comfy/model_management.py index 8b8d3ff0..04719329 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -167,7 +167,7 @@ if args.use_pytorch_cross_attention: ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False -VAE_DTYPE = torch.float32 +VAE_DTYPES = [torch.float32] try: if is_nvidia(): @@ -176,7 +176,7 @@ try: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: - VAE_DTYPE = torch.bfloat16 + VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True @@ -184,17 +184,10 @@ except: pass if is_intel_xpu(): - VAE_DTYPE = torch.bfloat16 + VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES if args.cpu_vae: - VAE_DTYPE = torch.float32 - -if args.fp16_vae: - VAE_DTYPE = torch.float16 -elif args.bf16_vae: - VAE_DTYPE = torch.bfloat16 -elif args.fp32_vae: - VAE_DTYPE = torch.float32 + VAE_DTYPES = [torch.float32] if ENABLE_PYTORCH_ATTENTION: @@ -258,7 +251,6 @@ try: except: logging.warning("Could not pick default device.") -logging.info("VAE dtype: {}".format(VAE_DTYPE)) current_loaded_models = [] @@ -619,9 +611,22 @@ def vae_offload_device(): else: return torch.device("cpu") -def vae_dtype(): - global VAE_DTYPE - return VAE_DTYPE +def vae_dtype(device=None, allowed_dtypes=[]): + global VAE_DTYPES + if args.fp16_vae: + return torch.float16 + elif args.bf16_vae: + return torch.bfloat16 + elif args.fp32_vae: + return torch.float32 + + for d in allowed_dtypes: + if d == torch.float16 and should_use_fp16(device, prioritize_performance=False): + return d + if d in VAE_DTYPES: + return d + + return VAE_DTYPES[0] def get_autocast_device(dev): if hasattr(dev, 'type'): diff --git a/comfy/sd.py b/comfy/sd.py index e16cd8e5..58a858aa 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -178,6 +178,7 @@ class VAE: self.output_channels = 3 self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + self.working_dtypes = [torch.bfloat16, torch.float32] if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -245,6 +246,7 @@ class VAE: self.downscale_ratio = 2048 self.process_output = lambda audio: audio self.process_input = lambda audio: audio + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -265,12 +267,13 @@ class VAE: self.device = device offload_device = model_management.vae_offload_device() if dtype is None: - dtype = model_management.vae_dtype() + dtype = model_management.vae_dtype(self.device, self.working_dtypes) self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) def vae_encode_crop_pixels(self, pixels): dims = pixels.shape[1:-1]