T2I adapter SDXL.

This commit is contained in:
comfyanonymous 2023-08-22 14:38:34 -04:00
parent f2a7cc9121
commit 85fde89d7f
2 changed files with 48 additions and 10 deletions

View File

@ -1128,7 +1128,11 @@ class T2IAdapter(ControlBase):
self.t2i_model.cpu() self.t2i_model.cpu()
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
return self.control_merge(control_input, None, control_prev, x_noisy.dtype) mid = None
if self.t2i_model.xl == True:
mid = control_input[-1:]
control_input = control_input[:-1]
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
def copy(self): def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in) c = T2IAdapter(self.t2i_model, self.channels_in)
@ -1151,11 +1155,20 @@ def load_t2i_adapter(t2i_data):
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
if len(down_opts) > 0: if len(down_opts) > 0:
use_conv = True use_conv = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv) xl = False
if cin == 256:
xl = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
else: else:
return None return None
model_ad.load_state_dict(t2i_data) missing, unexpected = model_ad.load_state_dict(t2i_data)
return T2IAdapter(model_ad, cin // 64) if len(missing) > 0:
print("t2i missing", missing)
if len(unexpected) > 0:
print("t2i unexpected", unexpected)
return T2IAdapter(model_ad, model_ad.input_channels)
class StyleModel: class StyleModel:

View File

@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
class Adapter(nn.Module): class Adapter(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
super(Adapter, self).__init__() super(Adapter, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8) unshuffle = 8
resblock_no_downsample = []
resblock_downsample = [3, 2, 1]
self.xl = xl
if self.xl:
unshuffle = 16
resblock_no_downsample = [1]
resblock_downsample = [2]
self.input_channels = cin // (unshuffle * unshuffle)
self.unshuffle = nn.PixelUnshuffle(unshuffle)
self.channels = channels self.channels = channels
self.nums_rb = nums_rb self.nums_rb = nums_rb
self.body = [] self.body = []
for i in range(len(channels)): for i in range(len(channels)):
for j in range(nums_rb): for j in range(nums_rb):
if (i != 0) and (j == 0): if (i in resblock_downsample) and (j == 0):
self.body.append( self.body.append(
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
elif (i in resblock_no_downsample) and (j == 0):
self.body.append(
ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
else: else:
self.body.append( self.body.append(
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
@ -128,6 +141,14 @@ class Adapter(nn.Module):
for j in range(self.nums_rb): for j in range(self.nums_rb):
idx = i * self.nums_rb + j idx = i * self.nums_rb + j
x = self.body[idx](x) x = self.body[idx](x)
if self.xl:
features.append(None)
if i == 0:
features.append(None)
features.append(None)
if i == 2:
features.append(None)
else:
features.append(None) features.append(None)
features.append(None) features.append(None)
features.append(x) features.append(x)
@ -243,10 +264,14 @@ class extractor(nn.Module):
class Adapter_light(nn.Module): class Adapter_light(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
super(Adapter_light, self).__init__() super(Adapter_light, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8) unshuffle = 8
self.unshuffle = nn.PixelUnshuffle(unshuffle)
self.input_channels = cin // (unshuffle * unshuffle)
self.channels = channels self.channels = channels
self.nums_rb = nums_rb self.nums_rb = nums_rb
self.body = [] self.body = []
self.xl = False
for i in range(len(channels)): for i in range(len(channels)):
if i == 0: if i == 0:
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False)) self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))