diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index cedf39d6..2eefc4c5 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -3,9 +3,7 @@ import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: - latent.movedim(1, -1) latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - latent.movedim(-1, 1) return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) @@ -102,9 +100,32 @@ class LatentInterpolate: samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) +class LatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "batch" + + CATEGORY = "latent/batch" + + def batch(self, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + if s1.shape[1:] != s2.shape[1:]: + s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, + "LatentBatch": LatentBatch, }