Pull latest tomesd code from upstream.

This commit is contained in:
comfyanonymous 2023-04-03 15:49:28 -04:00
parent f50b1fec69
commit 539ff487a8

View File

@ -1,4 +1,4 @@
#Taken from: https://github.com/dbolya/tomesd
import torch import torch
from typing import Tuple, Callable from typing import Tuple, Callable
@ -8,13 +8,23 @@ def do_nothing(x: torch.Tensor, mode:str=None):
return x return x
def mps_gather_workaround(input, dim, index):
if input.shape[-1] == 1:
return torch.gather(
input.unsqueeze(-1),
dim - 1 if dim < 0 else dim,
index.unsqueeze(-1)
).squeeze(-1)
else:
return torch.gather(input, dim, index)
def bipartite_soft_matching_random2d(metric: torch.Tensor, def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int, w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]: no_rand: bool = False) -> Tuple[Callable, Callable]:
""" """
Partitions the tokens into src and dst and merges r tokens from src to dst. Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args: Args:
- metric [B, N, C]: metric to use for similarity - metric [B, N, C]: metric to use for similarity
- w: image width in tokens - w: image width in tokens
@ -29,32 +39,48 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
if r <= 0: if r <= 0:
return do_nothing, do_nothing return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad(): with torch.no_grad():
hsy, wsx = h // sy, w // sx hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src # For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
if no_rand: if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64) rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else: else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device) rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype)) # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1) idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
rand_idx = idx_buffer.argsort(dim=1) idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
num_dst = int((1 / (sx*sy)) * N) # Image is not divisible by sx or sy so we need to move it into a new buffer
if (hsy * sy) < h or (wsx * sx) < w:
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
else:
idx_buffer = idx_buffer_view
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
# We're finished with these
del idx_buffer, idx_buffer_view
# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst b_idx = rand_idx[:, :num_dst, :] # dst
def split(x): def split(x):
C = x.shape[-1] C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C)) src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C)) dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst return src, dst
# Cosine similarity between A and B
metric = metric / metric.norm(dim=-1, keepdim=True) metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric) a, b = split(metric)
scores = a @ b.transpose(-1, -2) scores = a @ b.transpose(-1, -2)
@ -62,19 +88,20 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
# Can't reduce more than the # tokens in src # Can't reduce more than the # tokens in src
r = min(a.shape[1], r) r = min(a.shape[1], r)
# Find the most similar greedily
node_max, node_idx = scores.max(dim=-1) node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x) src, dst = split(x)
n, t1, c = src.shape n, t1, c = src.shape
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
return torch.cat([unm, dst], dim=1) return torch.cat([unm, dst], dim=1)
@ -84,13 +111,13 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape _, _, c = unm.shape
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c)) src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
# Combine back to the original shape # Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src) out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
return out return out
@ -100,14 +127,14 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def get_functions(x, ratio, original_shape): def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1])) downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
stride_x = 2 stride_x = 2
stride_y = 2 stride_y = 2
max_downsample = 1 max_downsample = 1
if downsample <= max_downsample: if downsample <= max_downsample:
w = original_w // downsample w = int(math.ceil(original_w / downsample))
h = original_h // downsample h = int(math.ceil(original_h / downsample))
r = int(x.shape[1] * ratio) r = int(x.shape[1] * ratio)
no_rand = False no_rand = False
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)