Only use events for devices supporting nonblocking

This commit is contained in:
Austin Mroz 2024-12-23 16:43:43 -06:00
parent 813b8df872
commit ce5afecc36
No known key found for this signature in database

View File

@ -18,9 +18,10 @@ class LatentPreviewer:
# A previous preview is still being processed # A previous preview is still being processed
return None return None
preview_tensor = self.decode_latent_to_preview(x0) preview_tensor = self.decode_latent_to_preview(x0)
if comfy.model_management.device_supports_non_blocking(preview_tensor.device):
latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255 .mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(preview_tensor.device)) ).to(device="cpu", dtype=torch.uint8, non_blocking=True)
latents_rgbx = torch.zeros(latents_ubyte.shape[:2] + (4,), device="cpu", dtype=torch.uint8) latents_rgbx = torch.zeros(latents_ubyte.shape[:2] + (4,), device="cpu", dtype=torch.uint8)
latents_rgbx[:,:,:3] = latents_ubyte latents_rgbx[:,:,:3] = latents_ubyte
self.event = torch.cuda.Event() self.event = torch.cuda.Event()
@ -28,6 +29,10 @@ class LatentPreviewer:
preview_image = Image.frombuffer('RGBX', (latents_ubyte.shape[1], latents_ubyte.shape[0]), preview_image = Image.frombuffer('RGBX', (latents_ubyte.shape[1], latents_ubyte.shape[0]),
latents_rgbx.numpy().data, 'raw', 'RGBX', 0, 1) latents_rgbx.numpy().data, 'raw', 'RGBX', 0, 1)
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION, self.event) return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION, self.event)
latents_ubyte = (((preview_tensor + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=False)
return ("JPEG", Image.fromarray(latents_ubyte.numpy()), MAX_PREVIEW_RESOLUTION)
class TAESDPreviewerImpl(LatentPreviewer): class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd): def __init__(self, taesd):