Controlnet refactor.

This commit is contained in:
comfyanonymous 2024-06-25 17:02:05 -04:00
parent 97b409cd48
commit 66aaa14001
4 changed files with 24 additions and 32 deletions

View File

@ -289,7 +289,8 @@ class ControlNet(nn.Module):
guided_hint = self.input_hint_block(hint, emb, context) guided_hint = self.input_hint_block(hint, emb, context)
outs = [] out_output = []
out_middle = []
hs = [] hs = []
if self.num_classes is not None: if self.num_classes is not None:
@ -304,10 +305,10 @@ class ControlNet(nn.Module):
guided_hint = None guided_hint = None
else: else:
h = module(h, emb, context) h = module(h, emb, context)
outs.append(zero_conv(h, emb, context)) out_output.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context)) out_middle.append(self.middle_block_out(h, emb, context))
return outs return {"middle": out_middle, "output": out_output}

View File

@ -89,27 +89,12 @@ class ControlBase:
return self.previous_controlnet.inference_memory_requirements(dtype) return self.previous_controlnet.inference_memory_requirements(dtype)
return 0 return 0
def control_merge(self, control_input, control_output, control_prev, output_dtype): def control_merge(self, control, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []} out = {'input':[], 'middle':[], 'output': []}
if control_input is not None: for key in control:
for i in range(len(control_input)): control_output = control[key]
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for i in range(len(control_output)): for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i] x = control_output[i]
if x is not None: if x is not None:
if self.global_average_pooling: if self.global_average_pooling:
@ -120,6 +105,7 @@ class ControlBase:
x = x.to(output_dtype) x = x.to(output_dtype)
out[key].append(x) out[key].append(x)
if control_prev is not None: if control_prev is not None:
for x in ['input', 'middle', 'output']: for x in ['input', 'middle', 'output']:
o = out[x] o = out[x]
@ -182,7 +168,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(control, control_prev, output_dtype)
def copy(self): def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@ -490,12 +476,11 @@ class T2IAdapter(ControlBase):
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu() self.t2i_model.cpu()
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) control_input = {}
mid = None for k in self.control_input:
if self.t2i_model.xl == True: control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
mid = control_input[-1:]
control_input = control_input[:-1] return self.control_merge(control_input, control_prev, x_noisy.dtype)
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
def copy(self): def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)

View File

@ -90,4 +90,4 @@ class ControlNet(nn.Module):
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
for i, idx in enumerate(self.proj_blocks): for i, idx in enumerate(self.proj_blocks):
proj_outputs[idx] = self.projections[i](x) proj_outputs[idx] = self.projections[i](x)
return proj_outputs return {"input": proj_outputs[::-1]}

View File

@ -153,7 +153,13 @@ class Adapter(nn.Module):
features.append(None) features.append(None)
features.append(x) features.append(x)
return features features = features[::-1]
if self.xl:
return {"input": features[1:], "middle": features[:1]}
else:
return {"input": features}
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
@ -290,4 +296,4 @@ class Adapter_light(nn.Module):
features.append(None) features.append(None)
features.append(x) features.append(x)
return features return {"input": features[::-1]}