mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Compare commits
4 Commits
3271d8fffc
...
9fd3351de3
Author | SHA1 | Date | |
---|---|---|---|
|
9fd3351de3 | ||
|
2307ff6746 | ||
|
ce5afecc36 | ||
|
813b8df872 |
@ -111,7 +111,7 @@ class CLIP:
|
|||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -898,7 +898,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_model:
|
if output_model:
|
||||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.info("loaded straight to GPU")
|
logging.info("loaded diffusion model directly to GPU")
|
||||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||||
|
|
||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
@ -9,20 +9,30 @@ import logging
|
|||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
|
||||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
|
||||||
.mul(0xFF) # to 0..255
|
|
||||||
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
|
||||||
|
|
||||||
return Image.fromarray(latents_ubyte.numpy())
|
|
||||||
|
|
||||||
class LatentPreviewer:
|
class LatentPreviewer:
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||||
preview_image = self.decode_latent_to_preview(x0)
|
if hasattr(self, 'event') and not self.event.query():
|
||||||
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
# A previous preview is still being processed
|
||||||
|
return None
|
||||||
|
preview_tensor = self.decode_latent_to_preview(x0)
|
||||||
|
if comfy.model_management.device_supports_non_blocking(preview_tensor.device):
|
||||||
|
latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
).to(device="cpu", dtype=torch.uint8, non_blocking=True)
|
||||||
|
latents_rgbx = torch.zeros(latents_ubyte.shape[:2] + (4,), device="cpu", dtype=torch.uint8)
|
||||||
|
latents_rgbx[:,:,:3] = latents_ubyte
|
||||||
|
self.event = torch.cuda.Event()
|
||||||
|
self.event.record()
|
||||||
|
preview_image = Image.frombuffer('RGBX', (latents_ubyte.shape[1], latents_ubyte.shape[0]),
|
||||||
|
latents_rgbx.numpy().data, 'raw', 'RGBX', 0, 1)
|
||||||
|
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION, self.event)
|
||||||
|
latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
).to(device="cpu", dtype=torch.uint8, non_blocking=False)
|
||||||
|
return ("JPEG", Image.fromarray(latents_ubyte.numpy()), MAX_PREVIEW_RESOLUTION)
|
||||||
|
|
||||||
class TAESDPreviewerImpl(LatentPreviewer):
|
class TAESDPreviewerImpl(LatentPreviewer):
|
||||||
def __init__(self, taesd):
|
def __init__(self, taesd):
|
||||||
@ -30,7 +40,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
|
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
|
||||||
return preview_to_image(x_sample)
|
return x_sample
|
||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
@ -53,7 +63,7 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||||||
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||||
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||||
|
|
||||||
return preview_to_image(latent_image)
|
return latent_image
|
||||||
|
|
||||||
|
|
||||||
def get_previewer(device, latent_format):
|
def get_previewer(device, latent_format):
|
||||||
|
@ -4,7 +4,8 @@ lint.ignore = ["ALL"]
|
|||||||
# Enable specific rules
|
# Enable specific rules
|
||||||
lint.select = [
|
lint.select = [
|
||||||
"S307", # suspicious-eval-usage
|
"S307", # suspicious-eval-usage
|
||||||
"T201", # print-usage
|
"S102", # exec
|
||||||
|
"T", # print-usage
|
||||||
"W",
|
"W",
|
||||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
|
@ -752,6 +752,10 @@ class PromptServer():
|
|||||||
image_type = image_data[0]
|
image_type = image_data[0]
|
||||||
image = image_data[1]
|
image = image_data[1]
|
||||||
max_size = image_data[2]
|
max_size = image_data[2]
|
||||||
|
if len(image_data) > 3:
|
||||||
|
event = image_data[3]
|
||||||
|
while not event.query():
|
||||||
|
await asyncio.sleep(.01)
|
||||||
if max_size is not None:
|
if max_size is not None:
|
||||||
if hasattr(Image, 'Resampling'):
|
if hasattr(Image, 'Resampling'):
|
||||||
resampling = Image.Resampling.BILINEAR
|
resampling = Image.Resampling.BILINEAR
|
||||||
|
Loading…
Reference in New Issue
Block a user