Add a way to pass options to the transformers blocks.

This commit is contained in:
comfyanonymous 2023-03-31 13:04:39 -04:00
parent 04b42bad87
commit 61ec3c9d5d
5 changed files with 33 additions and 29 deletions

View File

@ -78,7 +78,7 @@ class DDIMSampler(object):
dynamic_threshold=None, dynamic_threshold=None,
ucg_schedule=None, ucg_schedule=None,
denoise_function=None, denoise_function=None,
cond_concat=None, extra_args=None,
to_zero=True, to_zero=True,
end_step=None, end_step=None,
**kwargs **kwargs
@ -101,7 +101,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold, dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule, ucg_schedule=ucg_schedule,
denoise_function=denoise_function, denoise_function=denoise_function,
cond_concat=cond_concat, extra_args=extra_args,
to_zero=to_zero, to_zero=to_zero,
end_step=end_step end_step=end_step
) )
@ -174,7 +174,7 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold, dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule, ucg_schedule=ucg_schedule,
denoise_function=None, denoise_function=None,
cond_concat=None extra_args=None
) )
return samples, intermediates return samples, intermediates
@ -185,7 +185,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None): ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -225,7 +225,7 @@ class DDIMSampler(object):
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat) dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
@ -249,11 +249,11 @@ class DDIMSampler(object):
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, cond_concat=None): dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if denoise_function is not None: if denoise_function is not None:
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat) model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c) model_output = self.model.apply_model(x, t, c)
else: else:

View File

@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
self.conditioning_key = conditioning_key self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t, control=control) out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, control=control) out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn': elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn: if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments # TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce # with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'. # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc, control=control) out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
else: else:
out = self.diffusion_model(x, t, context=cc, control=control) out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, control=control) out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid-adm': elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn-adm': elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None assert c_adm is not None
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'adm': elif self.conditioning_key == 'adm':
cc = c_crossattn[0] cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc, control=control) out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
else: else:
raise NotImplementedError() raise NotImplementedError()

View File

@ -504,10 +504,10 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
def forward(self, x, context=None): def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None): def _forward(self, x, context=None, transformer_options={}):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
@ -557,7 +557,7 @@ class SpatialTransformer(nn.Module):
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None): def forward(self, x, context=None, transformer_options={}):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list): if not isinstance(context, list):
context = [context] context = [context]
@ -570,7 +570,7 @@ class SpatialTransformer(nn.Module):
if self.use_linear: if self.use_linear:
x = self.proj_in(x) x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i]) x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear: if self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()

View File

@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None): def forward(self, x, emb, context=None, transformer_options={}):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context) x = layer(x, context, transformer_options)
else: else:
x = layer(x) x = layer(x)
return x return x
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
@ -762,6 +762,7 @@ class UNetModel(nn.Module):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
transformer_options["original_shape"] = list(x.shape)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
@ -775,13 +776,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype) h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
h = module(h, emb, context) h = module(h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0: if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop() ctrl = control['input'].pop()
if ctrl is not None: if ctrl is not None:
h += ctrl h += ctrl
hs.append(h) hs.append(h)
h = self.middle_block(h, emb, context) h = self.middle_block(h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0: if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop() h += control['middle'].pop()
@ -793,7 +794,7 @@ class UNetModel(nn.Module):
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context) h = module(h, emb, context, transformer_options)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

View File

@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options']
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
@ -467,7 +470,7 @@ class KSampler:
x_T=z_enc, x_T=z_enc,
x0=latent_image, x0=latent_image,
denoise_function=sampling_function, denoise_function=sampling_function,
cond_concat=cond_concat, extra_args=extra_args,
mask=noise_mask, mask=noise_mask,
to_zero=sigmas[-1]==0, to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1) end_step=sigmas.shape[0] - 1)