diff --git a/comfy/sd.py b/comfy/sd.py index bf67f128..a74d1e82 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -334,8 +334,13 @@ class ControlNet: self.cond_hint = None self.strength = 1.0 self.device = device + self.previous_controlnet = None def get_control(self, x_noisy, t, cond_txt): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: @@ -354,10 +359,15 @@ class ControlNet: self.control_model = model_management.unload_if_low_vram(self.control_model) out = [] autocast_enabled = torch.is_autocast_enabled() - for x in control: + + for i in range(len(control)): + x = control[i] x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) + + if control_prev is not None: + x += control_prev[i] out.append(x) return out @@ -366,7 +376,13 @@ class ControlNet: self.strength = strength return self + def set_previous_controlnet(self, controlnet): + self.previous_controlnet = controlnet + return self + def cleanup(self): + if self.previous_controlnet is not None: + self.previous_controlnet.cleanup() if self.cond_hint is not None: del self.cond_hint self.cond_hint = None @@ -377,6 +393,13 @@ class ControlNet: c.strength = self.strength return c + def get_control_models(self): + out = [] + if self.previous_controlnet is not None: + out += self.previous_controlnet.get_control_models() + out.append(self.control_model) + return out + def load_controlnet(ckpt_path): controlnet_data = load_torch_file(ckpt_path) pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' diff --git a/nodes.py b/nodes.py index ef1201de..3bdad71b 100644 --- a/nodes.py +++ b/nodes.py @@ -252,7 +252,10 @@ class ControlNetApply: print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] - n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength) + c_net = control_net.copy().set_cond_hint(control_hint, strength) + if 'control' in t[1]: + c_net.set_previous_controlnet(t[1]['control']) + n[1]['control'] = c_net c.append(n) return (c, ) @@ -510,7 +513,10 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po control_nets += [p[1]['control']] negative_copy += [[t] + n[1:]] - model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets))) + control_net_models = [] + for x in control_nets: + control_net_models += x.get_control_models() + model_management.load_controlnet_gpu(control_net_models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)