InstantX depth flux controlnet.

This commit is contained in:
comfyanonymous 2024-08-29 02:14:19 -04:00
parent b33cd61070
commit ea3f39bd69
2 changed files with 32 additions and 16 deletions

View File

@ -148,7 +148,7 @@ class ControlBase:
elif self.strength_type == StrengthType.LINEAR_UP: elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i)) x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype: if output_dtype is not None and x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
out[key].append(x) out[key].append(x)
@ -206,7 +206,6 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None: if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
@ -236,7 +235,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype) return self.control_merge(control, control_prev, output_dtype=None)
def copy(self): def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)

View File

@ -23,8 +23,12 @@ class ControlNetFlux(Flux):
self.controlnet_blocks = nn.ModuleList([]) self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth): for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.latent_input = latent_input self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
@ -78,26 +82,39 @@ class ControlNetFlux(Flux):
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
block_res_samples = () controlnet_double = ()
for block in self.double_blocks: for i in range(len(self.double_blocks)):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe) img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,) controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
controlnet_block_res_samples = () img = torch.cat((txt, img), 1)
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
controlnet_single = ()
repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples)) for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input: if self.latent_input:
out_input = () out_input = ()
for x in controlnet_block_res_samples: for x in controlnet_double:
out_input += (x,) * repeat out_input += (x,) * repeat
else: else:
out_input = (controlnet_block_res_samples * repeat) out_input = (controlnet_double * repeat)
return {"input": out_input[:self.main_model_double]}
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2 patch_size = 2