ComfyUI/comfy_extras/chainner_models/architecture/SCUNet.py
comfyanonymous 766c7b3815 Update upscale model code to latest Chainner model code.
Don't add SRFormer because the code license is incompatible with the GPL.

Remove MAT because it's unused and the license is incompatible with GPL.
2023-09-02 22:27:40 -04:00

456 lines
14 KiB
Python

# pylint: skip-file
# -----------------------------------------------------------------------------------
# SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278
# Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc
# -----------------------------------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from .timm.drop import DropPath
from .timm.weight_init import trunc_normal_
# Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
class WMSA(nn.Module):
"""Self-attention module in Swin Transformer"""
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
super(WMSA, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.head_dim = head_dim
self.scale = self.head_dim**-0.5
self.n_heads = input_dim // head_dim
self.window_size = window_size
self.type = type
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
)
# TODO recover
# self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
)
self.linear = nn.Linear(self.input_dim, self.output_dim)
trunc_normal_(self.relative_position_params, std=0.02)
self.relative_position_params = torch.nn.Parameter(
self.relative_position_params.view(
2 * window_size - 1, 2 * window_size - 1, self.n_heads
)
.transpose(1, 2)
.transpose(0, 1)
)
def generate_mask(self, h, w, p, shift):
"""generating the mask of SW-MSA
Args:
shift: shift parameters in CyclicShift.
Returns:
attn_mask: should be (1 1 w p p),
"""
# supporting square.
attn_mask = torch.zeros(
h,
w,
p,
p,
p,
p,
dtype=torch.bool,
device=self.relative_position_params.device,
)
if self.type == "W":
return attn_mask
s = p - shift
attn_mask[-1, :, :s, :, s:, :] = True
attn_mask[-1, :, s:, :, :s, :] = True
attn_mask[:, -1, :, :s, :, s:] = True
attn_mask[:, -1, :, s:, :, :s] = True
attn_mask = rearrange(
attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
)
return attn_mask
def forward(self, x):
"""Forward pass of Window Multi-head Self-attention module.
Args:
x: input tensor with shape of [b h w c];
attn_mask: attention mask, fill -inf where the value is True;
Returns:
output: tensor shape [b h w c]
"""
if self.type != "W":
x = torch.roll(
x,
shifts=(-(self.window_size // 2), -(self.window_size // 2)),
dims=(1, 2),
)
x = rearrange(
x,
"b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
p1=self.window_size,
p2=self.window_size,
)
h_windows = x.size(1)
w_windows = x.size(2)
# square validation
# assert h_windows == w_windows
x = rearrange(
x,
"b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
p1=self.window_size,
p2=self.window_size,
)
qkv = self.embedding_layer(x)
q, k, v = rearrange(
qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
).chunk(3, dim=0)
sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
# Adding learnable relative embedding
sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
# Using Attn Mask to distinguish different subwindows.
if self.type != "W":
attn_mask = self.generate_mask(
h_windows, w_windows, self.window_size, shift=self.window_size // 2
)
sim = sim.masked_fill_(attn_mask, float("-inf"))
probs = nn.functional.softmax(sim, dim=-1)
output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
output = rearrange(output, "h b w p c -> b w p (h c)")
output = self.linear(output)
output = rearrange(
output,
"b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
w1=h_windows,
p1=self.window_size,
)
if self.type != "W":
output = torch.roll(
output,
shifts=(self.window_size // 2, self.window_size // 2),
dims=(1, 2),
)
return output
def relative_embedding(self):
cord = torch.tensor(
np.array(
[
[i, j]
for i in range(self.window_size)
for j in range(self.window_size)
]
)
)
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
# negative is allowed
return self.relative_position_params[
:, relation[:, :, 0].long(), relation[:, :, 1].long()
]
class Block(nn.Module):
def __init__(
self,
input_dim,
output_dim,
head_dim,
window_size,
drop_path,
type="W",
input_resolution=None,
):
"""SwinTransformer Block"""
super(Block, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert type in ["W", "SW"]
self.type = type
if input_resolution <= window_size:
self.type = "W"
self.ln1 = nn.LayerNorm(input_dim)
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.ln2 = nn.LayerNorm(input_dim)
self.mlp = nn.Sequential(
nn.Linear(input_dim, 4 * input_dim),
nn.GELU(),
nn.Linear(4 * input_dim, output_dim),
)
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
class ConvTransBlock(nn.Module):
def __init__(
self,
conv_dim,
trans_dim,
head_dim,
window_size,
drop_path,
type="W",
input_resolution=None,
):
"""SwinTransformer and Conv Block"""
super(ConvTransBlock, self).__init__()
self.conv_dim = conv_dim
self.trans_dim = trans_dim
self.head_dim = head_dim
self.window_size = window_size
self.drop_path = drop_path
self.type = type
self.input_resolution = input_resolution
assert self.type in ["W", "SW"]
if self.input_resolution <= self.window_size:
self.type = "W"
self.trans_block = Block(
self.trans_dim,
self.trans_dim,
self.head_dim,
self.window_size,
self.drop_path,
self.type,
self.input_resolution,
)
self.conv1_1 = nn.Conv2d(
self.conv_dim + self.trans_dim,
self.conv_dim + self.trans_dim,
1,
1,
0,
bias=True,
)
self.conv1_2 = nn.Conv2d(
self.conv_dim + self.trans_dim,
self.conv_dim + self.trans_dim,
1,
1,
0,
bias=True,
)
self.conv_block = nn.Sequential(
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
)
def forward(self, x):
conv_x, trans_x = torch.split(
self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange("b c h w -> b h w c")(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange("b h w c -> b c h w")(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
class SCUNet(nn.Module):
def __init__(
self,
state_dict,
in_nc=3,
config=[4, 4, 4, 4, 4, 4, 4],
dim=64,
drop_path_rate=0.0,
input_resolution=256,
):
super(SCUNet, self).__init__()
self.model_arch = "SCUNet"
self.sub_type = "SR"
self.num_filters: int = 0
self.state = state_dict
self.config = config
self.dim = dim
self.head_dim = 32
self.window_size = 8
self.in_nc = in_nc
self.out_nc = self.in_nc
self.scale = 1
self.supports_fp16 = True
# drop path rate for each layer
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
begin = 0
self.m_down1 = [
ConvTransBlock(
dim // 2,
dim // 2,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution,
)
for i in range(config[0])
] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
begin += config[0]
self.m_down2 = [
ConvTransBlock(
dim,
dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 2,
)
for i in range(config[1])
] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
begin += config[1]
self.m_down3 = [
ConvTransBlock(
2 * dim,
2 * dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 4,
)
for i in range(config[2])
] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
begin += config[2]
self.m_body = [
ConvTransBlock(
4 * dim,
4 * dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 8,
)
for i in range(config[3])
]
begin += config[3]
self.m_up3 = [
nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False),
] + [
ConvTransBlock(
2 * dim,
2 * dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 4,
)
for i in range(config[4])
]
begin += config[4]
self.m_up2 = [
nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False),
] + [
ConvTransBlock(
dim,
dim,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution // 2,
)
for i in range(config[5])
]
begin += config[5]
self.m_up1 = [
nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False),
] + [
ConvTransBlock(
dim // 2,
dim // 2,
self.head_dim,
self.window_size,
dpr[i + begin],
"W" if not i % 2 else "SW",
input_resolution,
)
for i in range(config[6])
]
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
self.m_head = nn.Sequential(*self.m_head)
self.m_down1 = nn.Sequential(*self.m_down1)
self.m_down2 = nn.Sequential(*self.m_down2)
self.m_down3 = nn.Sequential(*self.m_down3)
self.m_body = nn.Sequential(*self.m_body)
self.m_up3 = nn.Sequential(*self.m_up3)
self.m_up2 = nn.Sequential(*self.m_up2)
self.m_up1 = nn.Sequential(*self.m_up1)
self.m_tail = nn.Sequential(*self.m_tail)
# self.apply(self._init_weights)
self.load_state_dict(state_dict, strict=True)
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (64 - h % 64) % 64
mod_pad_w = (64 - w % 64) % 64
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
return x
def forward(self, x0):
h, w = x0.size()[-2:]
x0 = self.check_image_size(x0)
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
x = x[:, :, :h, :w]
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)