diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index eabae088..1c271b82 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -145,6 +145,82 @@ class LatentBatchSeedBehavior: return (samples_out,) +class LatentApplyOperation: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "operation": ("LATENT_OPERATION",), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced/operations" + EXPERIMENTAL = True + + def op(self, samples, operation): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = operation(latent=s1) + return (samples_out,) + +class LatentApplyOperationCFG: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "operation": ("LATENT_OPERATION",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "latent/advanced/operations" + EXPERIMENTAL = True + + def patch(self, model, operation): + m = model.clone() + + def pre_cfg_function(args): + conds_out = args["conds_out"] + if len(conds_out) == 2: + conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1] + else: + conds_out[0] = operation(latent=conds_out[0]) + return conds_out + + m.set_model_sampler_pre_cfg_function(pre_cfg_function) + return (m, ) + +class LatentOperationTonemapReinhard: + @classmethod + def INPUT_TYPES(s): + return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT_OPERATION",) + FUNCTION = "op" + + CATEGORY = "latent/advanced/operations" + EXPERIMENTAL = True + + def op(self, multiplier): + def tonemap_reinhard(latent, **kwargs): + latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] + normalized_latent = latent / latent_vector_magnitude + + mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True) + std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True) + + top = (std * 5 + mean) * multiplier + + #reinhard + latent_vector_magnitude *= (1.0 / top) + new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0) + new_magnitude *= top + + return normalized_latent * new_magnitude + return (tonemap_reinhard,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, @@ -152,4 +228,7 @@ NODE_CLASS_MAPPINGS = { "LatentInterpolate": LatentInterpolate, "LatentBatch": LatentBatch, "LatentBatchSeedBehavior": LatentBatchSeedBehavior, + "LatentApplyOperation": LatentApplyOperation, + "LatentApplyOperationCFG": LatentApplyOperationCFG, + "LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, }