mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Remove useless code.
This commit is contained in:
parent
274dff3257
commit
2b14041d4b
@ -1,105 +0,0 @@
|
|||||||
from functools import reduce
|
|
||||||
import math
|
|
||||||
import operator
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from skimage import transform
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
def translate2d(tx, ty):
|
|
||||||
mat = [[1, 0, tx],
|
|
||||||
[0, 1, ty],
|
|
||||||
[0, 0, 1]]
|
|
||||||
return torch.tensor(mat, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def scale2d(sx, sy):
|
|
||||||
mat = [[sx, 0, 0],
|
|
||||||
[ 0, sy, 0],
|
|
||||||
[ 0, 0, 1]]
|
|
||||||
return torch.tensor(mat, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate2d(theta):
|
|
||||||
mat = [[torch.cos(theta), torch.sin(-theta), 0],
|
|
||||||
[torch.sin(theta), torch.cos(theta), 0],
|
|
||||||
[ 0, 0, 1]]
|
|
||||||
return torch.tensor(mat, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class KarrasAugmentationPipeline:
|
|
||||||
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
|
|
||||||
self.a_prob = a_prob
|
|
||||||
self.a_scale = a_scale
|
|
||||||
self.a_aniso = a_aniso
|
|
||||||
self.a_trans = a_trans
|
|
||||||
|
|
||||||
def __call__(self, image):
|
|
||||||
h, w = image.size
|
|
||||||
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
|
|
||||||
|
|
||||||
# x-flip
|
|
||||||
a0 = torch.randint(2, []).float()
|
|
||||||
mats.append(scale2d(1 - 2 * a0, 1))
|
|
||||||
# y-flip
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a1 = torch.randint(2, []).float() * do
|
|
||||||
mats.append(scale2d(1, 1 - 2 * a1))
|
|
||||||
# scaling
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a2 = torch.randn([]) * do
|
|
||||||
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
|
|
||||||
# rotation
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
|
||||||
mats.append(rotate2d(-a3))
|
|
||||||
# anisotropy
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
|
|
||||||
a5 = torch.randn([]) * do
|
|
||||||
mats.append(rotate2d(a4))
|
|
||||||
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
|
|
||||||
mats.append(rotate2d(-a4))
|
|
||||||
# translation
|
|
||||||
do = (torch.rand([]) < self.a_prob).float()
|
|
||||||
a6 = torch.randn([]) * do
|
|
||||||
a7 = torch.randn([]) * do
|
|
||||||
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
|
|
||||||
|
|
||||||
# form the transformation matrix and conditioning vector
|
|
||||||
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
|
|
||||||
mat = reduce(operator.matmul, mats)
|
|
||||||
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
|
|
||||||
|
|
||||||
# apply the transformation
|
|
||||||
image_orig = np.array(image, dtype=np.float32) / 255
|
|
||||||
if image_orig.ndim == 2:
|
|
||||||
image_orig = image_orig[..., None]
|
|
||||||
tf = transform.AffineTransform(mat.numpy())
|
|
||||||
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
|
|
||||||
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
|
|
||||||
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
|
|
||||||
return image, image_orig, cond
|
|
||||||
|
|
||||||
|
|
||||||
class KarrasAugmentWrapper(nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
|
|
||||||
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
|
|
||||||
if aug_cond is None:
|
|
||||||
aug_cond = input.new_zeros([input.shape[0], 9])
|
|
||||||
if mapping_cond is None:
|
|
||||||
mapping_cond = aug_cond
|
|
||||||
else:
|
|
||||||
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
|
|
||||||
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
|
|
||||||
|
|
||||||
def set_skip_stages(self, skip_stages):
|
|
||||||
return self.inner_model.set_skip_stages(skip_stages)
|
|
||||||
|
|
||||||
def set_patch_size(self, patch_size):
|
|
||||||
return self.inner_model.set_patch_size(patch_size)
|
|
@ -1,110 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from jsonmerge import merge
|
|
||||||
|
|
||||||
from . import augmentation, layers, models, utils
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(file):
|
|
||||||
defaults = {
|
|
||||||
'model': {
|
|
||||||
'sigma_data': 1.,
|
|
||||||
'patch_size': 1,
|
|
||||||
'dropout_rate': 0.,
|
|
||||||
'augment_wrapper': True,
|
|
||||||
'augment_prob': 0.,
|
|
||||||
'mapping_cond_dim': 0,
|
|
||||||
'unet_cond_dim': 0,
|
|
||||||
'cross_cond_dim': 0,
|
|
||||||
'cross_attn_depths': None,
|
|
||||||
'skip_stages': 0,
|
|
||||||
'has_variance': False,
|
|
||||||
},
|
|
||||||
'dataset': {
|
|
||||||
'type': 'imagefolder',
|
|
||||||
},
|
|
||||||
'optimizer': {
|
|
||||||
'type': 'adamw',
|
|
||||||
'lr': 1e-4,
|
|
||||||
'betas': [0.95, 0.999],
|
|
||||||
'eps': 1e-6,
|
|
||||||
'weight_decay': 1e-3,
|
|
||||||
},
|
|
||||||
'lr_sched': {
|
|
||||||
'type': 'inverse',
|
|
||||||
'inv_gamma': 20000.,
|
|
||||||
'power': 1.,
|
|
||||||
'warmup': 0.99,
|
|
||||||
},
|
|
||||||
'ema_sched': {
|
|
||||||
'type': 'inverse',
|
|
||||||
'power': 0.6667,
|
|
||||||
'max_value': 0.9999
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config = json.load(file)
|
|
||||||
return merge(defaults, config)
|
|
||||||
|
|
||||||
|
|
||||||
def make_model(config):
|
|
||||||
config = config['model']
|
|
||||||
assert config['type'] == 'image_v1'
|
|
||||||
model = models.ImageDenoiserModelV1(
|
|
||||||
config['input_channels'],
|
|
||||||
config['mapping_out'],
|
|
||||||
config['depths'],
|
|
||||||
config['channels'],
|
|
||||||
config['self_attn_depths'],
|
|
||||||
config['cross_attn_depths'],
|
|
||||||
patch_size=config['patch_size'],
|
|
||||||
dropout_rate=config['dropout_rate'],
|
|
||||||
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
|
|
||||||
unet_cond_dim=config['unet_cond_dim'],
|
|
||||||
cross_cond_dim=config['cross_cond_dim'],
|
|
||||||
skip_stages=config['skip_stages'],
|
|
||||||
has_variance=config['has_variance'],
|
|
||||||
)
|
|
||||||
if config['augment_wrapper']:
|
|
||||||
model = augmentation.KarrasAugmentWrapper(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def make_denoiser_wrapper(config):
|
|
||||||
config = config['model']
|
|
||||||
sigma_data = config.get('sigma_data', 1.)
|
|
||||||
has_variance = config.get('has_variance', False)
|
|
||||||
if not has_variance:
|
|
||||||
return partial(layers.Denoiser, sigma_data=sigma_data)
|
|
||||||
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
|
|
||||||
|
|
||||||
|
|
||||||
def make_sample_density(config):
|
|
||||||
sd_config = config['sigma_sample_density']
|
|
||||||
sigma_data = config['sigma_data']
|
|
||||||
if sd_config['type'] == 'lognormal':
|
|
||||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
|
||||||
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
|
|
||||||
return partial(utils.rand_log_normal, loc=loc, scale=scale)
|
|
||||||
if sd_config['type'] == 'loglogistic':
|
|
||||||
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
|
|
||||||
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
|
|
||||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
|
||||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
|
||||||
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
|
|
||||||
if sd_config['type'] == 'loguniform':
|
|
||||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
|
|
||||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
|
|
||||||
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
|
|
||||||
if sd_config['type'] == 'v-diffusion':
|
|
||||||
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
|
|
||||||
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
|
|
||||||
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
|
|
||||||
if sd_config['type'] == 'split-lognormal':
|
|
||||||
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
|
|
||||||
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
|
|
||||||
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
|
|
||||||
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
|
|
||||||
raise ValueError('Unknown sample density type')
|
|
@ -1,134 +0,0 @@
|
|||||||
import math
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from cleanfid.inception_torchscript import InceptionV3W
|
|
||||||
import clip
|
|
||||||
from resize_right import resize
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import trange
|
|
||||||
|
|
||||||
from . import utils
|
|
||||||
|
|
||||||
|
|
||||||
class InceptionV3FeatureExtractor(nn.Module):
|
|
||||||
def __init__(self, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
|
|
||||||
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
|
||||||
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
|
|
||||||
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
|
|
||||||
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
|
|
||||||
self.size = (299, 299)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.shape[2:4] != self.size:
|
|
||||||
x = resize(x, out_shape=self.size, pad_mode='reflect')
|
|
||||||
if x.shape[1] == 1:
|
|
||||||
x = torch.cat([x] * 3, dim=1)
|
|
||||||
x = (x * 127.5 + 127.5).clamp(0, 255)
|
|
||||||
return self.model(x)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPFeatureExtractor(nn.Module):
|
|
||||||
def __init__(self, name='ViT-L/14@336px', device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
|
|
||||||
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
|
||||||
std=(0.26862954, 0.26130258, 0.27577711))
|
|
||||||
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.shape[2:4] != self.size:
|
|
||||||
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
|
|
||||||
x = self.normalize(x)
|
|
||||||
x = self.model.encode_image(x).float()
|
|
||||||
x = F.normalize(x) * x.shape[1] ** 0.5
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
|
|
||||||
n_per_proc = math.ceil(n / accelerator.num_processes)
|
|
||||||
feats_all = []
|
|
||||||
try:
|
|
||||||
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
|
|
||||||
cur_batch_size = min(n - i, batch_size)
|
|
||||||
samples = sample_fn(cur_batch_size)[:cur_batch_size]
|
|
||||||
feats_all.append(accelerator.gather(extractor_fn(samples)))
|
|
||||||
except StopIteration:
|
|
||||||
pass
|
|
||||||
return torch.cat(feats_all)[:n]
|
|
||||||
|
|
||||||
|
|
||||||
def polynomial_kernel(x, y):
|
|
||||||
d = x.shape[-1]
|
|
||||||
dot = x @ y.transpose(-2, -1)
|
|
||||||
return (dot / d + 1) ** 3
|
|
||||||
|
|
||||||
|
|
||||||
def squared_mmd(x, y, kernel=polynomial_kernel):
|
|
||||||
m = x.shape[-2]
|
|
||||||
n = y.shape[-2]
|
|
||||||
kxx = kernel(x, x)
|
|
||||||
kyy = kernel(y, y)
|
|
||||||
kxy = kernel(x, y)
|
|
||||||
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
|
|
||||||
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
|
|
||||||
kxy_sum = kxy.sum([-1, -2])
|
|
||||||
term_1 = kxx_sum / m / (m - 1)
|
|
||||||
term_2 = kyy_sum / n / (n - 1)
|
|
||||||
term_3 = kxy_sum * 2 / m / n
|
|
||||||
return term_1 + term_2 - term_3
|
|
||||||
|
|
||||||
|
|
||||||
@utils.tf32_mode(matmul=False)
|
|
||||||
def kid(x, y, max_size=5000):
|
|
||||||
x_size, y_size = x.shape[0], y.shape[0]
|
|
||||||
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
|
|
||||||
total_mmd = x.new_zeros([])
|
|
||||||
for i in range(n_partitions):
|
|
||||||
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
|
|
||||||
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
|
|
||||||
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
|
|
||||||
return total_mmd / n_partitions
|
|
||||||
|
|
||||||
|
|
||||||
class _MatrixSquareRootEig(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, a):
|
|
||||||
vals, vecs = torch.linalg.eigh(a)
|
|
||||||
ctx.save_for_backward(vals, vecs)
|
|
||||||
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
vals, vecs = ctx.saved_tensors
|
|
||||||
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
|
|
||||||
vecs_t = vecs.transpose(-2, -1)
|
|
||||||
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
|
|
||||||
|
|
||||||
|
|
||||||
def sqrtm_eig(a):
|
|
||||||
if a.ndim < 2:
|
|
||||||
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
|
|
||||||
if a.shape[-2] != a.shape[-1]:
|
|
||||||
raise RuntimeError('tensor must be batches of square matrices')
|
|
||||||
return _MatrixSquareRootEig.apply(a)
|
|
||||||
|
|
||||||
|
|
||||||
@utils.tf32_mode(matmul=False)
|
|
||||||
def fid(x, y, eps=1e-8):
|
|
||||||
x_mean = x.mean(dim=0)
|
|
||||||
y_mean = y.mean(dim=0)
|
|
||||||
mean_term = (x_mean - y_mean).pow(2).sum()
|
|
||||||
x_cov = torch.cov(x.T)
|
|
||||||
y_cov = torch.cov(y.T)
|
|
||||||
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
|
|
||||||
x_cov = x_cov + eps_eye
|
|
||||||
y_cov = y_cov + eps_eye
|
|
||||||
x_cov_sqrt = sqrtm_eig(x_cov)
|
|
||||||
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
|
|
||||||
return mean_term + cov_term
|
|
@ -1,99 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class DDPGradientStatsHook:
|
|
||||||
def __init__(self, ddp_module):
|
|
||||||
try:
|
|
||||||
ddp_module.register_comm_hook(self, self._hook_fn)
|
|
||||||
except AttributeError:
|
|
||||||
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
|
|
||||||
self._clear_state()
|
|
||||||
|
|
||||||
def _clear_state(self):
|
|
||||||
self.bucket_sq_norms_small_batch = []
|
|
||||||
self.bucket_sq_norms_large_batch = []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _hook_fn(self, bucket):
|
|
||||||
buf = bucket.buffer()
|
|
||||||
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
|
|
||||||
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
|
|
||||||
def callback(fut):
|
|
||||||
buf = fut.value()[0]
|
|
||||||
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
|
|
||||||
return buf
|
|
||||||
return fut.then(callback)
|
|
||||||
|
|
||||||
def get_stats(self):
|
|
||||||
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
|
|
||||||
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
|
|
||||||
self._clear_state()
|
|
||||||
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
|
|
||||||
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
|
|
||||||
return stats[0].item(), stats[1].item()
|
|
||||||
|
|
||||||
|
|
||||||
class GradientNoiseScale:
|
|
||||||
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
|
|
||||||
from _An Empirical Model of Large-Batch Training_,
|
|
||||||
https://arxiv.org/abs/1812.06162).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
beta (float): The decay factor for the exponential moving averages used to
|
|
||||||
calculate the gradient noise scale.
|
|
||||||
Default: 0.9998
|
|
||||||
eps (float): Added for numerical stability.
|
|
||||||
Default: 1e-8
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, beta=0.9998, eps=1e-8):
|
|
||||||
self.beta = beta
|
|
||||||
self.eps = eps
|
|
||||||
self.ema_sq_norm = 0.
|
|
||||||
self.ema_var = 0.
|
|
||||||
self.beta_cumprod = 1.
|
|
||||||
self.gradient_noise_scale = float('nan')
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
"""Returns the state of the object as a :class:`dict`."""
|
|
||||||
return dict(self.__dict__.items())
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
"""Loads the object's state.
|
|
||||||
Args:
|
|
||||||
state_dict (dict): object state. Should be an object returned
|
|
||||||
from a call to :meth:`state_dict`.
|
|
||||||
"""
|
|
||||||
self.__dict__.update(state_dict)
|
|
||||||
|
|
||||||
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
|
|
||||||
"""Updates the state with a new batch's gradient statistics, and returns the
|
|
||||||
current gradient noise scale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
|
|
||||||
per sample gradients.
|
|
||||||
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
|
|
||||||
per sample gradients.
|
|
||||||
n_small_batch (int): The batch size of the individual microbatch or per sample
|
|
||||||
gradients (1 if per sample).
|
|
||||||
n_large_batch (int): The total batch size of the mean of the microbatch or
|
|
||||||
per sample gradients.
|
|
||||||
"""
|
|
||||||
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
|
|
||||||
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
|
|
||||||
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
|
|
||||||
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
|
|
||||||
self.beta_cumprod *= self.beta
|
|
||||||
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
|
|
||||||
return self.gradient_noise_scale
|
|
||||||
|
|
||||||
def get_gns(self):
|
|
||||||
"""Returns the current gradient noise scale."""
|
|
||||||
return self.gradient_noise_scale
|
|
||||||
|
|
||||||
def get_stats(self):
|
|
||||||
"""Returns the current (debiased) estimates of the squared mean gradient
|
|
||||||
and gradient variance."""
|
|
||||||
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
|
|
@ -1,246 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from . import utils
|
|
||||||
|
|
||||||
# Karras et al. preconditioned denoiser
|
|
||||||
|
|
||||||
class Denoiser(nn.Module):
|
|
||||||
"""A Karras et al. preconditioner for denoising diffusion models."""
|
|
||||||
|
|
||||||
def __init__(self, inner_model, sigma_data=1.):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = inner_model
|
|
||||||
self.sigma_data = sigma_data
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
|
||||||
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_skip, c_out, c_in
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
|
|
||||||
|
|
||||||
|
|
||||||
class DenoiserWithVariance(Denoiser):
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
|
|
||||||
logvar = utils.append_dims(logvar, model_output.ndim)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
|
|
||||||
return losses.flatten(1).mean(1)
|
|
||||||
|
|
||||||
|
|
||||||
# Residual blocks
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
|
||||||
def __init__(self, *main, skip=None):
|
|
||||||
super().__init__()
|
|
||||||
self.main = nn.Sequential(*main)
|
|
||||||
self.skip = skip if skip else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return self.main(input) + self.skip(input)
|
|
||||||
|
|
||||||
|
|
||||||
# Noise level (and other) conditioning
|
|
||||||
|
|
||||||
class ConditionedModule(nn.Module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UnconditionedModule(ConditionedModule):
|
|
||||||
def __init__(self, module):
|
|
||||||
super().__init__()
|
|
||||||
self.module = module
|
|
||||||
|
|
||||||
def forward(self, input, cond=None):
|
|
||||||
return self.module(input)
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionedSequential(nn.Sequential, ConditionedModule):
|
|
||||||
def forward(self, input, cond):
|
|
||||||
for module in self:
|
|
||||||
if isinstance(module, ConditionedModule):
|
|
||||||
input = module(input, cond)
|
|
||||||
else:
|
|
||||||
input = module(input)
|
|
||||||
return input
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionedResidualBlock(ConditionedModule):
|
|
||||||
def __init__(self, *main, skip=None):
|
|
||||||
super().__init__()
|
|
||||||
self.main = ConditionedSequential(*main)
|
|
||||||
self.skip = skip if skip else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
|
|
||||||
return self.main(input, cond) + skip
|
|
||||||
|
|
||||||
|
|
||||||
class AdaGN(ConditionedModule):
|
|
||||||
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
|
|
||||||
super().__init__()
|
|
||||||
self.num_groups = num_groups
|
|
||||||
self.eps = eps
|
|
||||||
self.cond_key = cond_key
|
|
||||||
self.mapper = nn.Linear(feats_in, c_out * 2)
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
|
|
||||||
input = F.group_norm(input, self.num_groups, eps=self.eps)
|
|
||||||
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
|
|
||||||
class SelfAttention2d(ConditionedModule):
|
|
||||||
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
|
|
||||||
super().__init__()
|
|
||||||
assert c_in % n_head == 0
|
|
||||||
self.norm_in = norm(c_in)
|
|
||||||
self.n_head = n_head
|
|
||||||
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
|
|
||||||
self.out_proj = nn.Conv2d(c_in, c_in, 1)
|
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
n, c, h, w = input.shape
|
|
||||||
qkv = self.qkv_proj(self.norm_in(input, cond))
|
|
||||||
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
scale = k.shape[3] ** -0.25
|
|
||||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
|
||||||
att = self.dropout(att)
|
|
||||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
|
|
||||||
return input + self.out_proj(y)
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention2d(ConditionedModule):
|
|
||||||
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
|
|
||||||
cond_key='cross', cond_key_padding='cross_padding'):
|
|
||||||
super().__init__()
|
|
||||||
assert c_dec % n_head == 0
|
|
||||||
self.cond_key = cond_key
|
|
||||||
self.cond_key_padding = cond_key_padding
|
|
||||||
self.norm_enc = nn.LayerNorm(c_enc)
|
|
||||||
self.norm_dec = norm_dec(c_dec)
|
|
||||||
self.n_head = n_head
|
|
||||||
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
|
|
||||||
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
|
|
||||||
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
|
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
n, c, h, w = input.shape
|
|
||||||
q = self.q_proj(self.norm_dec(input, cond))
|
|
||||||
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
|
|
||||||
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
|
|
||||||
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
|
|
||||||
k, v = kv.chunk(2, dim=1)
|
|
||||||
scale = k.shape[3] ** -0.25
|
|
||||||
att = ((q * scale) @ (k.transpose(2, 3) * scale))
|
|
||||||
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
|
|
||||||
att = att.softmax(3)
|
|
||||||
att = self.dropout(att)
|
|
||||||
y = (att @ v).transpose(2, 3)
|
|
||||||
y = y.contiguous().view([n, c, h, w])
|
|
||||||
return input + self.out_proj(y)
|
|
||||||
|
|
||||||
|
|
||||||
# Downsampling/upsampling
|
|
||||||
|
|
||||||
_kernels = {
|
|
||||||
'linear':
|
|
||||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
|
||||||
'cubic':
|
|
||||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
|
||||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
|
||||||
'lanczos3':
|
|
||||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
|
||||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
|
||||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
|
||||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
|
||||||
}
|
|
||||||
_kernels['bilinear'] = _kernels['linear']
|
|
||||||
_kernels['bicubic'] = _kernels['cubic']
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample2d(nn.Module):
|
|
||||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
kernel_1d = torch.tensor([_kernels[kernel]])
|
|
||||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
|
||||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
|
||||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
|
||||||
indices = torch.arange(x.shape[1], device=x.device)
|
|
||||||
weight[indices, indices] = self.kernel.to(weight)
|
|
||||||
return F.conv2d(x, weight, stride=2)
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample2d(nn.Module):
|
|
||||||
def __init__(self, kernel='linear', pad_mode='reflect'):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
|
|
||||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
|
||||||
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
|
||||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
|
||||||
indices = torch.arange(x.shape[1], device=x.device)
|
|
||||||
weight[indices, indices] = self.kernel.to(weight)
|
|
||||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Embeddings
|
|
||||||
|
|
||||||
class FourierFeatures(nn.Module):
|
|
||||||
def __init__(self, in_features, out_features, std=1.):
|
|
||||||
super().__init__()
|
|
||||||
assert out_features % 2 == 0
|
|
||||||
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
f = 2 * math.pi * input @ self.weight.T
|
|
||||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# U-Nets
|
|
||||||
|
|
||||||
class UNet(ConditionedModule):
|
|
||||||
def __init__(self, d_blocks, u_blocks, skip_stages=0):
|
|
||||||
super().__init__()
|
|
||||||
self.d_blocks = nn.ModuleList(d_blocks)
|
|
||||||
self.u_blocks = nn.ModuleList(u_blocks)
|
|
||||||
self.skip_stages = skip_stages
|
|
||||||
|
|
||||||
def forward(self, input, cond):
|
|
||||||
skips = []
|
|
||||||
for block in self.d_blocks[self.skip_stages:]:
|
|
||||||
input = block(input, cond)
|
|
||||||
skips.append(input)
|
|
||||||
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
|
|
||||||
input = block(input, cond, skip if i > 0 else None)
|
|
||||||
return input
|
|
@ -1 +0,0 @@
|
|||||||
from .image_v1 import ImageDenoiserModelV1
|
|
@ -1,156 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .. import layers, utils
|
|
||||||
|
|
||||||
|
|
||||||
def orthogonal_(module):
|
|
||||||
nn.init.orthogonal_(module.weight)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
class ResConvBlock(layers.ConditionedResidualBlock):
|
|
||||||
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
|
|
||||||
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
|
|
||||||
super().__init__(
|
|
||||||
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(c_in, c_mid, 3, padding=1),
|
|
||||||
nn.Dropout2d(dropout_rate, inplace=True),
|
|
||||||
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(c_mid, c_out, 3, padding=1),
|
|
||||||
nn.Dropout2d(dropout_rate, inplace=True),
|
|
||||||
skip=skip)
|
|
||||||
|
|
||||||
|
|
||||||
class DBlock(layers.ConditionedSequential):
|
|
||||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
|
||||||
modules = [nn.Identity()]
|
|
||||||
for i in range(n_layers):
|
|
||||||
my_c_in = c_in if i == 0 else c_mid
|
|
||||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
|
||||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
|
||||||
if self_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
if cross_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
super().__init__(*modules)
|
|
||||||
self.set_downsample(downsample)
|
|
||||||
|
|
||||||
def set_downsample(self, downsample):
|
|
||||||
self[0] = layers.Downsample2d() if downsample else nn.Identity()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class UBlock(layers.ConditionedSequential):
|
|
||||||
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
|
|
||||||
modules = []
|
|
||||||
for i in range(n_layers):
|
|
||||||
my_c_in = c_in if i == 0 else c_mid
|
|
||||||
my_c_out = c_mid if i < n_layers - 1 else c_out
|
|
||||||
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
|
|
||||||
if self_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
if cross_attn:
|
|
||||||
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
|
||||||
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
|
|
||||||
modules.append(nn.Identity())
|
|
||||||
super().__init__(*modules)
|
|
||||||
self.set_upsample(upsample)
|
|
||||||
|
|
||||||
def forward(self, input, cond, skip=None):
|
|
||||||
if skip is not None:
|
|
||||||
input = torch.cat([input, skip], dim=1)
|
|
||||||
return super().forward(input, cond)
|
|
||||||
|
|
||||||
def set_upsample(self, upsample):
|
|
||||||
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class MappingNet(nn.Sequential):
|
|
||||||
def __init__(self, feats_in, feats_out, n_layers=2):
|
|
||||||
layers = []
|
|
||||||
for i in range(n_layers):
|
|
||||||
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
|
|
||||||
layers.append(nn.GELU())
|
|
||||||
super().__init__(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDenoiserModelV1(nn.Module):
|
|
||||||
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
|
|
||||||
super().__init__()
|
|
||||||
self.c_in = c_in
|
|
||||||
self.channels = channels
|
|
||||||
self.unet_cond_dim = unet_cond_dim
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.has_variance = has_variance
|
|
||||||
self.timestep_embed = layers.FourierFeatures(1, feats_in)
|
|
||||||
if mapping_cond_dim > 0:
|
|
||||||
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
|
|
||||||
self.mapping = MappingNet(feats_in, feats_in)
|
|
||||||
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
|
|
||||||
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
|
||||||
nn.init.zeros_(self.proj_out.weight)
|
|
||||||
nn.init.zeros_(self.proj_out.bias)
|
|
||||||
if cross_cond_dim == 0:
|
|
||||||
cross_attn_depths = [False] * len(self_attn_depths)
|
|
||||||
d_blocks, u_blocks = [], []
|
|
||||||
for i in range(len(depths)):
|
|
||||||
my_c_in = channels[max(0, i - 1)]
|
|
||||||
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
|
||||||
for i in range(len(depths)):
|
|
||||||
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
|
|
||||||
my_c_out = channels[max(0, i - 1)]
|
|
||||||
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
|
|
||||||
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
|
|
||||||
c_noise = sigma.log() / 4
|
|
||||||
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
|
|
||||||
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
|
|
||||||
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
|
|
||||||
cond = {'cond': mapping_out}
|
|
||||||
if unet_cond is not None:
|
|
||||||
input = torch.cat([input, unet_cond], dim=1)
|
|
||||||
if cross_cond is not None:
|
|
||||||
cond['cross'] = cross_cond
|
|
||||||
cond['cross_padding'] = cross_cond_padding
|
|
||||||
if self.patch_size > 1:
|
|
||||||
input = F.pixel_unshuffle(input, self.patch_size)
|
|
||||||
input = self.proj_in(input)
|
|
||||||
input = self.u_net(input, cond)
|
|
||||||
input = self.proj_out(input)
|
|
||||||
if self.has_variance:
|
|
||||||
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
|
|
||||||
if self.patch_size > 1:
|
|
||||||
input = F.pixel_shuffle(input, self.patch_size)
|
|
||||||
if self.has_variance and return_variance:
|
|
||||||
return input, logvar
|
|
||||||
return input
|
|
||||||
|
|
||||||
def set_skip_stages(self, skip_stages):
|
|
||||||
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
|
|
||||||
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
|
|
||||||
nn.init.zeros_(self.proj_out.weight)
|
|
||||||
nn.init.zeros_(self.proj_out.bias)
|
|
||||||
self.u_net.skip_stages = skip_stages
|
|
||||||
for i, block in enumerate(self.u_net.d_blocks):
|
|
||||||
block.set_downsample(i > skip_stages)
|
|
||||||
for i, block in enumerate(reversed(self.u_net.u_blocks)):
|
|
||||||
block.set_upsample(i > skip_stages)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def set_patch_size(self, patch_size):
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
|
|
||||||
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
|
|
||||||
nn.init.zeros_(self.proj_out.weight)
|
|
||||||
nn.init.zeros_(self.proj_out.bias)
|
|
@ -10,25 +10,6 @@ from PIL import Image
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
from torchvision.transforms import functional as TF
|
|
||||||
|
|
||||||
|
|
||||||
def from_pil_image(x):
|
|
||||||
"""Converts from a PIL image to a tensor."""
|
|
||||||
x = TF.to_tensor(x)
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = x[..., None]
|
|
||||||
return x * 2 - 1
|
|
||||||
|
|
||||||
|
|
||||||
def to_pil_image(x):
|
|
||||||
"""Converts from a tensor to a PIL image."""
|
|
||||||
if x.ndim == 4:
|
|
||||||
assert x.shape[0] == 1
|
|
||||||
x = x[0]
|
|
||||||
if x.shape[0] == 1:
|
|
||||||
x = x[0]
|
|
||||||
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
|
|
||||||
|
|
||||||
|
|
||||||
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
||||||
|
Loading…
Reference in New Issue
Block a user