mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add Condition by Mask node
This PR adds support for a Condition by Mask node. This node allows conditioning to be limited to a non-rectangle area.
This commit is contained in:
parent
e958dfdd4d
commit
e214c917ae
@ -6,6 +6,7 @@ import contextlib
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
|
from torchvision.ops import masks_to_boxes
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
@ -23,8 +24,20 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
adm_cond = cond[1]['adm_encoded']
|
adm_cond = cond[1]['adm_encoded']
|
||||||
|
|
||||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
mult = torch.ones_like(input_x) * strength
|
if 'mask' in cond[1]:
|
||||||
|
# Scale the mask to the size of the input
|
||||||
|
# The mask should have been resized as we began the sampling process
|
||||||
|
mask = cond[1]['mask']
|
||||||
|
assert(mask.shape[1] == x_in.shape[2])
|
||||||
|
assert(mask.shape[2] == x_in.shape[3])
|
||||||
|
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
|
if mask.shape[0] != input_x.shape[0]:
|
||||||
|
mask = mask.repeat(input_x.shape[0], 1, 1)
|
||||||
|
else:
|
||||||
|
mask = torch.ones_like(input_x)
|
||||||
|
mult = mask * strength
|
||||||
|
|
||||||
|
if 'mask' not in cond[1]:
|
||||||
rr = 8
|
rr = 8
|
||||||
if area[2] != 0:
|
if area[2] != 0:
|
||||||
for t in range(rr):
|
for t in range(rr):
|
||||||
@ -38,6 +51,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if (area[1] + area[3]) < x_in.shape[3]:
|
if (area[1] + area[3]) < x_in.shape[3]:
|
||||||
for t in range(rr):
|
for t in range(rr):
|
||||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||||
|
|
||||||
conditionning = {}
|
conditionning = {}
|
||||||
conditionning['c_crossattn'] = cond[0]
|
conditionning['c_crossattn'] = cond[0]
|
||||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||||
@ -301,6 +315,47 @@ def blank_inpaint_image_like(latent_image):
|
|||||||
blank_image[:,3] *= 0.1380
|
blank_image[:,3] *= 0.1380
|
||||||
return blank_image
|
return blank_image
|
||||||
|
|
||||||
|
def resolve_cond_masks(conditions, h, w, device):
|
||||||
|
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
||||||
|
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
||||||
|
for i in range(len(conditions)):
|
||||||
|
c = conditions[i]
|
||||||
|
if 'mask' in c[1]:
|
||||||
|
mask = c[1]['mask']
|
||||||
|
mask = mask.to(device=device)
|
||||||
|
modified = c[1].copy()
|
||||||
|
if len(mask.shape) == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
if mask.shape[2] != h or mask.shape[3] != w:
|
||||||
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
||||||
|
|
||||||
|
if 'area' not in modified:
|
||||||
|
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||||
|
if torch.max(bounds) == 0:
|
||||||
|
# Handle the edge-case of an all black mask (where masks_to_boxes would error)
|
||||||
|
area = (0, 0, 0, 0)
|
||||||
|
else:
|
||||||
|
box = masks_to_boxes(bounds)[0].type(torch.int)
|
||||||
|
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
|
||||||
|
# Make sure the height and width are divisible by 8
|
||||||
|
if X % 8 != 0:
|
||||||
|
newx = X // 8 * 8
|
||||||
|
W = W + (X - newx)
|
||||||
|
X = newx
|
||||||
|
if Y % 8 != 0:
|
||||||
|
newy = Y // 8 * 8
|
||||||
|
H = H + (Y - newy)
|
||||||
|
Y = newy
|
||||||
|
if H % 8 != 0:
|
||||||
|
H = H + (8 - (H % 8))
|
||||||
|
if W % 8 != 0:
|
||||||
|
W = W + (8 - (W % 8))
|
||||||
|
area = (int(H), int(W), int(Y), (X))
|
||||||
|
modified['area'] = area
|
||||||
|
|
||||||
|
modified['mask'] = mask
|
||||||
|
conditions[i] = [c[0], modified]
|
||||||
|
|
||||||
def create_cond_with_same_area_if_none(conds, c):
|
def create_cond_with_same_area_if_none(conds, c):
|
||||||
if 'area' not in c[1]:
|
if 'area' not in c[1]:
|
||||||
return
|
return
|
||||||
@ -461,7 +516,6 @@ class KSampler:
|
|||||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||||
self.sigmas = sigmas[-(steps + 1):]
|
self.sigmas = sigmas[-(steps + 1):]
|
||||||
|
|
||||||
|
|
||||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
|
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
|
||||||
if sigmas is None:
|
if sigmas is None:
|
||||||
sigmas = self.sigmas
|
sigmas = self.sigmas
|
||||||
@ -484,6 +538,10 @@ class KSampler:
|
|||||||
|
|
||||||
positive = positive[:]
|
positive = positive[:]
|
||||||
negative = negative[:]
|
negative = negative[:]
|
||||||
|
|
||||||
|
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
||||||
|
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
create_cond_with_same_area_if_none(negative, c)
|
create_cond_with_same_area_if_none(negative, c)
|
||||||
|
28
nodes.py
28
nodes.py
@ -85,6 +85,32 @@ class ConditioningSetArea:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class ConditioningSetMask:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0):
|
||||||
|
c = []
|
||||||
|
if len(mask.shape) < 3:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
for t in conditioning:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
_, h, w = mask.shape
|
||||||
|
n[1]['mask'] = mask
|
||||||
|
n[1]['strength'] = strength
|
||||||
|
n[1]['min_sigma'] = min_sigma
|
||||||
|
n[1]['max_sigma'] = max_sigma
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu"):
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -1115,6 +1141,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ImagePadForOutpaint": ImagePadForOutpaint,
|
"ImagePadForOutpaint": ImagePadForOutpaint,
|
||||||
"ConditioningCombine": ConditioningCombine,
|
"ConditioningCombine": ConditioningCombine,
|
||||||
"ConditioningSetArea": ConditioningSetArea,
|
"ConditioningSetArea": ConditioningSetArea,
|
||||||
|
"ConditioningSetMask": ConditioningSetMask,
|
||||||
"KSamplerAdvanced": KSamplerAdvanced,
|
"KSamplerAdvanced": KSamplerAdvanced,
|
||||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||||
"LatentComposite": LatentComposite,
|
"LatentComposite": LatentComposite,
|
||||||
@ -1164,6 +1191,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CLIPSetLastLayer": "CLIP Set Last Layer",
|
"CLIPSetLastLayer": "CLIP Set Last Layer",
|
||||||
"ConditioningCombine": "Conditioning (Combine)",
|
"ConditioningCombine": "Conditioning (Combine)",
|
||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet",
|
"ControlNetApply": "Apply ControlNet",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
|
Loading…
Reference in New Issue
Block a user