From 813b8df872fe9e6f5b36dcdadab3d69372d8b2d2 Mon Sep 17 00:00:00 2001 From: Austin Mroz Date: Thu, 19 Dec 2024 13:17:00 -0600 Subject: [PATCH] Fix race condition in preview code. In the previous preview code, when possible, a non-blocking `to` operation is performed and, immediately after, the output tensor is used to create an image. If this non-blocking operation has not completed, PIL makes a copy of the uninitialized memory to produce an image. Generally, this will either contain zeros, or the result of a previously generated preview. This results in both incorrect output, and wasted computation (unless the memory this output was eventually copied to is reallocated and displayed instead of a future preview). To resolve this, the state of the preview generation is tracked with an event. - The PIL image is created with no copy - The preview image is not sent to from the server until ready - Completion of this event is polled with a reasonably slow frequency - A new preview is not created if a previous preview has not completed --- latent_preview.py | 27 ++++++++++++++++----------- server.py | 4 ++++ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/latent_preview.py b/latent_preview.py index 07f9cc68..e3f620de 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -9,20 +9,25 @@ import logging MAX_PREVIEW_RESOLUTION = args.preview_size -def preview_to_image(latent_image): - latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) - - return Image.fromarray(latents_ubyte.numpy()) - class LatentPreviewer: def decode_latent_to_preview(self, x0): pass def decode_latent_to_preview_image(self, preview_format, x0): - preview_image = self.decode_latent_to_preview(x0) - return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) + if hasattr(self, 'event') and not self.event.query(): + # A previous preview is still being processed + return None + preview_tensor = self.decode_latent_to_preview(x0) + latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(preview_tensor.device)) + latents_rgbx = torch.zeros(latents_ubyte.shape[:2] + (4,), device="cpu", dtype=torch.uint8) + latents_rgbx[:,:,:3] = latents_ubyte + self.event = torch.cuda.Event() + self.event.record() + preview_image = Image.frombuffer('RGBX', (latents_ubyte.shape[1], latents_ubyte.shape[0]), + latents_rgbx.numpy().data, 'raw', 'RGBX', 0, 1) + return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION, self.event) class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd): @@ -30,7 +35,7 @@ class TAESDPreviewerImpl(LatentPreviewer): def decode_latent_to_preview(self, x0): x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) - return preview_to_image(x_sample) + return x_sample class Latent2RGBPreviewer(LatentPreviewer): @@ -53,7 +58,7 @@ class Latent2RGBPreviewer(LatentPreviewer): latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias) # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors - return preview_to_image(latent_image) + return latent_image def get_previewer(device, latent_format): diff --git a/server.py b/server.py index ddd71e06..f31d792b 100644 --- a/server.py +++ b/server.py @@ -750,6 +750,10 @@ class PromptServer(): image_type = image_data[0] image = image_data[1] max_size = image_data[2] + if len(image_data) > 3: + event = image_data[3] + while not event.query(): + await asyncio.sleep(.01) if max_size is not None: if hasattr(Image, 'Resampling'): resampling = Image.Resampling.BILINEAR