Add a SetLatentNoiseMask node.

LATENT is now a dict that can contain properties.
This commit is contained in:
comfyanonymous 2023-02-15 16:58:55 -05:00
parent cef2cc3cb0
commit d75003001a

View File

@ -98,7 +98,7 @@ class VAEDecode:
CATEGORY = "latent" CATEGORY = "latent"
def decode(self, vae, samples): def decode(self, vae, samples):
return (vae.decode(samples), ) return (vae.decode(samples["samples"]), )
class VAEEncode: class VAEEncode:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
@ -117,7 +117,9 @@ class VAEEncode:
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] pixels = pixels[:,:x,:y,:]
return (vae.encode(pixels), ) t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )
class CheckpointLoader: class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
@ -212,7 +214,7 @@ class EmptyLatentImage:
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (latent, ) return ({"samples":latent}, )
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center": if crop == "center":
@ -247,7 +249,8 @@ class LatentUpscale:
CATEGORY = "latent" CATEGORY = "latent"
def upscale(self, samples, upscale_method, width, height, crop): def upscale(self, samples, upscale_method, width, height, crop):
s = common_upscale(samples, width // 8, height // 8, upscale_method, crop) s = samples.copy()
s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
return (s,) return (s,)
class LatentRotate: class LatentRotate:
@ -262,6 +265,7 @@ class LatentRotate:
CATEGORY = "latent" CATEGORY = "latent"
def rotate(self, samples, rotation): def rotate(self, samples, rotation):
s = samples.copy()
rotate_by = 0 rotate_by = 0
if rotation.startswith("90"): if rotation.startswith("90"):
rotate_by = 1 rotate_by = 1
@ -270,7 +274,7 @@ class LatentRotate:
elif rotation.startswith("270"): elif rotation.startswith("270"):
rotate_by = 3 rotate_by = 3
s = torch.rot90(samples, k=rotate_by, dims=[3, 2]) s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
return (s,) return (s,)
class LatentFlip: class LatentFlip:
@ -285,12 +289,11 @@ class LatentFlip:
CATEGORY = "latent" CATEGORY = "latent"
def flip(self, samples, flip_method): def flip(self, samples, flip_method):
s = samples.copy()
if flip_method.startswith("x"): if flip_method.startswith("x"):
s = torch.flip(samples, dims=[2]) s["samples"] = torch.flip(samples["samples"], dims=[2])
elif flip_method.startswith("y"): elif flip_method.startswith("y"):
s = torch.flip(samples, dims=[3]) s["samples"] = torch.flip(samples["samples"], dims=[3])
else:
s = samples
return (s,) return (s,)
@ -312,12 +315,15 @@ class LatentComposite:
x = x // 8 x = x // 8
y = y // 8 y = y // 8
feather = feather // 8 feather = feather // 8
s = samples_to.clone() samples_out = samples_to.copy()
s = samples_to["samples"].clone()
samples_to = samples_to["samples"]
samples_from = samples_from["samples"]
if feather == 0: if feather == 0:
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
else: else:
s_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
mask = torch.ones_like(s_from) mask = torch.ones_like(samples_from)
for t in range(feather): for t in range(feather):
if y != 0: if y != 0:
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
@ -330,7 +336,8 @@ class LatentComposite:
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
rev_mask = torch.ones_like(mask) - mask rev_mask = torch.ones_like(mask) - mask
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
return (s,) samples_out["samples"] = s
return (samples_out,)
class LatentCrop: class LatentCrop:
@classmethod @classmethod
@ -347,6 +354,8 @@ class LatentCrop:
CATEGORY = "latent" CATEGORY = "latent"
def crop(self, samples, width, height, x, y): def crop(self, samples, width, height, x, y):
s = samples.copy()
samples = samples['samples']
x = x // 8 x = x // 8
y = y // 8 y = y // 8
@ -370,15 +379,46 @@ class LatentCrop:
#make sure size is always multiple of 64 #make sure size is always multiple of 64
x, to_x = enforce_image_dim(x, to_x, samples.shape[3]) x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
y, to_y = enforce_image_dim(y, to_y, samples.shape[2]) y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
s = samples[:,:,y:to_y, x:to_x] s['samples'] = samples[:,:,y:to_y, x:to_x]
return (s,) return (s,)
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): class SetLatentNoiseMask:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"mask": ("MASK",),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "set_mask"
CATEGORY = "latent"
def set_mask(self, samples, mask):
s = samples.copy()
s["noise_mask"] = mask
return (s,)
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
noise_mask = None
if disable_noise: if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else: else:
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
if "noise_mask" in latent:
noise_mask = latent['noise_mask']
print(noise_mask.shape, noise.shape)
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.floor()
noise_mask = torch.ones_like(noise_mask) - noise_mask
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
real_model = None real_model = None
if device != "cpu": if device != "cpu":
model_management.load_model_gpu(model) model_management.load_model_gpu(model)
@ -411,10 +451,11 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
#other samplers #other samplers
pass pass
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu() samples = samples.cpu()
out = latent.copy()
return (samples, ) out["samples"] = samples
return (out, )
class KSampler: class KSampler:
def __init__(self, device="cuda"): def __init__(self, device="cuda"):
@ -589,6 +630,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningCombine": ConditioningCombine, "ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite, "LatentComposite": LatentComposite,
"LatentRotate": LatentRotate, "LatentRotate": LatentRotate,
"LatentFlip": LatentFlip, "LatentFlip": LatentFlip,