ControlNetApply now stacks.

It can be used to apply multiple control nets at the same time.
This commit is contained in:
comfyanonymous 2023-02-21 01:18:53 -05:00
parent 8683ea4248
commit d80af7ca30
2 changed files with 32 additions and 3 deletions

View File

@ -334,8 +334,13 @@ class ControlNet:
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
self.device = device self.device = device
self.previous_controlnet = None
def get_control(self, x_noisy, t, cond_txt): def get_control(self, x_noisy, t, cond_txt):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt)
output_dtype = x_noisy.dtype output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
@ -354,10 +359,15 @@ class ControlNet:
self.control_model = model_management.unload_if_low_vram(self.control_model) self.control_model = model_management.unload_if_low_vram(self.control_model)
out = [] out = []
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()
for x in control:
for i in range(len(control)):
x = control[i]
x *= self.strength x *= self.strength
if x.dtype != output_dtype and not autocast_enabled: if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype) x = x.to(output_dtype)
if control_prev is not None:
x += control_prev[i]
out.append(x) out.append(x)
return out return out
@ -366,7 +376,13 @@ class ControlNet:
self.strength = strength self.strength = strength
return self return self
def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self
def cleanup(self): def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
@ -377,6 +393,13 @@ class ControlNet:
c.strength = self.strength c.strength = self.strength
return c return c
def get_control_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models()
out.append(self.control_model)
return out
def load_controlnet(ckpt_path): def load_controlnet(ckpt_path):
controlnet_data = load_torch_file(ckpt_path) controlnet_data = load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'

View File

@ -252,7 +252,10 @@ class ControlNetApply:
print(control_hint.shape) print(control_hint.shape)
for t in conditioning: for t in conditioning:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength) c_net = control_net.copy().set_cond_hint(control_hint, strength)
if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net
c.append(n) c.append(n)
return (c, ) return (c, )
@ -510,7 +513,10 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
control_nets += [p[1]['control']] control_nets += [p[1]['control']]
negative_copy += [[t] + n[1:]] negative_copy += [[t] + n[1:]]
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets))) control_net_models = []
for x in control_nets:
control_net_models += x.get_control_models()
model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)