mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Add dtype parameter to VAE object.
This commit is contained in:
parent
32b7e7e769
commit
824e4935f5
@ -151,7 +151,7 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
@ -188,7 +188,9 @@ class VAE:
|
|||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
offload_device = model_management.vae_offload_device()
|
offload_device = model_management.vae_offload_device()
|
||||||
self.vae_dtype = model_management.vae_dtype()
|
if dtype is None:
|
||||||
|
dtype = model_management.vae_dtype()
|
||||||
|
self.vae_dtype = dtype
|
||||||
self.first_stage_model.to(self.vae_dtype)
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user