diff --git a/comfy/sd.py b/comfy/sd.py index 09eab505..d75bbd9a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1128,7 +1128,11 @@ class T2IAdapter(ControlBase): self.t2i_model.cpu() control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) - return self.control_merge(control_input, None, control_prev, x_noisy.dtype) + mid = None + if self.t2i_model.xl == True: + mid = control_input[-1:] + control_input = control_input[:-1] + return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in) @@ -1151,11 +1155,20 @@ def load_t2i_adapter(t2i_data): down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys)) if len(down_opts) > 0: use_conv = True - model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv) + xl = False + if cin == 256: + xl = True + model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) else: return None - model_ad.load_state_dict(t2i_data) - return T2IAdapter(model_ad, cin // 64) + missing, unexpected = model_ad.load_state_dict(t2i_data) + if len(missing) > 0: + print("t2i missing", missing) + + if len(unexpected) > 0: + print("t2i unexpected", unexpected) + + return T2IAdapter(model_ad, model_ad.input_channels) class StyleModel: diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 3647c4cf..000cf041 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -101,17 +101,30 @@ class ResnetBlock(nn.Module): class Adapter(nn.Module): - def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): + def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True): super(Adapter, self).__init__() - self.unshuffle = nn.PixelUnshuffle(8) + unshuffle = 8 + resblock_no_downsample = [] + resblock_downsample = [3, 2, 1] + self.xl = xl + if self.xl: + unshuffle = 16 + resblock_no_downsample = [1] + resblock_downsample = [2] + + self.input_channels = cin // (unshuffle * unshuffle) + self.unshuffle = nn.PixelUnshuffle(unshuffle) self.channels = channels self.nums_rb = nums_rb self.body = [] for i in range(len(channels)): for j in range(nums_rb): - if (i != 0) and (j == 0): + if (i in resblock_downsample) and (j == 0): self.body.append( ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) + elif (i in resblock_no_downsample) and (j == 0): + self.body.append( + ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) else: self.body.append( ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) @@ -128,8 +141,16 @@ class Adapter(nn.Module): for j in range(self.nums_rb): idx = i * self.nums_rb + j x = self.body[idx](x) - features.append(None) - features.append(None) + if self.xl: + features.append(None) + if i == 0: + features.append(None) + features.append(None) + if i == 2: + features.append(None) + else: + features.append(None) + features.append(None) features.append(x) return features @@ -243,10 +264,14 @@ class extractor(nn.Module): class Adapter_light(nn.Module): def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64): super(Adapter_light, self).__init__() - self.unshuffle = nn.PixelUnshuffle(8) + unshuffle = 8 + self.unshuffle = nn.PixelUnshuffle(unshuffle) + self.input_channels = cin // (unshuffle * unshuffle) self.channels = channels self.nums_rb = nums_rb self.body = [] + self.xl = False + for i in range(len(channels)): if i == 0: self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))