mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
T2I adapter SDXL.
This commit is contained in:
parent
f2a7cc9121
commit
85fde89d7f
21
comfy/sd.py
21
comfy/sd.py
@ -1128,7 +1128,11 @@ class T2IAdapter(ControlBase):
|
|||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
|
|
||||||
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
||||||
return self.control_merge(control_input, None, control_prev, x_noisy.dtype)
|
mid = None
|
||||||
|
if self.t2i_model.xl == True:
|
||||||
|
mid = control_input[-1:]
|
||||||
|
control_input = control_input[:-1]
|
||||||
|
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = T2IAdapter(self.t2i_model, self.channels_in)
|
c = T2IAdapter(self.t2i_model, self.channels_in)
|
||||||
@ -1151,11 +1155,20 @@ def load_t2i_adapter(t2i_data):
|
|||||||
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
|
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
|
||||||
if len(down_opts) > 0:
|
if len(down_opts) > 0:
|
||||||
use_conv = True
|
use_conv = True
|
||||||
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv)
|
xl = False
|
||||||
|
if cin == 256:
|
||||||
|
xl = True
|
||||||
|
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
model_ad.load_state_dict(t2i_data)
|
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||||
return T2IAdapter(model_ad, cin // 64)
|
if len(missing) > 0:
|
||||||
|
print("t2i missing", missing)
|
||||||
|
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
print("t2i unexpected", unexpected)
|
||||||
|
|
||||||
|
return T2IAdapter(model_ad, model_ad.input_channels)
|
||||||
|
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
|
@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Adapter(nn.Module):
|
class Adapter(nn.Module):
|
||||||
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
|
||||||
super(Adapter, self).__init__()
|
super(Adapter, self).__init__()
|
||||||
self.unshuffle = nn.PixelUnshuffle(8)
|
unshuffle = 8
|
||||||
|
resblock_no_downsample = []
|
||||||
|
resblock_downsample = [3, 2, 1]
|
||||||
|
self.xl = xl
|
||||||
|
if self.xl:
|
||||||
|
unshuffle = 16
|
||||||
|
resblock_no_downsample = [1]
|
||||||
|
resblock_downsample = [2]
|
||||||
|
|
||||||
|
self.input_channels = cin // (unshuffle * unshuffle)
|
||||||
|
self.unshuffle = nn.PixelUnshuffle(unshuffle)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.nums_rb = nums_rb
|
self.nums_rb = nums_rb
|
||||||
self.body = []
|
self.body = []
|
||||||
for i in range(len(channels)):
|
for i in range(len(channels)):
|
||||||
for j in range(nums_rb):
|
for j in range(nums_rb):
|
||||||
if (i != 0) and (j == 0):
|
if (i in resblock_downsample) and (j == 0):
|
||||||
self.body.append(
|
self.body.append(
|
||||||
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
|
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
|
||||||
|
elif (i in resblock_no_downsample) and (j == 0):
|
||||||
|
self.body.append(
|
||||||
|
ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
||||||
else:
|
else:
|
||||||
self.body.append(
|
self.body.append(
|
||||||
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
||||||
@ -128,6 +141,14 @@ class Adapter(nn.Module):
|
|||||||
for j in range(self.nums_rb):
|
for j in range(self.nums_rb):
|
||||||
idx = i * self.nums_rb + j
|
idx = i * self.nums_rb + j
|
||||||
x = self.body[idx](x)
|
x = self.body[idx](x)
|
||||||
|
if self.xl:
|
||||||
|
features.append(None)
|
||||||
|
if i == 0:
|
||||||
|
features.append(None)
|
||||||
|
features.append(None)
|
||||||
|
if i == 2:
|
||||||
|
features.append(None)
|
||||||
|
else:
|
||||||
features.append(None)
|
features.append(None)
|
||||||
features.append(None)
|
features.append(None)
|
||||||
features.append(x)
|
features.append(x)
|
||||||
@ -243,10 +264,14 @@ class extractor(nn.Module):
|
|||||||
class Adapter_light(nn.Module):
|
class Adapter_light(nn.Module):
|
||||||
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
||||||
super(Adapter_light, self).__init__()
|
super(Adapter_light, self).__init__()
|
||||||
self.unshuffle = nn.PixelUnshuffle(8)
|
unshuffle = 8
|
||||||
|
self.unshuffle = nn.PixelUnshuffle(unshuffle)
|
||||||
|
self.input_channels = cin // (unshuffle * unshuffle)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.nums_rb = nums_rb
|
self.nums_rb = nums_rb
|
||||||
self.body = []
|
self.body = []
|
||||||
|
self.xl = False
|
||||||
|
|
||||||
for i in range(len(channels)):
|
for i in range(len(channels)):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
||||||
|
Loading…
Reference in New Issue
Block a user