mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Merge f4e86a4a07
into ff838657fa
This commit is contained in:
commit
3b40134e3b
@ -50,8 +50,7 @@ class ResBlockUnionControlnet(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
return x + self.mlp(self.ln_2(x))
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
#implemented in the ldm unet
|
||||
|
@ -36,8 +36,7 @@ class CLIPMLP(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
return self.fc2(x)
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
||||
|
@ -270,5 +270,4 @@ class CheckLazyMixin:
|
||||
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
|
||||
"""
|
||||
|
||||
need = [name for name in kwargs if kwargs[name] is None]
|
||||
return need
|
||||
return [name for name in kwargs if kwargs[name] is None]
|
||||
|
@ -404,8 +404,7 @@ class ControlLora(ControlNet):
|
||||
super().cleanup()
|
||||
|
||||
def get_models(self):
|
||||
out = ControlBase.get_models(self)
|
||||
return out
|
||||
return ControlBase.get_models(self)
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
@ -461,8 +460,7 @@ def load_controlnet_mmdit(sd, model_options={}):
|
||||
|
||||
latent_format = comfy.latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
|
||||
|
||||
class ControlNetSD35(ControlNet):
|
||||
@ -536,8 +534,7 @@ def load_controlnet_sd35(sd, model_options={}):
|
||||
elif depth_cnet:
|
||||
preprocess_image = lambda a: 1.0 - a
|
||||
|
||||
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||
return control
|
||||
return ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||
|
||||
|
||||
|
||||
@ -549,16 +546,14 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
|
||||
latent_format = comfy.latent_formats.SDXL()
|
||||
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
return ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
|
||||
def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||
@ -581,8 +576,7 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
||||
|
||||
latent_format = comfy.latent_formats.Flux()
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
@ -738,8 +732,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
return ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, model_options={}):
|
||||
if "global_average_pooling" not in model_options:
|
||||
|
@ -99,8 +99,7 @@ def convert_unet_state_dict(unet_state_dict):
|
||||
for sd_part, hf_part in unet_conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
return {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||
|
||||
|
||||
# ================#
|
||||
|
@ -136,8 +136,7 @@ class NoiseScheduleVP:
|
||||
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
elif self.schedule == 'cosine':
|
||||
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
||||
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||
return log_alpha_t
|
||||
return log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||
|
||||
def marginal_alpha(self, t):
|
||||
"""
|
||||
@ -174,8 +173,7 @@ class NoiseScheduleVP:
|
||||
else:
|
||||
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
t = t_fn(log_alpha)
|
||||
return t
|
||||
return t_fn(log_alpha)
|
||||
|
||||
|
||||
def model_wrapper(
|
||||
@ -378,8 +376,7 @@ class UniPC:
|
||||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
return torch.clamp(x0, -s, s) / s
|
||||
|
||||
def noise_prediction_fn(self, x, t):
|
||||
"""
|
||||
@ -423,8 +420,7 @@ class UniPC:
|
||||
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||
elif skip_type == 'time_quadratic':
|
||||
t_order = 2
|
||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||
return t
|
||||
return torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||
else:
|
||||
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
||||
|
||||
@ -801,8 +797,7 @@ def interpolate_fn(x, xp, yp):
|
||||
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||
return cand
|
||||
return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||
|
||||
|
||||
def expand_dims(v, dims):
|
||||
|
@ -78,10 +78,9 @@ class GatedCrossAttentionDense(nn.Module):
|
||||
|
||||
x = x + self.scale * \
|
||||
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
|
||||
x = x + self.scale * \
|
||||
return x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GatedSelfAttentionDense(nn.Module):
|
||||
@ -118,10 +117,9 @@ class GatedSelfAttentionDense(nn.Module):
|
||||
|
||||
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
|
||||
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
|
||||
x = x + self.scale * \
|
||||
return x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GatedSelfAttentionDense2(nn.Module):
|
||||
@ -172,10 +170,9 @@ class GatedSelfAttentionDense2(nn.Module):
|
||||
|
||||
# add residual to visual feature
|
||||
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
|
||||
x = x + self.scale * \
|
||||
return x + self.scale * \
|
||||
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FourierEmbedder():
|
||||
@ -340,5 +337,4 @@ def load_gligen(sd):
|
||||
w.position_net = PositionNet(in_dim, out_dim)
|
||||
w.load_state_dict(sd, strict=False)
|
||||
|
||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||
return gligen
|
||||
return Gligen(output_list, w.position_net, key_dim)
|
||||
|
@ -46,8 +46,7 @@ def cal_intergrand(beta_0, beta_1, taus):
|
||||
log_alpha = alpha.log()
|
||||
log_alpha.sum().backward()
|
||||
d_log_alpha_dtau = taus.grad
|
||||
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
|
||||
return integrand
|
||||
return -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
|
@ -50,8 +50,7 @@ def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
|
||||
x = torch.linspace(0, 1, n, device=device)
|
||||
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
|
||||
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
|
||||
sigmas = clamp(torch.exp(lmb))
|
||||
return sigmas
|
||||
return clamp(torch.exp(lmb))
|
||||
|
||||
|
||||
|
||||
|
@ -70,9 +70,8 @@ class SnakeBeta(nn.Module):
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = snake_beta(x, alpha, beta)
|
||||
return snake_beta(x, alpha, beta)
|
||||
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
|
@ -89,8 +89,7 @@ class AbsolutePositionalEmbedding(nn.Module):
|
||||
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||
|
||||
pos_emb = self.emb(pos)
|
||||
pos_emb = pos_emb * self.scale
|
||||
return pos_emb
|
||||
return pos_emb * self.scale
|
||||
|
||||
class ScaledSinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta = 10000):
|
||||
@ -404,9 +403,8 @@ class ConformerModule(nn.Module):
|
||||
x = self.swish(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv_2(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
return rearrange(x, 'b d n -> b n d')
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
|
@ -21,8 +21,7 @@ class LearnedPositionalEmbedding(nn.Module):
|
||||
x = rearrange(x, "b -> b 1")
|
||||
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
||||
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
||||
fouriered = torch.cat((x, fouriered), dim=-1)
|
||||
return fouriered
|
||||
return torch.cat((x, fouriered), dim=-1)
|
||||
|
||||
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
||||
return nn.Sequential(
|
||||
@ -49,8 +48,7 @@ class NumberEmbedder(nn.Module):
|
||||
shape = x.shape
|
||||
x = rearrange(x, "... -> (...)")
|
||||
embedding = self.embedding(x)
|
||||
x = embedding.view(*shape, self.features)
|
||||
return x # type: ignore
|
||||
return embedding.view(*shape, self.features)
|
||||
|
||||
|
||||
class Conditioner(nn.Module):
|
||||
|
@ -36,8 +36,7 @@ class MLP(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
return self.c_proj(x)
|
||||
|
||||
|
||||
class MultiHeadLayerNorm(nn.Module):
|
||||
@ -95,8 +94,7 @@ class SingleAttention(nn.Module):
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
return self.w1o(output)
|
||||
|
||||
|
||||
|
||||
@ -265,9 +263,8 @@ class DiTBlock(nn.Module):
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
|
||||
cx = cxres + cx
|
||||
return cxres + cx
|
||||
|
||||
return cx
|
||||
|
||||
|
||||
|
||||
@ -298,8 +295,7 @@ class TimestepEmbedder(nn.Module):
|
||||
#@torch.compile()
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
return self.mlp(t_freq)
|
||||
|
||||
|
||||
class MMDiT(nn.Module):
|
||||
@ -401,8 +397,7 @@ class MMDiT(nn.Module):
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
return x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
@ -415,8 +410,7 @@ class MMDiT(nn.Module):
|
||||
(W + 1) // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
return x
|
||||
return x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
|
||||
def apply_pos_embeds(self, x, h, w):
|
||||
h = (h + 1) // self.patch_size
|
||||
@ -494,5 +488,4 @@ class MMDiT(nn.Module):
|
||||
|
||||
x = modulate(x, fshift, fscale)
|
||||
x = self.final_linear(x)
|
||||
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
|
||||
return x
|
||||
return self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
|
||||
|
@ -54,8 +54,7 @@ class Attention2D(nn.Module):
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
return x.permute(0, 2, 1).view(*orig_shape)
|
||||
|
||||
|
||||
def LayerNorm2d_op(operations):
|
||||
@ -116,8 +115,7 @@ class AttnBlock(nn.Module):
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
return x
|
||||
return x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
@ -133,8 +131,7 @@ class FeedForwardBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
return x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
|
@ -157,9 +157,8 @@ class ResBlock(nn.Module):
|
||||
x = x + self.depthwise[1](x_temp) * mods[2]
|
||||
|
||||
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
||||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
||||
return x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StageA(nn.Module):
|
||||
@ -218,8 +217,7 @@ class StageA(nn.Module):
|
||||
|
||||
def decode(self, x):
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
return self.out_block(x)
|
||||
|
||||
def forward(self, x, quantize=False):
|
||||
qe, x, _, vq_loss = self.encode(x, quantize)
|
||||
@ -251,5 +249,4 @@ class Discriminator(nn.Module):
|
||||
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
||||
x = torch.cat([x, cond], dim=1)
|
||||
x = self.shuffle(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
return self.logits(x)
|
||||
|
@ -170,8 +170,7 @@ class StageB(nn.Module):
|
||||
if len(clip.shape) == 2:
|
||||
clip = clip.unsqueeze(1)
|
||||
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
return self.clip_norm(clip)
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
|
@ -179,8 +179,7 @@ class StageC(nn.Module):
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
return self.clip_norm(clip)
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
level_outputs = []
|
||||
|
@ -35,8 +35,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
return self.mapper(self.backbone(x))
|
||||
|
||||
|
||||
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
||||
|
@ -256,5 +256,4 @@ class LastLayer(nn.Module):
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
@ -9,8 +9,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
return optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
@ -183,8 +183,7 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
|
@ -21,5 +21,4 @@ class ReduxImageEncoder(torch.nn.Module):
|
||||
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
return projected_x
|
||||
return self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||
|
@ -34,9 +34,8 @@ import comfy.ops
|
||||
def modulated_rmsnorm(x, scale, eps=1e-6):
|
||||
# Normalize and modulate
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
|
||||
x_modulated = x_normed * (1 + scale.unsqueeze(1))
|
||||
return x_normed * (1 + scale.unsqueeze(1))
|
||||
|
||||
return x_modulated
|
||||
|
||||
|
||||
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||
@ -47,9 +46,8 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
||||
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
|
||||
|
||||
# Apply residual connection
|
||||
output = x + x_normed
|
||||
return x + x_normed
|
||||
|
||||
return output
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -275,14 +273,12 @@ class AsymmetricJointBlock(nn.Module):
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
x_mod = modulated_rmsnorm(x, scale_x)
|
||||
x_res = self.mlp_x(x_mod)
|
||||
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||
return x
|
||||
return residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
||||
|
||||
def ff_block_y(self, y, scale_y, gate_y):
|
||||
y_mod = modulated_rmsnorm(y, scale_y)
|
||||
y_res = self.mlp_y(y_mod)
|
||||
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||
return y
|
||||
return residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
@ -312,8 +308,7 @@ class FinalLayer(nn.Module):
|
||||
c = F.silu(c)
|
||||
shift, scale = self.mod(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class AsymmDiTJoint(nn.Module):
|
||||
|
@ -64,8 +64,7 @@ class TimestepEmbedder(nn.Module):
|
||||
if self.timestep_scale is not None:
|
||||
t = t * self.timestep_scale
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
return self.mlp(t_freq)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
@ -93,8 +92,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.w1(x).chunk(2, dim=-1)
|
||||
x = self.w2(F.silu(x) * gate)
|
||||
return x
|
||||
return self.w2(F.silu(x) * gate)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
@ -149,8 +147,7 @@ class PatchEmbed(nn.Module):
|
||||
raise NotImplementedError("Must flatten output.")
|
||||
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
@ -60,9 +60,8 @@ def create_position_matrix(
|
||||
# Stack and reshape the grids.
|
||||
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
||||
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
||||
pos = pos.to(dtype=dtype, device=device)
|
||||
return pos.to(dtype=dtype, device=device)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def compute_mixed_rotation(
|
||||
|
@ -30,5 +30,4 @@ def apply_rotary_emb_qk_real(
|
||||
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
||||
|
||||
# Interleave the results back into the original shape
|
||||
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
return out
|
||||
return torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
||||
|
@ -29,8 +29,7 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
|
||||
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
return (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
@ -98,5 +97,4 @@ class AttentionPool(nn.Module):
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
return self.to_out(x)
|
||||
|
@ -99,8 +99,7 @@ class Conv1x1(ops.Linear):
|
||||
"""
|
||||
x = x.movedim(1, -1)
|
||||
x = super().forward(x)
|
||||
x = x.movedim(-1, 1)
|
||||
return x
|
||||
return x.movedim(-1, 1)
|
||||
|
||||
|
||||
class DepthToSpaceTime(nn.Module):
|
||||
@ -269,8 +268,7 @@ class Attention(nn.Module):
|
||||
assert x.size(0) == q.size(0)
|
||||
|
||||
x = self.out(x)
|
||||
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
return x
|
||||
return rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
@ -321,8 +319,7 @@ class CausalUpsampleBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
x = self.proj(x)
|
||||
x = self.d2st(x)
|
||||
return x
|
||||
return self.d2st(x)
|
||||
|
||||
|
||||
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
|
||||
|
@ -86,8 +86,7 @@ class TokenRefinerBlock(nn.Module):
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
return x
|
||||
return x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
|
||||
|
||||
class IndividualTokenRefiner(nn.Module):
|
||||
@ -157,8 +156,7 @@ class TokenRefiner(nn.Module):
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
return x
|
||||
return self.individual_token_refiner(x, c, mask)
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
@ -311,8 +309,7 @@ class HunyuanVideo(nn.Module):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape)
|
||||
return img
|
||||
return img.reshape(initial_shape)
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
@ -326,5 +323,4 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||
return out
|
||||
return self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||
|
@ -166,9 +166,8 @@ class CrossAttention(nn.Module):
|
||||
out = self.out_proj(context) # context.reshape - B, L1, -1
|
||||
out = self.proj_drop(out)
|
||||
|
||||
out_tuple = (out,)
|
||||
return (out,)
|
||||
|
||||
return out_tuple
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -213,6 +212,5 @@ class Attention(nn.Module):
|
||||
x = self.out_proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
out_tuple = (x,)
|
||||
return (x,)
|
||||
|
||||
return out_tuple
|
||||
|
@ -19,8 +19,7 @@ def calc_rope(x, patch_size, head_size):
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
rope = (rope[0].to(x), rope[1].to(x))
|
||||
return rope
|
||||
return (rope[0].to(x), rope[1].to(x))
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
@ -110,9 +109,8 @@ class HunYuanDiTBlock(nn.Module):
|
||||
|
||||
# FFN Layer
|
||||
mlp_inputs = self.norm2(x)
|
||||
x = x + self.mlp(mlp_inputs)
|
||||
return x + self.mlp(mlp_inputs)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
@ -136,8 +134,7 @@ class FinalLayer(nn.Module):
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class HunYuanDiT(nn.Module):
|
||||
@ -413,5 +410,4 @@ class HunYuanDiT(nn.Module):
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
return x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
|
@ -53,8 +53,7 @@ def get_meshgrid(start, *args):
|
||||
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
return grid
|
||||
return np.stack(grid, axis=0) # [2, W, H]
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
@ -87,8 +86,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
@ -108,8 +106,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
|
||||
|
||||
#################################################################################
|
||||
@ -138,8 +135,7 @@ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
return get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
@ -154,8 +150,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
return torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
||||
@ -187,8 +182,7 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
return torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
|
||||
|
||||
|
||||
|
@ -122,14 +122,13 @@ class Timesteps(nn.Module):
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
return get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
@ -149,8 +148,7 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
return timesteps_emb
|
||||
return self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
@ -209,8 +207,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
return self.linear_2(hidden_states)
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
@ -247,9 +244,8 @@ def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
return input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
@ -316,14 +312,13 @@ class BasicTransformerBlock(nn.Module):
|
||||
return x
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
fractional_positions = torch.stack(
|
||||
return torch.stack(
|
||||
[
|
||||
indices_grid[:, i] / max_pos[i]
|
||||
for i in range(3)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return fractional_positions
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
|
@ -65,8 +65,7 @@ class Patchifier(ABC):
|
||||
scale = scale_grid[i]
|
||||
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
|
||||
|
||||
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
return grid
|
||||
return rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
|
||||
|
||||
|
||||
class SymmetricPatchifier(Patchifier):
|
||||
@ -74,14 +73,13 @@ class SymmetricPatchifier(Patchifier):
|
||||
self,
|
||||
latents: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
latents = rearrange(
|
||||
return rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self._patch_size[0],
|
||||
p2=self._patch_size[1],
|
||||
p3=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
@ -93,7 +91,7 @@ class SymmetricPatchifier(Patchifier):
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
output_height = output_height // self._patch_size[1]
|
||||
output_width = output_width // self._patch_size[2]
|
||||
latents = rearrange(
|
||||
return rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q) ",
|
||||
f=output_num_frames,
|
||||
@ -102,4 +100,3 @@ class SymmetricPatchifier(Patchifier):
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
@ -56,8 +56,7 @@ class CausalConv3d(nn.Module):
|
||||
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
||||
)
|
||||
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
return self.conv(x)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
@ -417,9 +417,8 @@ class Decoder(nn.Module):
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
return unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
@ -563,8 +562,7 @@ class LayerNorm(nn.Module):
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c d h w -> b d h w c")
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, "b d h w c -> b c d h w")
|
||||
return x
|
||||
return rearrange(x, "b d h w c -> b c d h w")
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
@ -677,9 +675,8 @@ class ResnetBlock3D(nn.Module):
|
||||
# similar to the "explicit noise inputs" method in style-gan
|
||||
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
||||
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
||||
hidden_states = hidden_states + scaled_noise
|
||||
return hidden_states + scaled_noise
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -740,9 +737,8 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
return input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def patchify(x, patch_size_hw, patch_size_t=1):
|
||||
|
@ -114,7 +114,7 @@ class DualConv3d(nn.Module):
|
||||
return x
|
||||
|
||||
# Second convolution
|
||||
x = F.conv3d(
|
||||
return F.conv3d(
|
||||
x,
|
||||
self.weight2,
|
||||
self.bias2,
|
||||
@ -124,7 +124,6 @@ class DualConv3d(nn.Module):
|
||||
self.groups,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_with_2d(self, x, skip_time_conv):
|
||||
b, c, d, h, w = x.shape
|
||||
@ -142,8 +141,7 @@ class DualConv3d(nn.Module):
|
||||
_, _, h, w = x.shape
|
||||
|
||||
if skip_time_conv:
|
||||
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
return x
|
||||
return rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
||||
|
||||
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
||||
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
||||
@ -155,9 +153,8 @@ class DualConv3d(nn.Module):
|
||||
padding2 = self.padding2[0]
|
||||
dilation2 = self.dilation2[0]
|
||||
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
|
||||
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
return rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
||||
|
||||
return x
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
@ -136,8 +136,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
return self.decoder(z, **kwargs)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
@ -178,8 +177,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
return super().get_autoencoder_params()
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
|
@ -142,13 +142,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = (
|
||||
return (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
@ -216,8 +215,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
return hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
@ -326,13 +324,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
del q, k, v
|
||||
|
||||
r1 = (
|
||||
return (
|
||||
r1.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return r1
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
try:
|
||||
@ -395,11 +392,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
out = (
|
||||
return (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
|
||||
SDP_BATCH_LIMIT = 2**15
|
||||
@ -932,7 +928,6 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
return x + x_in
|
||||
|
||||
|
||||
|
@ -51,8 +51,7 @@ class Mlp(nn.Module):
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
return self.drop2(x)
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
@ -103,8 +102,7 @@ class PatchEmbed(nn.Module):
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
return self.norm(x)
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
@ -155,8 +153,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
@ -176,8 +173,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
|
||||
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
|
||||
@ -187,8 +183,7 @@ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=t
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb
|
||||
return torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
|
||||
small = min(h, w)
|
||||
@ -197,8 +192,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude
|
||||
grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
return torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
|
||||
|
||||
#################################################################################
|
||||
@ -222,8 +216,7 @@ class TimestepEmbedder(nn.Module):
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
return self.mlp(t_freq)
|
||||
|
||||
|
||||
class VectorEmbedder(nn.Module):
|
||||
@ -240,8 +233,7 @@ class VectorEmbedder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.mlp(x)
|
||||
return emb
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
#################################################################################
|
||||
@ -307,16 +299,14 @@ class SelfAttention(nn.Module):
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
return self.proj_drop(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
q, k, v = self.pre_attention(x)
|
||||
x = optimized_attention(
|
||||
q, k, v, heads=self.num_heads
|
||||
)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
return self.post_attention(x)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
@ -530,10 +520,9 @@ class DismantledBlock(nn.Module):
|
||||
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
return x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
@ -568,10 +557,9 @@ class DismantledBlock(nn.Module):
|
||||
out2 = gate_msa2.unsqueeze(1) * attn2
|
||||
x = x + out1
|
||||
x = x + out2
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
return x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
@ -696,8 +684,7 @@ class FinalLayer(nn.Module):
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
||||
class SelfAttentionContext(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
|
||||
@ -902,13 +889,12 @@ class MMDiT(nn.Module):
|
||||
w=self.pos_embed_max_size,
|
||||
)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
return rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
# print(spatial_pos_embed, top, left, h, w)
|
||||
# # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
|
||||
# t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
|
||||
# # print(t)
|
||||
# return t
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x, hw=None):
|
||||
"""
|
||||
@ -927,8 +913,7 @@ class MMDiT(nn.Module):
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
return x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
|
||||
def forward_core_with_concat(
|
||||
self,
|
||||
@ -976,8 +961,7 @@ class MMDiT(nn.Module):
|
||||
if add is not None:
|
||||
x += add
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
return self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -493,8 +493,7 @@ class Model(nn.Module):
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
return self.conv_out(h)
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
@ -602,8 +601,7 @@ class Encoder(nn.Module):
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
return self.conv_out(h)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
@ -350,8 +350,7 @@ class VideoResBlock(ResBlock):
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
|
@ -79,11 +79,10 @@ class AlphaBlender(nn.Module):
|
||||
image_only_indicator=None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
|
||||
x = (
|
||||
return (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
@ -201,8 +200,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
return ctx.run_function(*ctx.input_tensors)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
|
@ -33,8 +33,7 @@ class DiagonalGaussianDistribution(object):
|
||||
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
return x
|
||||
return self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
|
@ -15,13 +15,11 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
|
||||
def scale(self, x):
|
||||
# re-normalize to centered mean and unit variance
|
||||
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
|
||||
return x
|
||||
return (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
|
||||
|
||||
def unscale(self, x):
|
||||
# back to original data stats
|
||||
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
|
||||
return x
|
||||
return (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
|
||||
|
||||
def forward(self, x, noise_level=None, seed=None):
|
||||
if noise_level is None:
|
||||
|
@ -177,8 +177,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
attn_scores /= summed
|
||||
attn_probs = attn_scores
|
||||
|
||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||
return hidden_states_slice
|
||||
return torch.bmm(attn_probs.to(value.dtype), value)
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
@ -264,7 +263,7 @@ def efficient_dot_product_attention(
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
return torch.cat([
|
||||
compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key_t=key_t,
|
||||
@ -272,4 +271,3 @@ def efficient_dot_product_attention(
|
||||
mask=get_mask_chunk(i * query_chunk_size)
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
return res
|
||||
|
@ -73,8 +73,7 @@ class MultiHeadCrossAttention(nn.Module):
|
||||
|
||||
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class AttentionKVCompress(nn.Module):
|
||||
@ -172,8 +171,7 @@ class AttentionKVCompress(nn.Module):
|
||||
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
|
||||
|
||||
x = x.view(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
@ -192,8 +190,7 @@ class FinalLayer(nn.Module):
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
@ -209,8 +206,7 @@ class T2IFinalLayer(nn.Module):
|
||||
def forward(self, x, t):
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class MaskFinalLayer(nn.Module):
|
||||
@ -228,8 +224,7 @@ class MaskFinalLayer(nn.Module):
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
@ -247,8 +242,7 @@ class DecoderLayer(nn.Module):
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_decoder(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
@ -276,8 +270,7 @@ class SizeEmbedder(TimestepEmbedder):
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = timestep_embedding(s, self.frequency_embedding_size)
|
||||
s_emb = self.mlp(s_freq.to(s.dtype))
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
return rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
@ -299,15 +292,13 @@ class LabelEmbedder(nn.Module):
|
||||
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
return torch.where(drop_ids, self.num_classes, labels)
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
return self.embedding_table(labels)
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
@ -331,8 +322,7 @@ class CaptionEmbedder(nn.Module):
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
return torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
@ -340,8 +330,7 @@ class CaptionEmbedder(nn.Module):
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
return self.y_proj(caption)
|
||||
|
||||
|
||||
class CaptionEmbedderDoubleBr(nn.Module):
|
||||
|
@ -23,8 +23,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_si
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
return torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
@ -57,9 +56,8 @@ class PixArtMSBlock(nn.Module):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
return x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
@ -212,9 +210,8 @@ class PixArtMS(nn.Module):
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
return self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
@ -252,5 +249,4 @@ class PixArtMS(nn.Module):
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
return x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
|
@ -29,8 +29,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
return torch.tensor(txts)
|
||||
|
||||
|
||||
def ismap(x):
|
||||
|
@ -206,8 +206,7 @@ class BaseModel(torch.nn.Module):
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
return data
|
||||
return torch.cat(cond_concat, dim=1)
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
@ -418,8 +417,7 @@ class SVD_img2vid(BaseModel):
|
||||
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
|
||||
out.append(self.embedder(torch.Tensor([augmentation])))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
return torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@ -457,8 +455,7 @@ class SV3D_u(SVD_img2vid):
|
||||
out = []
|
||||
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
return torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
|
||||
class SV3D_p(SVD_img2vid):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
|
@ -887,8 +887,7 @@ class ModelPatcher:
|
||||
|
||||
all_models = self.get_additional_models()
|
||||
models_set = set(all_models)
|
||||
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
|
||||
return real_all_models
|
||||
return _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
|
||||
|
||||
def use_ejected(self, skip_and_inject_on_exit_only=False):
|
||||
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)
|
||||
|
@ -286,8 +286,7 @@ class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
var = 1 / ((sigma * sigma) + 1)
|
||||
var = var.clamp(0, 1.0)
|
||||
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
|
||||
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return t
|
||||
return (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
|
@ -21,8 +21,7 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
return torch.cat(noises, axis=0)
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
@ -43,10 +42,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
return samples.to(comfy.model_management.intermediate_device())
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.to(comfy.model_management.intermediate_device())
|
||||
return samples
|
||||
return samples.to(comfy.model_management.intermediate_device())
|
||||
|
@ -713,8 +713,7 @@ class KSAMPLER(Sampler):
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
||||
return samples
|
||||
return model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
|
||||
|
||||
|
||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
|
@ -422,12 +422,11 @@ class VAE:
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
output = self.process_output(
|
||||
return self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
@ -485,8 +484,7 @@ class VAE:
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
return pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
|
@ -304,13 +304,11 @@ def token_weights(string, current_weight):
|
||||
|
||||
def escape_important(text):
|
||||
text = text.replace("\\)", "\0\1")
|
||||
text = text.replace("\\(", "\0\2")
|
||||
return text
|
||||
return text.replace("\\(", "\0\2")
|
||||
|
||||
def unescape_important(text):
|
||||
text = text.replace("\0\1", ")")
|
||||
text = text.replace("\0\2", "(")
|
||||
return text
|
||||
return text.replace("\0\2", "(")
|
||||
|
||||
def safe_load_embed_zip(embed_path):
|
||||
with zipfile.ZipFile(embed_path) as myzip:
|
||||
@ -635,8 +633,7 @@ class SD1ClipModel(torch.nn.Module):
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out
|
||||
return getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
@ -51,8 +51,7 @@ class SD15(supported_models_base.BASE):
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
return state_dict
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
|
||||
@ -97,15 +96,13 @@ class SD20(supported_models_base.BASE):
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
||||
replace_prefix["cond_stage_model.model."] = "clip_h."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
|
||||
return state_dict
|
||||
return utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["clip_h"] = "cond_stage_model.model"
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
return diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel)
|
||||
@ -158,8 +155,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
return state_dict
|
||||
return utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
@ -167,8 +163,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
||||
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
return utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||
@ -221,8 +216,7 @@ class SDXL(supported_models_base.BASE):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
|
||||
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
||||
return state_dict
|
||||
return utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
@ -239,8 +233,7 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
return state_dict_g
|
||||
return utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
@ -310,8 +303,7 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SVD_img2vid(self, device=device)
|
||||
return out
|
||||
return model_base.SVD_img2vid(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
@ -331,8 +323,7 @@ class SV3D_u(SVD_img2vid):
|
||||
vae_key_prefix = ["conditioner.embedders.1.encoder."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_u(self, device=device)
|
||||
return out
|
||||
return model_base.SV3D_u(self, device=device)
|
||||
|
||||
class SV3D_p(SV3D_u):
|
||||
unet_config = {
|
||||
@ -348,8 +339,7 @@ class SV3D_p(SV3D_u):
|
||||
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_p(self, device=device)
|
||||
return out
|
||||
return model_base.SV3D_p(self, device=device)
|
||||
|
||||
class Stable_Zero123(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@ -376,8 +366,7 @@ class Stable_Zero123(supported_models_base.BASE):
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||
return out
|
||||
return model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
@ -407,8 +396,7 @@ class SD_X4Upscaler(SD20):
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
return model_base.SD_X4Upscaler(self, device=device)
|
||||
|
||||
class Stable_Cascade_C(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@ -450,8 +438,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
||||
return state_dict
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_C(self, device=device)
|
||||
return out
|
||||
return model_base.StableCascade_C(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||
@ -473,8 +460,7 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
||||
clip_vision_prefix = None
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
return model_base.StableCascade_B(self, device=device)
|
||||
|
||||
class SD15_instructpix2pix(SD15):
|
||||
unet_config = {
|
||||
@ -521,8 +507,7 @@ class SD3(supported_models_base.BASE):
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SD3(self, device=device)
|
||||
return out
|
||||
return model_base.SD3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
clip_l = False
|
||||
@ -587,8 +572,7 @@ class AuraFlow(supported_models_base.BASE):
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.AuraFlow(self, device=device)
|
||||
return out
|
||||
return model_base.AuraFlow(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
||||
@ -648,8 +632,7 @@ class HunyuanDiT(supported_models_base.BASE):
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanDiT(self, device=device)
|
||||
return out
|
||||
return model_base.HunyuanDiT(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
|
||||
@ -686,8 +669,7 @@ class Flux(supported_models_base.BASE):
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux(self, device=device)
|
||||
return out
|
||||
return model_base.Flux(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
@ -715,8 +697,7 @@ class FluxSchnell(Flux):
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
return model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@ -739,8 +720,7 @@ class GenmoMochi(supported_models_base.BASE):
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.GenmoMochi(self, device=device)
|
||||
return out
|
||||
return model_base.GenmoMochi(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
@ -767,8 +747,7 @@ class LTXV(supported_models_base.BASE):
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXV(self, device=device)
|
||||
return out
|
||||
return model_base.LTXV(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
@ -795,8 +774,7 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo(self, device=device)
|
||||
return out
|
||||
return model_base.HunyuanVideo(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
|
@ -87,8 +87,7 @@ class BASE:
|
||||
return out
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
return state_dict
|
||||
return utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
@ -60,8 +60,7 @@ class Downsample(nn.Module):
|
||||
padding = [x.shape[2] % 2, x.shape[3] % 2]
|
||||
self.op.padding = padding
|
||||
|
||||
x = self.op(x)
|
||||
return x
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
@ -196,8 +195,7 @@ class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
return x + self.mlp(self.ln_2(x))
|
||||
|
||||
|
||||
class StyleAdapter(nn.Module):
|
||||
@ -224,9 +222,8 @@ class StyleAdapter(nn.Module):
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x[:, -self.num_token:, :])
|
||||
x = x @ self.proj
|
||||
return x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock_light(nn.Module):
|
||||
@ -262,9 +259,8 @@ class extractor(nn.Module):
|
||||
x = self.down_opt(x)
|
||||
x = self.in_conv(x)
|
||||
x = self.body(x)
|
||||
x = self.out_conv(x)
|
||||
return self.out_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Adapter_light(nn.Module):
|
||||
|
@ -72,8 +72,7 @@ class TAESD(nn.Module):
|
||||
|
||||
def decode(self, x):
|
||||
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
return x_sample
|
||||
return x_sample.sub(0.5).mul(2)
|
||||
|
||||
def encode(self, x):
|
||||
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||
|
@ -17,8 +17,7 @@ class BertAttention(torch.nn.Module):
|
||||
k = self.key(x)
|
||||
v = self.value(x)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads, mask)
|
||||
return out
|
||||
return optimized_attention(q, k, v, self.heads, mask)
|
||||
|
||||
class BertOutput(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||
@ -30,8 +29,7 @@ class BertOutput(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x = self.dense(x)
|
||||
# hidden_states = self.dropout(hidden_states)
|
||||
x = self.LayerNorm(x + y)
|
||||
return x
|
||||
return self.LayerNorm(x + y)
|
||||
|
||||
class BertAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
@ -100,8 +98,7 @@ class BertEmbeddings(torch.nn.Module):
|
||||
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
|
||||
else:
|
||||
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
|
||||
x = self.LayerNorm(x)
|
||||
return x
|
||||
return self.LayerNorm(x)
|
||||
|
||||
|
||||
class BertModel_(torch.nn.Module):
|
||||
|
@ -142,9 +142,8 @@ class TransformerBlock(nn.Module):
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
return residual + x
|
||||
|
||||
return x
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
|
@ -30,8 +30,7 @@ class T5DenseActDense(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.act(self.wi(x))
|
||||
# x = self.dropout(x)
|
||||
x = self.wo(x)
|
||||
return x
|
||||
return self.wo(x)
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
@ -47,8 +46,7 @@ class T5DenseGatedActDense(torch.nn.Module):
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
# x = self.dropout(x)
|
||||
x = self.wo(x)
|
||||
return x
|
||||
return self.wo(x)
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
|
||||
@ -145,8 +143,7 @@ class T5Attention(torch.nn.Module):
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
return values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
q = self.q(x)
|
||||
|
@ -292,8 +292,7 @@ def unet_to_diffusers(unet_config):
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
return torch.cat([scale, shift], dim=0)
|
||||
|
||||
MMDIT_MAP_BASIC = {
|
||||
("context_embedder.bias", "context_embedder.bias"),
|
||||
@ -1000,8 +999,7 @@ def reshape_mask(input_mask, output_shape):
|
||||
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
||||
if mask.shape[1] < output_shape[1]:
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||
mask = repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
return repeat_to_batch_size(mask, output_shape[0])
|
||||
|
||||
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
hi, wi = img_size_in
|
||||
@ -1037,9 +1035,8 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
|
||||
|
||||
# reassemble the mask from blocks
|
||||
out = torch.cat([
|
||||
return torch.cat([
|
||||
torch.cat([txt_to_txt, txt_to_img], dim=2),
|
||||
torch.cat([img_to_txt, img_to_img], dim=2)],
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps):
|
||||
new_xs = np.linspace(0, 1, num_steps)
|
||||
new_ys = np.interp(new_xs, xs, ys)
|
||||
|
||||
interped_ys = np.exp(new_ys)[::-1].copy()
|
||||
return interped_ys
|
||||
return np.exp(new_ys)[::-1].copy()
|
||||
|
||||
NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
|
||||
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
|
||||
|
@ -104,9 +104,8 @@ def create_vorbis_comment_block(comment_dict, last_block):
|
||||
id = b'\x84'
|
||||
else:
|
||||
id = b'\x04'
|
||||
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
|
||||
return id + struct.pack('>I', len(comment_data))[1:] + comment_data
|
||||
|
||||
return comment_block
|
||||
|
||||
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
||||
if len(comment_dict) == 0:
|
||||
|
@ -150,8 +150,7 @@ class PorterDuffImageComposite:
|
||||
out_images.append(out_image)
|
||||
out_alphas.append(out_alpha.squeeze(2))
|
||||
|
||||
result = (torch.stack(out_images), torch.stack(out_alphas))
|
||||
return result
|
||||
return (torch.stack(out_images), torch.stack(out_alphas))
|
||||
|
||||
|
||||
class SplitImageWithAlpha:
|
||||
@ -170,8 +169,7 @@ class SplitImageWithAlpha:
|
||||
def split_image_with_alpha(self, image: torch.Tensor):
|
||||
out_images = [i[:,:,:3] for i in image]
|
||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||
return result
|
||||
return (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||
|
||||
|
||||
class JoinImageWithAlpha:
|
||||
@ -196,8 +194,7 @@ class JoinImageWithAlpha:
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
result = (torch.stack(out_images),)
|
||||
return result
|
||||
return (torch.stack(out_images),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps):
|
||||
new_xs = np.linspace(0, 1, num_steps)
|
||||
new_ys = np.interp(new_xs, xs, ys)
|
||||
|
||||
interped_ys = np.exp(new_ys)[::-1].copy()
|
||||
return interped_ys
|
||||
return np.exp(new_ys)[::-1].copy()
|
||||
|
||||
NOISE_LEVELS = {
|
||||
0.80: [
|
||||
|
@ -27,8 +27,7 @@ class Mahiro:
|
||||
normm = torch.sqrt(merge.abs()) * merge.sign()
|
||||
sim = F.cosine_similarity(normu, normm).mean()
|
||||
simsc = 2 * (sim+1)
|
||||
wm = (simsc*cfg + (4-simsc)*leap) / 4
|
||||
return wm
|
||||
return (simsc*cfg + (4-simsc)*leap) / 4
|
||||
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
||||
return (m, )
|
||||
|
||||
|
@ -11,8 +11,7 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co
|
||||
|
||||
perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
|
||||
perp_neg = perp * neg_scale
|
||||
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
|
||||
return cfg_result
|
||||
return noise_pred_nocond + cond_scale*(pos - perp_neg)
|
||||
|
||||
#TODO: This node should be removed, it has been replaced with PerpNegGuider
|
||||
class PerpNeg:
|
||||
@ -44,8 +43,7 @@ class PerpNeg:
|
||||
|
||||
(noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
|
||||
|
||||
cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
|
||||
return cfg_result
|
||||
return x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
|
||||
|
||||
m.set_model_sampler_cfg_function(cfg_function)
|
||||
|
||||
|
@ -52,8 +52,7 @@ class FuseModule(nn.Module):
|
||||
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
|
||||
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
|
||||
stacked_id_embeds = self.mlp2(stacked_id_embeds)
|
||||
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
|
||||
return stacked_id_embeds
|
||||
return self.layer_norm(stacked_id_embeds)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -86,8 +85,7 @@ class FuseModule(nn.Module):
|
||||
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
|
||||
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
|
||||
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
|
||||
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
|
||||
return updated_prompt_embeds
|
||||
return prompt_embeds.view(batch_size, seq_length, -1)
|
||||
|
||||
class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
|
||||
def __init__(self):
|
||||
@ -111,9 +109,8 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
|
||||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
|
||||
|
||||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
|
||||
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
|
||||
return self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
|
||||
|
||||
return updated_prompt_embeds
|
||||
|
||||
|
||||
class PhotoMakerLoader:
|
||||
|
@ -162,8 +162,7 @@ class Quantize:
|
||||
result = result.to(dtype=torch.uint8)
|
||||
|
||||
im = Image.fromarray(result.cpu().numpy())
|
||||
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
return im
|
||||
return im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
|
||||
def quantize(self, image: torch.Tensor, colors: int, dither: str):
|
||||
batch_size, height, width, _ = image.shape
|
||||
|
@ -50,8 +50,7 @@ class LatentRebatch:
|
||||
def cat_batch(batch1, batch2):
|
||||
if batch1[0] is None:
|
||||
return batch2
|
||||
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||
return result
|
||||
return [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||
|
||||
def rebatch(self, latents, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
@ -82,8 +82,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
mask = F.interpolate(mask, (lh, lw))
|
||||
|
||||
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||
blurred = blurred * mask + x0 * (1 - mask)
|
||||
return blurred
|
||||
return blurred * mask + x0 * (1 - mask)
|
||||
|
||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
@ -101,8 +100,7 @@ def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
|
||||
img = F.pad(img, padding, mode="reflect")
|
||||
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
return img
|
||||
return F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
|
||||
class SelfAttentionGuidance:
|
||||
@classmethod
|
||||
|
@ -5,7 +5,7 @@ import comfy.utils
|
||||
def camera_embeddings(elevation, azimuth):
|
||||
elevation = torch.as_tensor([elevation])
|
||||
azimuth = torch.as_tensor([azimuth])
|
||||
embeddings = torch.stack(
|
||||
return torch.stack(
|
||||
[
|
||||
torch.deg2rad(
|
||||
(90 - elevation) - (90)
|
||||
@ -17,7 +17,6 @@ def camera_embeddings(elevation, azimuth):
|
||||
),
|
||||
], dim=-1).unsqueeze(1)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class StableZero123_Conditioning:
|
||||
|
@ -81,11 +81,10 @@ class CacheSet:
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
return {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
"ui": self.ui.recursive_debug_dump(),
|
||||
}
|
||||
return result
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
|
@ -356,8 +356,7 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
||||
input = input.replace("%day%", str(now.tm_mday).zfill(2))
|
||||
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
|
||||
input = input.replace("%minute%", str(now.tm_min).zfill(2))
|
||||
input = input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||
return input
|
||||
return input.replace("%second%", str(now.tm_sec).zfill(2))
|
||||
|
||||
if "%" in filename_prefix:
|
||||
filename_prefix = compute_vars(filename_prefix, image_width, image_height)
|
||||
|
3
nodes.py
3
nodes.py
@ -607,8 +607,7 @@ class unCLIPCheckpointLoader:
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out
|
||||
return comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
class CLIPSetLastLayer:
|
||||
@classmethod
|
||||
|
@ -10,6 +10,7 @@ lint.select = [
|
||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
"RET504",
|
||||
]
|
||||
|
||||
exclude = ["*.ipynb"]
|
||||
|
@ -129,8 +129,7 @@ class TestCompareImageMetrics:
|
||||
|
||||
def read_img(self, filename: str) -> np.ndarray:
|
||||
cvImg = imread(filename)
|
||||
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
|
||||
return cvImg
|
||||
return cvtColor(cvImg, COLOR_BGR2RGB)
|
||||
|
||||
def image_grid(self, img_list: list[list[Image.Image]]):
|
||||
# imgs is a 2D list of images
|
||||
@ -154,8 +153,7 @@ class TestCompareImageMetrics:
|
||||
with open(metrics_output_file, 'r') as f:
|
||||
for line in f:
|
||||
if fname_basestr in line:
|
||||
score = float(line.split('|')[5])
|
||||
return score
|
||||
return float(line.split('|')[5])
|
||||
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
|
||||
|
||||
def gather_file_basenames(self, directory: str):
|
||||
|
@ -141,14 +141,13 @@ class TestExecutionBlockerNode:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
return {
|
||||
"required": {
|
||||
"input": ("*",),
|
||||
"block": ("BOOLEAN",),
|
||||
"verbose": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
RETURN_NAMES = ("output",)
|
||||
|
Loading…
Reference in New Issue
Block a user