mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
ControlNetApply now stacks.
It can be used to apply multiple control nets at the same time.
This commit is contained in:
parent
8683ea4248
commit
d80af7ca30
25
comfy/sd.py
25
comfy/sd.py
@ -334,8 +334,13 @@ class ControlNet:
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.previous_controlnet = None
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond_txt):
|
def get_control(self, x_noisy, t, cond_txt):
|
||||||
|
control_prev = None
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt)
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
output_dtype = x_noisy.dtype
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
@ -354,10 +359,15 @@ class ControlNet:
|
|||||||
self.control_model = model_management.unload_if_low_vram(self.control_model)
|
self.control_model = model_management.unload_if_low_vram(self.control_model)
|
||||||
out = []
|
out = []
|
||||||
autocast_enabled = torch.is_autocast_enabled()
|
autocast_enabled = torch.is_autocast_enabled()
|
||||||
for x in control:
|
|
||||||
|
for i in range(len(control)):
|
||||||
|
x = control[i]
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
if x.dtype != output_dtype and not autocast_enabled:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
|
if control_prev is not None:
|
||||||
|
x += control_prev[i]
|
||||||
out.append(x)
|
out.append(x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -366,7 +376,13 @@ class ControlNet:
|
|||||||
self.strength = strength
|
self.strength = strength
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def set_previous_controlnet(self, controlnet):
|
||||||
|
self.previous_controlnet = controlnet
|
||||||
|
return self
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
self.previous_controlnet.cleanup()
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
@ -377,6 +393,13 @@ class ControlNet:
|
|||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def get_control_models(self):
|
||||||
|
out = []
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
out += self.previous_controlnet.get_control_models()
|
||||||
|
out.append(self.control_model)
|
||||||
|
return out
|
||||||
|
|
||||||
def load_controlnet(ckpt_path):
|
def load_controlnet(ckpt_path):
|
||||||
controlnet_data = load_torch_file(ckpt_path)
|
controlnet_data = load_torch_file(ckpt_path)
|
||||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
|
10
nodes.py
10
nodes.py
@ -252,7 +252,10 @@ class ControlNetApply:
|
|||||||
print(control_hint.shape)
|
print(control_hint.shape)
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength)
|
c_net = control_net.copy().set_cond_hint(control_hint, strength)
|
||||||
|
if 'control' in t[1]:
|
||||||
|
c_net.set_previous_controlnet(t[1]['control'])
|
||||||
|
n[1]['control'] = c_net
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
@ -510,7 +513,10 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
|||||||
control_nets += [p[1]['control']]
|
control_nets += [p[1]['control']]
|
||||||
negative_copy += [[t] + n[1:]]
|
negative_copy += [[t] + n[1:]]
|
||||||
|
|
||||||
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
|
control_net_models = []
|
||||||
|
for x in control_nets:
|
||||||
|
control_net_models += x.get_control_models()
|
||||||
|
model_management.load_controlnet_gpu(control_net_models)
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||||
|
Loading…
Reference in New Issue
Block a user