diff --git a/comfy/utils.py b/comfy/utils.py index d58320b4..33c1c3dd 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -50,71 +51,81 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd -#slow and inefficient, should be optimized def bislerp(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width ) - height_scale = (shape[2]) / (height ) + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] - shape[3] = width - shape[2] = height - out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) - def algorithm(in1, in2, t): - dims = in1.shape - val = t + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms - #flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low/low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high/high_weight - - dot_prod = (low_norm*high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm - res *= (low_weight * (1.0-val) + high_weight * val) - return res.reshape(dims) - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new): + coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - in1 = samples[:,:,y1,x1] - in2 = samples[:,:,y1,x2] - in3 = samples[:,:,y2,x1] - in4 = samples[:,:,y2,x2] + pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif (x1 == x2): - out_value = algorithm(in1, in3, wy) - elif (y1 == y2): - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) - out1[:,:,y_dest,x_dest] = out_value - return out1 + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + + coords_1 = coords_1.expand((n, c, h_new, -1)) + coords_2 = coords_2.expand((n, c, h_new, -1)) + ratios = ratios.expand((n, 1, h_new, -1)) + + pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + return result def common_upscale(samples, width, height, upscale_method, crop): if crop == "center":