Default to sampling entire image

By default, when applying a mask to a condition, the entire image will
still be used for sampling. The new "set_area_to_bounds" option on the
node will allow the user to automatically limit conditioning to the
bounds of the mask.

I've also removed the dependency on torchvision for calculating bounding
boxes. I've taken the opportunity to fix some frustrating details in the
other version:
1. An all-0 mask will no longer cause an error
2. Indices are returned as integers instead of floats so they can be
   used to index into tensors.
This commit is contained in:
Jacob Segal 2023-04-29 00:16:58 -07:00
parent e214c917ae
commit af02393c2a
2 changed files with 35 additions and 11 deletions

View File

@ -6,7 +6,6 @@ import contextlib
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
from torchvision.ops import masks_to_boxes
#The main sampling function shared by all the samplers
#Returns predicted noise
@ -31,8 +30,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if mask.shape[0] != input_x.shape[0]:
mask = mask.repeat(input_x.shape[0], 1, 1)
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
@ -315,6 +313,29 @@ def blank_inpaint_image_like(latent_image):
blank_image[:,3] *= 0.1380
return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
b = masks.shape[0]
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
for i in range(b):
mask = masks[i]
if mask.numel() == 0:
continue
if torch.max(mask != 0) == False:
is_empty[i] = True
continue
y, x = torch.where(mask)
bounding_boxes[i, 0] = torch.min(x)
bounding_boxes[i, 1] = torch.min(y)
bounding_boxes[i, 2] = torch.max(x)
bounding_boxes[i, 3] = torch.max(y)
return bounding_boxes, is_empty
def resolve_cond_masks(conditions, h, w, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
@ -329,13 +350,14 @@ def resolve_cond_masks(conditions, h, w, device):
if mask.shape[2] != h or mask.shape[3] != w:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
if 'area' not in modified:
if modified.get("set_area_to_bounds", False):
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
if torch.max(bounds) == 0:
# Handle the edge-case of an all black mask (where masks_to_boxes would error)
area = (0, 0, 0, 0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified['area'] = (8, 8, 0, 0)
else:
box = masks_to_boxes(bounds)[0].type(torch.int)
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
# Make sure the height and width are divisible by 8
if X % 8 != 0:
@ -350,7 +372,7 @@ def resolve_cond_masks(conditions, h, w, device):
H = H + (8 - (H % 8))
if W % 8 != 0:
W = W + (8 - (W % 8))
area = (int(H), int(W), int(Y), (X))
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
modified['mask'] = mask

View File

@ -90,6 +90,7 @@ class ConditioningSetMask:
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"mask": ("MASK", ),
"set_area_to_bounds": ([False, True],),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
@ -97,7 +98,7 @@ class ConditioningSetMask:
CATEGORY = "conditioning"
def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0):
def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0):
c = []
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
@ -105,6 +106,7 @@ class ConditioningSetMask:
n = [t[0], t[1].copy()]
_, h, w = mask.shape
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['strength'] = strength
n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma