Compare commits

...

4 Commits

Author SHA1 Message Date
AustinMroz
9fd3351de3
Merge ce5afecc36 into 2307ff6746 2025-01-08 19:17:00 -05:00
comfyanonymous
2307ff6746 Improve some logging messages. 2025-01-08 19:05:22 -05:00
Austin Mroz
ce5afecc36
Only use events for devices supporting nonblocking 2024-12-23 16:43:43 -06:00
Austin Mroz
813b8df872
Fix race condition in preview code.
In the previous preview code, when possible, a non-blocking `to`
operation is performed and, immediately after, the output tensor is used
to create an image. If this non-blocking operation has not completed,
PIL makes a copy of the uninitialized memory to produce an image.
Generally, this will either contain zeros, or the result of a previously
generated preview. This results in both incorrect output, and wasted
computation (unless the memory this output was eventually copied to is
reallocated and displayed instead of a future preview).

To resolve this, the state of the preview generation is tracked with an
event.
- The PIL image is created with no copy
- The preview image is not sent to from the server until ready
- Completion of this event is polled with a reasonably slow frequency
- A new preview is not created if a previous preview has not completed
2024-12-21 04:13:35 -06:00
4 changed files with 29 additions and 14 deletions

View File

@ -111,7 +111,7 @@ class CLIP:
model_management.load_models_gpu([self.patcher], force_full_load=True) model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None self.layer_idx = None
self.use_clip_schedule = False self.use_clip_schedule = False
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
@ -898,7 +898,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model: if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU") logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True) model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)

View File

@ -9,20 +9,30 @@ import logging
MAX_PREVIEW_RESOLUTION = args.preview_size MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
return Image.fromarray(latents_ubyte.numpy())
class LatentPreviewer: class LatentPreviewer:
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
pass pass
def decode_latent_to_preview_image(self, preview_format, x0): def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0) if hasattr(self, 'event') and not self.event.query():
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) # A previous preview is still being processed
return None
preview_tensor = self.decode_latent_to_preview(x0)
if comfy.model_management.device_supports_non_blocking(preview_tensor.device):
latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=True)
latents_rgbx = torch.zeros(latents_ubyte.shape[:2] + (4,), device="cpu", dtype=torch.uint8)
latents_rgbx[:,:,:3] = latents_ubyte
self.event = torch.cuda.Event()
self.event.record()
preview_image = Image.frombuffer('RGBX', (latents_ubyte.shape[1], latents_ubyte.shape[0]),
latents_rgbx.numpy().data, 'raw', 'RGBX', 0, 1)
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION, self.event)
latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=False)
return ("JPEG", Image.fromarray(latents_ubyte.numpy()), MAX_PREVIEW_RESOLUTION)
class TAESDPreviewerImpl(LatentPreviewer): class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd): def __init__(self, taesd):
@ -30,7 +40,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
return preview_to_image(x_sample) return x_sample
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
@ -53,7 +63,7 @@ class Latent2RGBPreviewer(LatentPreviewer):
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias) latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image) return latent_image
def get_previewer(device, latent_format): def get_previewer(device, latent_format):

View File

@ -4,7 +4,8 @@ lint.ignore = ["ALL"]
# Enable specific rules # Enable specific rules
lint.select = [ lint.select = [
"S307", # suspicious-eval-usage "S307", # suspicious-eval-usage
"T201", # print-usage "S102", # exec
"T", # print-usage
"W", "W",
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f

View File

@ -752,6 +752,10 @@ class PromptServer():
image_type = image_data[0] image_type = image_data[0]
image = image_data[1] image = image_data[1]
max_size = image_data[2] max_size = image_data[2]
if len(image_data) > 3:
event = image_data[3]
while not event.query():
await asyncio.sleep(.01)
if max_size is not None: if max_size is not None:
if hasattr(Image, 'Resampling'): if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR resampling = Image.Resampling.BILINEAR