Prepare for t2i adapter.

This commit is contained in:
comfyanonymous 2023-02-24 23:36:17 -05:00
parent cf5a211efc
commit fcb25d37db
2 changed files with 25 additions and 9 deletions

View File

@ -774,17 +774,23 @@ class UNetModel(nn.Module):
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for id, module in enumerate(self.input_blocks):
h = module(h, emb, context) h = module(h, emb, context)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context)
if control is not None: if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control.pop() h += control['middle'].pop()
for module in self.output_blocks: for module in self.output_blocks:
hsp = hs.pop() hsp = hs.pop()
if control is not None: if control is not None and 'output' in control and len(control['output']) > 0:
hsp += control.pop() ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context) h = module(h, emb, context)

View File

@ -388,18 +388,28 @@ class ControlNet:
self.control_model = model_management.load_if_low_vram(self.control_model) self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.control_model = model_management.unload_if_low_vram(self.control_model) self.control_model = model_management.unload_if_low_vram(self.control_model)
out = [] out = {'input':[], 'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()
for i in range(len(control)): for i in range(len(control)):
if i == (len(control) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control[i] x = control[i]
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype) x = x.to(output_dtype)
if control_prev is not None: if control_prev is not None and key in control_prev:
x += control_prev[i] prev = control_prev[key][index]
out.append(x) if prev is not None:
x += prev
out[key].append(x)
if control_prev is not None and 'input' in control_prev:
out['input'] = control_prev['input']
return out return out
def set_cond_hint(self, cond_hint, strength=1.0): def set_cond_hint(self, cond_hint, strength=1.0):