From a5e6a632f9f16e5b3c72c428820bce67b05446bf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 10 Jun 2024 01:05:53 -0400 Subject: [PATCH] Support sampling non 2D latents. --- comfy/samplers.py | 93 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 30 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 29962a91..656e0a28 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -8,7 +8,8 @@ import logging import comfy.sampler_helpers def get_area_and_mult(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) + dims = tuple(x_in.shape[2:]) + area = None strength = 1.0 if 'timestep_start' in conds: @@ -20,11 +21,16 @@ def get_area_and_mult(conds, x_in, timestep_in): if timestep_in[0] < timestep_end: return None if 'area' in conds: - area = conds['area'] + area = list(conds['area']) if 'strength' in conds: strength = conds['strength'] - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + input_x = x_in + if area is not None: + for i in range(len(dims)): + area[i] = min(input_x.shape[i + 2] - area[len(dims) + i], area[i]) + input_x = input_x.narrow(i + 2, area[len(dims) + i], area[i]) + if 'mask' in conds: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process @@ -32,28 +38,30 @@ def get_area_and_mult(conds, x_in, timestep_in): if "mask_strength" in conds: mask_strength = conds["mask_strength"] mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + assert(mask.shape[1:] == x_in.shape[2:]) + + mask = mask[:input_x.shape[0]] + if area is not None: + for i in range(len(dims)): + mask = mask.narrow(i + 1, area[len(dims) + i], area[i]) + + mask = mask * mask_strength mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength - if 'mask' not in conds: + if 'mask' not in conds and area is not None: rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + for i in range(len(dims)): + if area[len(dims) + i] != 0: + for t in range(rr): + m = mult.narrow(i + 2, t, 1) + m *= ((1.0/rr) * (t + 1)) + if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]: + for t in range(rr): + m = mult.narrow(i + 2, area[i] - 1 - t, 1) + m *= ((1.0/rr) * (t + 1)) conditioning = {} model_conds = conds["model_conds"] @@ -219,8 +227,19 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options): for o in range(batch_chunks): cond_index = cond_or_uncond[o] - out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -335,7 +354,7 @@ def get_mask_aabb(masks): return bounding_boxes, is_empty -def resolve_areas_and_cond_masks(conditions, h, w, device): +def resolve_areas_and_cond_masks_multidim(conditions, dims, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons for i in range(len(conditions)): @@ -344,7 +363,14 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): area = c['area'] if area[0] == "percentage": modified = c.copy() - area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w)) + a = area[1:] + a_len = len(a) // 2 + area = () + for d in range(len(dims)): + area += (max(1, round(a[d] * dims[d])),) + for d in range(len(dims)): + area += (round(a[d + a_len] * dims[d]),) + modified['area'] = area c = modified conditions[i] = c @@ -353,12 +379,12 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): mask = c['mask'] mask = mask.to(device=device) modified = c.copy() - if len(mask.shape) == 2: + if len(mask.shape) == len(dims): mask = mask.unsqueeze(0) - if mask.shape[1] != h or mask.shape[2] != w: - mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) + if mask.shape[1:] != dims: + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1) - if modified.get("set_area_to_bounds", False): + if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2 bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) boxes, is_empty = get_mask_aabb(bounds) if is_empty[0]: @@ -375,7 +401,11 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): modified['mask'] = mask conditions[i] = modified -def create_cond_with_same_area_if_none(conds, c): +def resolve_areas_and_cond_masks(conditions, h, w, device): + logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.") + return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device) + +def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2 if 'area' not in c: return @@ -479,7 +509,10 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar params = x.copy() params["device"] = device params["noise"] = noise - params["width"] = params.get("width", noise.shape[3] * 8) + default_width = None + if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model + default_width = noise.shape[3] * 8 + params["width"] = params.get("width", default_width) params["height"] = params.get("height", noise.shape[2] * 8) params["prompt_type"] = params.get("prompt_type", prompt_type) for k in kwargs: @@ -567,7 +600,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): for k in conds: conds[k] = conds[k][:] - resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device) + resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) for k in conds: calculate_start_end_timesteps(model, conds[k])