diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 5e2d7364..e00ffd3f 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -78,7 +78,7 @@ class DDIMSampler(object): dynamic_threshold=None, ucg_schedule=None, denoise_function=None, - cond_concat=None, + extra_args=None, to_zero=True, end_step=None, **kwargs @@ -101,7 +101,7 @@ class DDIMSampler(object): dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=denoise_function, - cond_concat=cond_concat, + extra_args=extra_args, to_zero=to_zero, end_step=end_step ) @@ -174,7 +174,7 @@ class DDIMSampler(object): dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule, denoise_function=None, - cond_concat=None + extra_args=None ) return samples, intermediates @@ -185,7 +185,7 @@ class DDIMSampler(object): mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None): + ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -225,7 +225,7 @@ class DDIMSampler(object): corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat) + dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) @@ -249,11 +249,11 @@ class DDIMSampler(object): def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None, denoise_function=None, cond_concat=None): + dynamic_threshold=None, denoise_function=None, extra_args=None): b, *_, device = *x.shape, x.device if denoise_function is not None: - model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat) + model_output = denoise_function(self.model.apply_model, x, t, **extra_args) elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: model_output = self.model.apply_model(x, t, c) else: diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 42ed2add..6af96124 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module): self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}): if self.conditioning_key is None: - out = self.diffusion_model(x, t, control=control) + out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t, control=control) + out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'crossattn': if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) @@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module): # TorchScript changes names of the arguments # with argument cc defined as context=cc scripted model will produce # an error: RuntimeError: forward() is missing value for argument 'argument_3'. - out = self.scripted_diffusion_model(x, t, cc, control=control) + out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options) else: - out = self.diffusion_model(x, t, context=cc, control=control) + out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, control=control) + out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'hybrid-adm': assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'crossattn-adm': assert c_adm is not None cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) + out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options) elif self.conditioning_key == 'adm': cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc, control=control) + out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options) else: raise NotImplementedError() diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 23b04734..25051b33 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -504,10 +504,10 @@ class BasicTransformerBlock(nn.Module): self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint - def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + def forward(self, x, context=None, transformer_options={}): + return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) - def _forward(self, x, context=None): + def _forward(self, x, context=None, transformer_options={}): x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x @@ -557,7 +557,7 @@ class SpatialTransformer(nn.Module): self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear - def forward(self, x, context=None): + def forward(self, x, context=None, transformer_options={}): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] @@ -570,7 +570,7 @@ class SpatialTransformer(nn.Module): if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): - x = block(x, context=context[i]) + x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 09ab1a06..7b2f5b53 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): support it as an extra input. """ - def forward(self, x, emb, context=None): + def forward(self, x, emb, context=None, transformer_options={}): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): - x = layer(x, context) + x = layer(x, context, transformer_options) else: x = layer(x) return x @@ -753,7 +753,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -762,6 +762,7 @@ class UNetModel(nn.Module): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ + transformer_options["original_shape"] = list(x.shape) assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -775,13 +776,13 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): - h = module(h, emb, context) + h = module(h, emb, context, transformer_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) - h = self.middle_block(h, emb, context) + h = self.middle_block(h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() @@ -793,7 +794,7 @@ class UNetModel(nn.Module): hsp += ctrl h = th.cat([h, hsp], dim=1) del hsp - h = module(h, emb, context) + h = module(h, emb, context, transformer_options) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/samplers.py b/comfy/samplers.py index 66218f88..40d5d332 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module): #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if control is not None: c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) + if 'transformer_options' in model_options: + c['transformer_options'] = model_options['transformer_options'] + output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x @@ -467,7 +470,7 @@ class KSampler: x_T=z_enc, x0=latent_image, denoise_function=sampling_function, - cond_concat=cond_concat, + extra_args=extra_args, mask=noise_mask, to_zero=sigmas[-1]==0, end_step=sigmas.shape[0] - 1)