mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Pull latest tomesd code from upstream.
This commit is contained in:
parent
f50b1fec69
commit
539ff487a8
@ -1,4 +1,4 @@
|
|||||||
|
#Taken from: https://github.com/dbolya/tomesd
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Tuple, Callable
|
from typing import Tuple, Callable
|
||||||
@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mps_gather_workaround(input, dim, index):
|
||||||
|
if input.shape[-1] == 1:
|
||||||
|
return torch.gather(
|
||||||
|
input.unsqueeze(-1),
|
||||||
|
dim - 1 if dim < 0 else dim,
|
||||||
|
index.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
else:
|
||||||
|
return torch.gather(input, dim, index)
|
||||||
|
|
||||||
|
|
||||||
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||||
w: int, h: int, sx: int, sy: int, r: int,
|
w: int, h: int, sx: int, sy: int, r: int,
|
||||||
no_rand: bool = False) -> Tuple[Callable, Callable]:
|
no_rand: bool = False) -> Tuple[Callable, Callable]:
|
||||||
"""
|
"""
|
||||||
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||||
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- metric [B, N, C]: metric to use for similarity
|
- metric [B, N, C]: metric to use for similarity
|
||||||
- w: image width in tokens
|
- w: image width in tokens
|
||||||
@ -29,32 +39,48 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
|||||||
if r <= 0:
|
if r <= 0:
|
||||||
return do_nothing, do_nothing
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
hsy, wsx = h // sy, w // sx
|
hsy, wsx = h // sy, w // sx
|
||||||
|
|
||||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||||
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
|
|
||||||
|
|
||||||
if no_rand:
|
if no_rand:
|
||||||
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64)
|
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||||
else:
|
else:
|
||||||
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device)
|
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||||
|
|
||||||
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype))
|
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||||
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1)
|
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||||
rand_idx = idx_buffer.argsort(dim=1)
|
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||||
|
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||||
|
|
||||||
num_dst = int((1 / (sx*sy)) * N)
|
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||||
|
if (hsy * sy) < h or (wsx * sx) < w:
|
||||||
|
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||||
|
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||||
|
else:
|
||||||
|
idx_buffer = idx_buffer_view
|
||||||
|
|
||||||
|
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||||
|
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||||
|
|
||||||
|
# We're finished with these
|
||||||
|
del idx_buffer, idx_buffer_view
|
||||||
|
|
||||||
|
# rand_idx is currently dst|src, so split them
|
||||||
|
num_dst = hsy * wsx
|
||||||
a_idx = rand_idx[:, num_dst:, :] # src
|
a_idx = rand_idx[:, num_dst:, :] # src
|
||||||
b_idx = rand_idx[:, :num_dst, :] # dst
|
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||||
|
|
||||||
def split(x):
|
def split(x):
|
||||||
C = x.shape[-1]
|
C = x.shape[-1]
|
||||||
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C))
|
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||||
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C))
|
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||||
return src, dst
|
return src, dst
|
||||||
|
|
||||||
|
# Cosine similarity between A and B
|
||||||
metric = metric / metric.norm(dim=-1, keepdim=True)
|
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||||
a, b = split(metric)
|
a, b = split(metric)
|
||||||
scores = a @ b.transpose(-1, -2)
|
scores = a @ b.transpose(-1, -2)
|
||||||
@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
|||||||
# Can't reduce more than the # tokens in src
|
# Can't reduce more than the # tokens in src
|
||||||
r = min(a.shape[1], r)
|
r = min(a.shape[1], r)
|
||||||
|
|
||||||
|
# Find the most similar greedily
|
||||||
node_max, node_idx = scores.max(dim=-1)
|
node_max, node_idx = scores.max(dim=-1)
|
||||||
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||||
|
|
||||||
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||||
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||||
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||||
|
|
||||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||||
src, dst = split(x)
|
src, dst = split(x)
|
||||||
n, t1, c = src.shape
|
n, t1, c = src.shape
|
||||||
|
|
||||||
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||||
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
|
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||||
|
|
||||||
return torch.cat([unm, dst], dim=1)
|
return torch.cat([unm, dst], dim=1)
|
||||||
@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
|||||||
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||||
_, _, c = unm.shape
|
_, _, c = unm.shape
|
||||||
|
|
||||||
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c))
|
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||||
|
|
||||||
# Combine back to the original shape
|
# Combine back to the original shape
|
||||||
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||||
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||||
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||||
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src)
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
|||||||
def get_functions(x, ratio, original_shape):
|
def get_functions(x, ratio, original_shape):
|
||||||
b, c, original_h, original_w = original_shape
|
b, c, original_h, original_w = original_shape
|
||||||
original_tokens = original_h * original_w
|
original_tokens = original_h * original_w
|
||||||
downsample = int(math.sqrt(original_tokens // x.shape[1]))
|
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||||
stride_x = 2
|
stride_x = 2
|
||||||
stride_y = 2
|
stride_y = 2
|
||||||
max_downsample = 1
|
max_downsample = 1
|
||||||
|
|
||||||
if downsample <= max_downsample:
|
if downsample <= max_downsample:
|
||||||
w = original_w // downsample
|
w = int(math.ceil(original_w / downsample))
|
||||||
h = original_h // downsample
|
h = int(math.ceil(original_h / downsample))
|
||||||
r = int(x.shape[1] * ratio)
|
r = int(x.shape[1] * ratio)
|
||||||
no_rand = False
|
no_rand = False
|
||||||
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||||
|
Loading…
Reference in New Issue
Block a user