diff --git a/comfy/sd.py b/comfy/sd.py index 8c056e4e..220637a0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -151,7 +151,7 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None, dtype=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -188,7 +188,9 @@ class VAE: device = model_management.vae_device() self.device = device offload_device = model_management.vae_offload_device() - self.vae_dtype = model_management.vae_dtype() + if dtype is None: + dtype = model_management.vae_dtype() + self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device()