mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 18:35:17 +00:00
766c7b3815
Don't add SRFormer because the code license is incompatible with the GPL. Remove MAT because it's unused and the license is incompatible with GPL.
144 lines
4.5 KiB
Python
144 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: OmniSR.py
|
|
# Created Date: Tuesday April 28th 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Sunday, 23rd April 2023 3:06:36 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2020 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .OSAG import OSAG
|
|
from .pixelshuffle import pixelshuffle_block
|
|
|
|
|
|
class OmniSR(nn.Module):
|
|
def __init__(
|
|
self,
|
|
state_dict,
|
|
**kwargs,
|
|
):
|
|
super(OmniSR, self).__init__()
|
|
self.state = state_dict
|
|
|
|
bias = True # Fine to assume this for now
|
|
block_num = 1 # Fine to assume this for now
|
|
ffn_bias = True
|
|
pe = True
|
|
|
|
num_feat = state_dict["input.weight"].shape[0] or 64
|
|
num_in_ch = state_dict["input.weight"].shape[1] or 3
|
|
num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
|
|
|
|
pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
|
|
up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
|
|
if up_scale - int(up_scale) > 0:
|
|
print(
|
|
"out_nc is probably different than in_nc, scale calculation might be wrong"
|
|
)
|
|
up_scale = int(up_scale)
|
|
res_num = 0
|
|
for key in state_dict.keys():
|
|
if "residual_layer" in key:
|
|
temp_res_num = int(key.split(".")[1])
|
|
if temp_res_num > res_num:
|
|
res_num = temp_res_num
|
|
res_num = res_num + 1 # zero-indexed
|
|
|
|
residual_layer = []
|
|
self.res_num = res_num
|
|
|
|
if (
|
|
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
|
|
in state_dict.keys()
|
|
):
|
|
rel_pos_bias_weight = state_dict[
|
|
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
|
|
].shape[0]
|
|
self.window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2)
|
|
else:
|
|
self.window_size = 8
|
|
|
|
self.up_scale = up_scale
|
|
|
|
for _ in range(res_num):
|
|
temp_res = OSAG(
|
|
channel_num=num_feat,
|
|
bias=bias,
|
|
block_num=block_num,
|
|
ffn_bias=ffn_bias,
|
|
window_size=self.window_size,
|
|
pe=pe,
|
|
)
|
|
residual_layer.append(temp_res)
|
|
self.residual_layer = nn.Sequential(*residual_layer)
|
|
self.input = nn.Conv2d(
|
|
in_channels=num_in_ch,
|
|
out_channels=num_feat,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=bias,
|
|
)
|
|
self.output = nn.Conv2d(
|
|
in_channels=num_feat,
|
|
out_channels=num_feat,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=bias,
|
|
)
|
|
self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
|
|
|
|
# self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
|
|
|
|
# for m in self.modules():
|
|
# if isinstance(m, nn.Conv2d):
|
|
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
# m.weight.data.normal_(0, sqrt(2. / n))
|
|
|
|
# chaiNNer specific stuff
|
|
self.model_arch = "OmniSR"
|
|
self.sub_type = "SR"
|
|
self.in_nc = num_in_ch
|
|
self.out_nc = num_out_ch
|
|
self.num_feat = num_feat
|
|
self.scale = up_scale
|
|
|
|
self.supports_fp16 = True # TODO: Test this
|
|
self.supports_bfp16 = True
|
|
self.min_size_restriction = 16
|
|
|
|
self.load_state_dict(state_dict, strict=False)
|
|
|
|
def check_image_size(self, x):
|
|
_, _, h, w = x.size()
|
|
# import pdb; pdb.set_trace()
|
|
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
|
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
|
# x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
|
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
H, W = x.shape[2:]
|
|
x = self.check_image_size(x)
|
|
|
|
residual = self.input(x)
|
|
out = self.residual_layer(residual)
|
|
|
|
# origin
|
|
out = torch.add(self.output(out), residual)
|
|
out = self.up(out)
|
|
|
|
out = out[:, :, : H * self.up_scale, : W * self.up_scale]
|
|
return out
|