mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
vecorized bislerp
This commit is contained in:
parent
9b1396e93a
commit
8b4b0c3188
117
comfy/utils.py
117
comfy/utils.py
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import einops
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
@ -46,71 +47,81 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
#slow and inefficient, should be optimized
|
|
||||||
def bislerp(samples, width, height):
|
def bislerp(samples, width, height):
|
||||||
shape = list(samples.shape)
|
def slerp(b1, b2, r):
|
||||||
width_scale = (shape[3]) / (width )
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||||
height_scale = (shape[2]) / (height )
|
|
||||||
|
c = b1.shape[-1]
|
||||||
|
|
||||||
shape[3] = width
|
#norms
|
||||||
shape[2] = height
|
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||||
out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device)
|
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
def algorithm(in1, in2, t):
|
#normalize
|
||||||
dims = in1.shape
|
b1_normalized = b1 / b1_norms
|
||||||
val = t
|
b2_normalized = b2 / b2_norms
|
||||||
|
|
||||||
#flatten to batches
|
#zero when norms are zero
|
||||||
low = in1.reshape(dims[0], -1)
|
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
high = in2.reshape(dims[0], -1)
|
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
|
|
||||||
low_weight = torch.norm(low, dim=1, keepdim=True)
|
#slerp
|
||||||
low_weight[low_weight == 0] = 0.0000000001
|
dot = (b1_normalized*b2_normalized).sum(1)
|
||||||
low_norm = low/low_weight
|
omega = torch.acos(dot)
|
||||||
high_weight = torch.norm(high, dim=1, keepdim=True)
|
|
||||||
high_weight[high_weight == 0] = 0.0000000001
|
|
||||||
high_norm = high/high_weight
|
|
||||||
|
|
||||||
dot_prod = (low_norm*high_norm).sum(1)
|
|
||||||
dot_prod[dot_prod > 0.9995] = 0.9995
|
|
||||||
dot_prod[dot_prod < -0.9995] = -0.9995
|
|
||||||
omega = torch.acos(dot_prod)
|
|
||||||
so = torch.sin(omega)
|
so = torch.sin(omega)
|
||||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm
|
|
||||||
res *= (low_weight * (1.0-val) + high_weight * val)
|
|
||||||
return res.reshape(dims)
|
|
||||||
|
|
||||||
for x_dest in range(shape[3]):
|
#technically not mathematically correct, but more pleasing?
|
||||||
for y_dest in range(shape[2]):
|
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||||
y = (y_dest + 0.5) * height_scale - 0.5
|
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||||
x = (x_dest + 0.5) * width_scale - 0.5
|
|
||||||
|
|
||||||
x1 = max(math.floor(x), 0)
|
#edge cases for same or polar opposites
|
||||||
x2 = min(x1 + 1, samples.shape[3] - 1)
|
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||||
wx = x - math.floor(x)
|
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||||
|
return res
|
||||||
|
|
||||||
|
def generate_bilinear_data(length_old, length_new):
|
||||||
|
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
|
||||||
|
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||||
|
ratios = coords_1 - coords_1.floor()
|
||||||
|
coords_1 = coords_1.to(torch.int64)
|
||||||
|
|
||||||
|
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
|
||||||
|
coords_2[:,:,:,-1] -= 1
|
||||||
|
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||||
|
coords_2 = coords_2.to(torch.int64)
|
||||||
|
return ratios, coords_1, coords_2
|
||||||
|
|
||||||
|
n,c,h,w = samples.shape
|
||||||
|
h_new, w_new = (height, width)
|
||||||
|
|
||||||
|
#linear h
|
||||||
|
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
|
||||||
|
|
||||||
y1 = max(math.floor(y), 0)
|
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w))
|
||||||
y2 = min(y1 + 1, samples.shape[2] - 1)
|
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w))
|
||||||
wy = y - math.floor(y)
|
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w))
|
||||||
|
|
||||||
in1 = samples[:,:,y1,x1]
|
pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c')
|
||||||
in2 = samples[:,:,y1,x2]
|
pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c')
|
||||||
in3 = samples[:,:,y2,x1]
|
ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c')
|
||||||
in4 = samples[:,:,y2,x2]
|
|
||||||
|
|
||||||
if (x1 == x2) and (y1 == y2):
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
out_value = in1
|
result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w)
|
||||||
elif (x1 == x2):
|
|
||||||
out_value = algorithm(in1, in3, wy)
|
|
||||||
elif (y1 == y2):
|
|
||||||
out_value = algorithm(in1, in2, wx)
|
|
||||||
else:
|
|
||||||
o1 = algorithm(in1, in2, wx)
|
|
||||||
o2 = algorithm(in3, in4, wx)
|
|
||||||
out_value = algorithm(o1, o2, wy)
|
|
||||||
|
|
||||||
out1[:,:,y_dest,x_dest] = out_value
|
#linear w
|
||||||
return out1
|
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
|
||||||
|
|
||||||
|
coords_1 = coords_1.expand((n, c, h_new, -1))
|
||||||
|
coords_2 = coords_2.expand((n, c, h_new, -1))
|
||||||
|
ratios = ratios.expand((n, 1, h_new, -1))
|
||||||
|
|
||||||
|
pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c')
|
||||||
|
pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c')
|
||||||
|
ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c')
|
||||||
|
|
||||||
|
result = slerp(pass_1, pass_2, ratios)
|
||||||
|
result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new)
|
||||||
|
return result
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
|
Loading…
Reference in New Issue
Block a user