mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Prepare for t2i adapter.
This commit is contained in:
parent
cf5a211efc
commit
fcb25d37db
@ -774,17 +774,23 @@ class UNetModel(nn.Module):
|
|||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for module in self.input_blocks:
|
for id, module in enumerate(self.input_blocks):
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||||
|
ctrl = control['input'].pop()
|
||||||
|
if ctrl is not None:
|
||||||
|
h += ctrl
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
h = self.middle_block(h, emb, context)
|
h = self.middle_block(h, emb, context)
|
||||||
if control is not None:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control.pop()
|
h += control['middle'].pop()
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
hsp = hs.pop()
|
hsp = hs.pop()
|
||||||
if control is not None:
|
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||||
hsp += control.pop()
|
ctrl = control['output'].pop()
|
||||||
|
if ctrl is not None:
|
||||||
|
hsp += ctrl
|
||||||
h = th.cat([h, hsp], dim=1)
|
h = th.cat([h, hsp], dim=1)
|
||||||
del hsp
|
del hsp
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
|
18
comfy/sd.py
18
comfy/sd.py
@ -388,18 +388,28 @@ class ControlNet:
|
|||||||
self.control_model = model_management.load_if_low_vram(self.control_model)
|
self.control_model = model_management.load_if_low_vram(self.control_model)
|
||||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||||
self.control_model = model_management.unload_if_low_vram(self.control_model)
|
self.control_model = model_management.unload_if_low_vram(self.control_model)
|
||||||
out = []
|
out = {'input':[], 'middle':[], 'output': []}
|
||||||
autocast_enabled = torch.is_autocast_enabled()
|
autocast_enabled = torch.is_autocast_enabled()
|
||||||
|
|
||||||
for i in range(len(control)):
|
for i in range(len(control)):
|
||||||
|
if i == (len(control) - 1):
|
||||||
|
key = 'middle'
|
||||||
|
index = 0
|
||||||
|
else:
|
||||||
|
key = 'output'
|
||||||
|
index = i
|
||||||
x = control[i]
|
x = control[i]
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
if x.dtype != output_dtype and not autocast_enabled:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
if control_prev is not None:
|
if control_prev is not None and key in control_prev:
|
||||||
x += control_prev[i]
|
prev = control_prev[key][index]
|
||||||
out.append(x)
|
if prev is not None:
|
||||||
|
x += prev
|
||||||
|
out[key].append(x)
|
||||||
|
if control_prev is not None and 'input' in control_prev:
|
||||||
|
out['input'] = control_prev['input']
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
def set_cond_hint(self, cond_hint, strength=1.0):
|
||||||
|
Loading…
Reference in New Issue
Block a user