From ce5afecc362efb1bc9670744e18aeafe730a99d8 Mon Sep 17 00:00:00 2001 From: Austin Mroz Date: Mon, 23 Dec 2024 16:43:43 -0600 Subject: [PATCH] Only use events for devices supporting nonblocking --- latent_preview.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/latent_preview.py b/latent_preview.py index e3f620de..06ebf14a 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -18,16 +18,21 @@ class LatentPreviewer: # A previous preview is still being processed return None preview_tensor = self.decode_latent_to_preview(x0) + if comfy.model_management.device_supports_non_blocking(preview_tensor.device): + latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ).to(device="cpu", dtype=torch.uint8, non_blocking=True) + latents_rgbx = torch.zeros(latents_ubyte.shape[:2] + (4,), device="cpu", dtype=torch.uint8) + latents_rgbx[:,:,:3] = latents_ubyte + self.event = torch.cuda.Event() + self.event.record() + preview_image = Image.frombuffer('RGBX', (latents_ubyte.shape[1], latents_ubyte.shape[0]), + latents_rgbx.numpy().data, 'raw', 'RGBX', 0, 1) + return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION, self.event) latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 .mul(0xFF) # to 0..255 - ).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(preview_tensor.device)) - latents_rgbx = torch.zeros(latents_ubyte.shape[:2] + (4,), device="cpu", dtype=torch.uint8) - latents_rgbx[:,:,:3] = latents_ubyte - self.event = torch.cuda.Event() - self.event.record() - preview_image = Image.frombuffer('RGBX', (latents_ubyte.shape[1], latents_ubyte.shape[0]), - latents_rgbx.numpy().data, 'raw', 'RGBX', 0, 1) - return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION, self.event) + ).to(device="cpu", dtype=torch.uint8, non_blocking=False) + return ("JPEG", Image.fromarray(latents_ubyte.numpy()), MAX_PREVIEW_RESOLUTION) class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd):