diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 1c271b82..952db753 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -221,6 +221,55 @@ class LatentOperationTonemapReinhard: return normalized_latent * new_magnitude return (tonemap_reinhard,) +class LatentOperationSharpen: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "sharpen_radius": ("INT", { + "default": 9, + "min": 1, + "max": 31, + "step": 1 + }), + "sigma": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + "alpha": ("FLOAT", { + "default": 0.1, + "min": 0.0, + "max": 5.0, + "step": 0.01 + }), + }} + + RETURN_TYPES = ("LATENT_OPERATION",) + FUNCTION = "op" + + CATEGORY = "latent/advanced/operations" + EXPERIMENTAL = True + + def op(self, sharpen_radius, sigma, alpha): + def sharpen(latent, **kwargs): + luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None] + normalized_latent = latent / luminance + channels = latent.shape[1] + + kernel_size = sharpen_radius * 2 + 1 + kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device) + center = kernel_size // 2 + + kernel *= alpha * -10 + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 + + padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect') + sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] + + return luminance * sharpened + return (sharpen,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, @@ -231,4 +280,5 @@ NODE_CLASS_MAPPINGS = { "LatentApplyOperation": LatentApplyOperation, "LatentApplyOperationCFG": LatentApplyOperationCFG, "LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, + "LatentOperationSharpen": LatentOperationSharpen, }