mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
--gpu-only now keeps the VAE on the device.
This commit is contained in:
parent
ce35d8c659
commit
1c1b0e7299
@ -349,6 +349,15 @@ def text_encoder_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def vae_device():
|
||||||
|
return get_torch_device()
|
||||||
|
|
||||||
|
def vae_offload_device():
|
||||||
|
if args.gpu_only or vram_state == VRAMState.SHARED:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
|
11
comfy/sd.py
11
comfy/sd.py
@ -605,8 +605,9 @@ class VAE:
|
|||||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.vae_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.offload_device = model_management.vae_offload_device()
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@ -651,7 +652,7 @@ class VAE:
|
|||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
@ -659,7 +660,7 @@ class VAE:
|
|||||||
model_management.unload_model()
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return output.movedim(1,-1)
|
return output.movedim(1,-1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
@ -679,7 +680,7 @@ class VAE:
|
|||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
@ -687,7 +688,7 @@ class VAE:
|
|||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user