Compare commits

...

3 Commits

Author SHA1 Message Date
Alexander Piskun
c5f652ac79
Merge f4e86a4a07 into 2307ff6746 2025-01-08 19:17:00 -05:00
comfyanonymous
2307ff6746 Improve some logging messages. 2025-01-08 19:05:22 -05:00
Alexander Piskun
f4e86a4a07
applied RUFF RET504 rule (unnecessary assignment before return statement)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2025-01-01 12:45:55 +02:00
81 changed files with 223 additions and 439 deletions

View File

@ -50,8 +50,7 @@ class ResBlockUnionControlnet(nn.Module):
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
return x + self.mlp(self.ln_2(x))
class ControlledUnetModel(UNetModel):
#implemented in the ldm unet

View File

@ -36,8 +36,7 @@ class CLIPMLP(torch.nn.Module):
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
return self.fc2(x)
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):

View File

@ -270,5 +270,4 @@ class CheckLazyMixin:
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need
return [name for name in kwargs if kwargs[name] is None]

View File

@ -404,8 +404,7 @@ class ControlLora(ControlNet):
super().cleanup()
def get_models(self):
out = ControlBase.get_models(self)
return out
return ControlBase.get_models(self)
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
@ -461,8 +460,7 @@ def load_controlnet_mmdit(sd, model_options={}):
latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
class ControlNetSD35(ControlNet):
@ -536,8 +534,7 @@ def load_controlnet_sd35(sd, model_options={}):
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
return ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
@ -549,16 +546,14 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
latent_format = comfy.latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
return ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
@ -581,8 +576,7 @@ def load_controlnet_flux_instantx(sd, model_options={}):
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@ -738,8 +732,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = model_options.get("global_average_pooling", False)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
return ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options:

View File

@ -99,8 +99,7 @@ def convert_unet_state_dict(unet_state_dict):
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
return {v: unet_state_dict[k] for k, v in mapping.items()}
# ================#

View File

@ -136,8 +136,7 @@ class NoiseScheduleVP:
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == 'cosine':
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
return log_alpha_fn(t) - self.cosine_log_alpha_0
def marginal_alpha(self, t):
"""
@ -174,8 +173,7 @@ class NoiseScheduleVP:
else:
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
t = t_fn(log_alpha)
return t
return t_fn(log_alpha)
def model_wrapper(
@ -378,8 +376,7 @@ class UniPC:
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
return torch.clamp(x0, -s, s) / s
def noise_prediction_fn(self, x, t):
"""
@ -423,8 +420,7 @@ class UniPC:
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic':
t_order = 2
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
return t
return torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
else:
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
@ -801,8 +797,7 @@ def interpolate_fn(x, xp, yp):
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
def expand_dims(v, dims):

View File

@ -78,10 +78,9 @@ class GatedCrossAttentionDense(nn.Module):
x = x + self.scale * \
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
x = x + self.scale * \
return x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense(nn.Module):
@ -118,10 +117,9 @@ class GatedSelfAttentionDense(nn.Module):
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
x = x + self.scale * \
return x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense2(nn.Module):
@ -172,10 +170,9 @@ class GatedSelfAttentionDense2(nn.Module):
# add residual to visual feature
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
x = x + self.scale * \
return x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class FourierEmbedder():
@ -340,5 +337,4 @@ def load_gligen(sd):
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen
return Gligen(output_list, w.position_net, key_dim)

View File

@ -46,8 +46,7 @@ def cal_intergrand(beta_0, beta_1, taus):
log_alpha = alpha.log()
log_alpha.sum().backward()
d_log_alpha_dtau = taus.grad
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
return integrand
return -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
#----------------------------------------------------------------------------

View File

@ -50,8 +50,7 @@ def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
return clamp(torch.exp(lmb))

View File

@ -70,9 +70,8 @@ class SnakeBeta(nn.Module):
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return snake_beta(x, alpha, beta)
return x
def WNConv1d(*args, **kwargs):
try:

View File

@ -89,8 +89,7 @@ class AbsolutePositionalEmbedding(nn.Module):
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return pos_emb
return pos_emb * self.scale
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
@ -404,9 +403,8 @@ class ConformerModule(nn.Module):
x = self.swish(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d')
return rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module):
def __init__(

View File

@ -21,8 +21,7 @@ class LearnedPositionalEmbedding(nn.Module):
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
return torch.cat((x, fouriered), dim=-1)
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
@ -49,8 +48,7 @@ class NumberEmbedder(nn.Module):
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
return embedding.view(*shape, self.features)
class Conditioner(nn.Module):

View File

@ -36,8 +36,7 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
return self.c_proj(x)
class MultiHeadLayerNorm(nn.Module):
@ -95,8 +94,7 @@ class SingleAttention(nn.Module):
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c = self.w1o(output)
return c
return self.w1o(output)
@ -265,9 +263,8 @@ class DiTBlock(nn.Module):
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
cx = cxres + cx
return cxres + cx
return cx
@ -298,8 +295,7 @@ class TimestepEmbedder(nn.Module):
#@torch.compile()
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
return self.mlp(t_freq)
class MMDiT(nn.Module):
@ -401,8 +397,7 @@ class MMDiT(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
return x.reshape(shape=(x.shape[0], c, h * p, w * p))
def patchify(self, x):
B, C, H, W = x.size()
@ -415,8 +410,7 @@ class MMDiT(nn.Module):
(W + 1) // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
return x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
def apply_pos_embeds(self, x, h, w):
h = (h + 1) // self.patch_size
@ -494,5 +488,4 @@ class MMDiT(nn.Module):
x = modulate(x, fshift, fscale)
x = self.final_linear(x)
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
return x
return self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]

View File

@ -54,8 +54,7 @@ class Attention2D(nn.Module):
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
return x.permute(0, 2, 1).view(*orig_shape)
def LayerNorm2d_op(operations):
@ -116,8 +115,7 @@ class AttnBlock(nn.Module):
def forward(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x
return x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
class FeedForwardBlock(nn.Module):
@ -133,8 +131,7 @@ class FeedForwardBlock(nn.Module):
)
def forward(self, x):
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
return x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class TimestepBlock(nn.Module):

View File

@ -157,9 +157,8 @@ class ResBlock(nn.Module):
x = x + self.depthwise[1](x_temp) * mods[2]
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x
class StageA(nn.Module):
@ -218,8 +217,7 @@ class StageA(nn.Module):
def decode(self, x):
x = self.up_blocks(x)
x = self.out_block(x)
return x
return self.out_block(x)
def forward(self, x, quantize=False):
qe, x, _, vq_loss = self.encode(x, quantize)
@ -251,5 +249,4 @@ class Discriminator(nn.Module):
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1)
x = self.shuffle(x)
x = self.logits(x)
return x
return self.logits(x)

View File

@ -170,8 +170,7 @@ class StageB(nn.Module):
if len(clip.shape) == 2:
clip = clip.unsqueeze(1)
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
clip = self.clip_norm(clip)
return clip
return self.clip_norm(clip)
def _down_encode(self, x, r_embed, clip):
level_outputs = []

View File

@ -179,8 +179,7 @@ class StageC(nn.Module):
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
clip = self.clip_norm(clip)
return clip
return self.clip_norm(clip)
def _down_encode(self, x, r_embed, clip, cnet=None):
level_outputs = []

View File

@ -35,8 +35,7 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x):
x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
o = self.mapper(self.backbone(x))
return o
return self.mapper(self.backbone(x))
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192

View File

@ -256,5 +256,4 @@ class LastLayer(nn.Module):
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
return self.linear(x)

View File

@ -9,8 +9,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
return optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:

View File

@ -183,8 +183,7 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape

View File

@ -21,5 +21,4 @@ class ReduxImageEncoder(torch.nn.Module):
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x
return self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))

View File

@ -34,9 +34,8 @@ import comfy.ops
def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1))
return x_normed * (1 + scale.unsqueeze(1))
return x_modulated
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
@ -47,9 +46,8 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection
output = x + x_normed
return x + x_normed
return output
class AsymmetricAttention(nn.Module):
def __init__(
@ -275,14 +273,12 @@ class AsymmetricJointBlock(nn.Module):
def ff_block_x(self, x, scale_x, gate_x):
x_mod = modulated_rmsnorm(x, scale_x)
x_res = self.mlp_x(x_mod)
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
return x
return residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
def ff_block_y(self, y, scale_y, gate_y):
y_mod = modulated_rmsnorm(y, scale_y)
y_res = self.mlp_y(y_mod)
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
return y
return residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
class FinalLayer(nn.Module):
@ -312,8 +308,7 @@ class FinalLayer(nn.Module):
c = F.silu(c)
shift, scale = self.mod(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
return self.linear(x)
class AsymmDiTJoint(nn.Module):

View File

@ -64,8 +64,7 @@ class TimestepEmbedder(nn.Module):
if self.timestep_scale is not None:
t = t * self.timestep_scale
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
t_emb = self.mlp(t_freq)
return t_emb
return self.mlp(t_freq)
class FeedForward(nn.Module):
@ -93,8 +92,7 @@ class FeedForward(nn.Module):
def forward(self, x):
x, gate = self.w1(x).chunk(2, dim=-1)
x = self.w2(F.silu(x) * gate)
return x
return self.w2(F.silu(x) * gate)
class PatchEmbed(nn.Module):
@ -149,8 +147,7 @@ class PatchEmbed(nn.Module):
raise NotImplementedError("Must flatten output.")
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
x = self.norm(x)
return x
return self.norm(x)
class RMSNorm(torch.nn.Module):

View File

@ -60,9 +60,8 @@ def create_position_matrix(
# Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device)
return pos.to(dtype=dtype, device=device)
return pos
def compute_mixed_rotation(

View File

@ -30,5 +30,4 @@ def apply_rotary_emb_qk_real(
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
# Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
return out
return torch.stack([cos_part, sin_part], dim=-1).flatten(-2)

View File

@ -29,8 +29,7 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
return (x * mask).sum(dim=1, keepdim=keepdim)
class AttentionPool(nn.Module):
@ -98,5 +97,4 @@ class AttentionPool(nn.Module):
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
return self.to_out(x)

View File

@ -99,8 +99,7 @@ class Conv1x1(ops.Linear):
"""
x = x.movedim(1, -1)
x = super().forward(x)
x = x.movedim(-1, 1)
return x
return x.movedim(-1, 1)
class DepthToSpaceTime(nn.Module):
@ -269,8 +268,7 @@ class Attention(nn.Module):
assert x.size(0) == q.size(0)
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
return rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
class AttentionBlock(nn.Module):
@ -321,8 +319,7 @@ class CausalUpsampleBlock(nn.Module):
def forward(self, x):
x = self.blocks(x)
x = self.proj(x)
x = self.d2st(x)
return x
return self.d2st(x)
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):

View File

@ -86,8 +86,7 @@ class TokenRefinerBlock(nn.Module):
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
return x
return x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
class IndividualTokenRefiner(nn.Module):
@ -157,8 +156,7 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x
return self.individual_token_refiner(x, c, mask)
class HunyuanVideo(nn.Module):
"""
@ -311,8 +309,7 @@ class HunyuanVideo(nn.Module):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
return img
return img.reshape(initial_shape)
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
@ -326,5 +323,4 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
return out
return self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)

View File

@ -166,9 +166,8 @@ class CrossAttention(nn.Module):
out = self.out_proj(context) # context.reshape - B, L1, -1
out = self.proj_drop(out)
out_tuple = (out,)
return (out,)
return out_tuple
class Attention(nn.Module):
@ -213,6 +212,5 @@ class Attention(nn.Module):
x = self.out_proj(x)
x = self.proj_drop(x)
out_tuple = (x,)
return (x,)
return out_tuple

View File

@ -19,8 +19,7 @@ def calc_rope(x, patch_size, head_size):
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
rope = (rope[0].to(x), rope[1].to(x))
return rope
return (rope[0].to(x), rope[1].to(x))
def modulate(x, shift, scale):
@ -110,9 +109,8 @@ class HunYuanDiTBlock(nn.Module):
# FFN Layer
mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs)
return x + self.mlp(mlp_inputs)
return x
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
if self.gradient_checkpointing and self.training:
@ -136,8 +134,7 @@ class FinalLayer(nn.Module):
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
return self.linear(x)
class HunYuanDiT(nn.Module):
@ -413,5 +410,4 @@ class HunYuanDiT(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
return x.reshape(shape=(x.shape[0], c, h * p, w * p))

View File

@ -53,8 +53,7 @@ def get_meshgrid(start, *args):
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
return np.stack(grid, axis=0) # [2, W, H]
#################################################################################
# Sine/Cosine Positional Embedding Functions #
@ -87,8 +86,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
@ -108,8 +106,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
#################################################################################
@ -138,8 +135,7 @@ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
"""
grid = get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
return get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
@ -154,8 +150,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
return torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
@ -187,8 +182,7 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
return torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]

View File

@ -122,14 +122,13 @@ class Timesteps(nn.Module):
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
@ -149,8 +148,7 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
return self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
class AdaLayerNormSingle(nn.Module):
@ -209,8 +207,7 @@ class PixArtAlphaTextProjection(nn.Module):
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
return self.linear_2(hidden_states)
class GELU_approx(nn.Module):
@ -247,9 +244,8 @@ def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
@ -316,14 +312,13 @@ class BasicTransformerBlock(nn.Module):
return x
def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack(
return torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):

View File

@ -65,8 +65,7 @@ class Patchifier(ABC):
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
return rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
class SymmetricPatchifier(Patchifier):
@ -74,14 +73,13 @@ class SymmetricPatchifier(Patchifier):
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
latents = rearrange(
return rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
@ -93,7 +91,7 @@ class SymmetricPatchifier(Patchifier):
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
return rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames,
@ -102,4 +100,3 @@ class SymmetricPatchifier(Patchifier):
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents

View File

@ -56,8 +56,7 @@ class CausalConv3d(nn.Module):
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
return self.conv(x)
@property
def weight(self):

View File

@ -417,9 +417,8 @@ class Decoder(nn.Module):
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
@ -563,8 +562,7 @@ class LayerNorm(nn.Module):
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
return rearrange(x, "b d h w c -> b c d h w")
class ResnetBlock3D(nn.Module):
@ -677,9 +675,8 @@ class ResnetBlock3D(nn.Module):
# similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise
return hidden_states + scaled_noise
return hidden_states
def forward(
self,
@ -740,9 +737,8 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):

View File

@ -114,7 +114,7 @@ class DualConv3d(nn.Module):
return x
# Second convolution
x = F.conv3d(
return F.conv3d(
x,
self.weight2,
self.bias2,
@ -124,7 +124,6 @@ class DualConv3d(nn.Module):
self.groups,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
@ -142,8 +141,7 @@ class DualConv3d(nn.Module):
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
return rearrange(x, "(b d) c h w -> b c d h w", b=b)
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
@ -155,9 +153,8 @@ class DualConv3d(nn.Module):
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):

View File

@ -136,8 +136,7 @@ class AutoencodingEngine(AbstractAutoencoder):
return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
return self.decoder(z, **kwargs)
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
@ -178,8 +177,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
return super().get_autoencoder_params()
def encode(
self, x: torch.Tensor, return_reg_log: bool = False

View File

@ -142,13 +142,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
return (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
@ -216,8 +215,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
return hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
@ -326,13 +324,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
del q, k, v
r1 = (
return (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
BROKEN_XFORMERS = False
try:
@ -395,11 +392,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
return (
out.reshape(b, -1, heads * dim_head)
)
return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15
@ -932,7 +928,6 @@ class SpatialVideoTransformer(SpatialTransformer):
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out
return x + x_in

View File

@ -51,8 +51,7 @@ class Mlp(nn.Module):
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
return self.drop2(x)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
@ -103,8 +102,7 @@ class PatchEmbed(nn.Module):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
x = self.norm(x)
return x
return self.norm(x)
def modulate(x, shift, scale):
if shift is None:
@ -155,8 +153,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
@ -176,8 +173,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
@ -187,8 +183,7 @@ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=t
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
return torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
small = min(h, w)
@ -197,8 +192,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude
grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
return torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
#################################################################################
@ -222,8 +216,7 @@ class TimestepEmbedder(nn.Module):
def forward(self, t, dtype, **kwargs):
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
return self.mlp(t_freq)
class VectorEmbedder(nn.Module):
@ -240,8 +233,7 @@ class VectorEmbedder(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
emb = self.mlp(x)
return emb
return self.mlp(x)
#################################################################################
@ -307,16 +299,14 @@ class SelfAttention(nn.Module):
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
x = self.proj(x)
x = self.proj_drop(x)
return x
return self.proj_drop(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = self.pre_attention(x)
x = optimized_attention(
q, k, v, heads=self.num_heads
)
x = self.post_attention(x)
return x
return self.post_attention(x)
class RMSNorm(torch.nn.Module):
@ -530,10 +520,9 @@ class DismantledBlock(nn.Module):
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
assert not self.pre_only
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
return x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn
@ -568,10 +557,9 @@ class DismantledBlock(nn.Module):
out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1
x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp(
return x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp)
)
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
@ -696,8 +684,7 @@ class FinalLayer(nn.Module):
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
return self.linear(x)
class SelfAttentionContext(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
@ -902,13 +889,12 @@ class MMDiT(nn.Module):
w=self.pos_embed_max_size,
)
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
return rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
# print(spatial_pos_embed, top, left, h, w)
# # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
# t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
# # print(t)
# return t
return spatial_pos_embed
def unpatchify(self, x, hw=None):
"""
@ -927,8 +913,7 @@ class MMDiT(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
return x.reshape(shape=(x.shape[0], c, h * p, w * p))
def forward_core_with_concat(
self,
@ -976,8 +961,7 @@ class MMDiT(nn.Module):
if add is not None:
x += add
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
return x
return self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
def forward(
self,

View File

@ -493,8 +493,7 @@ class Model(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
return self.conv_out(h)
def get_last_layer(self):
return self.conv_out.weight
@ -602,8 +601,7 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
return self.conv_out(h)
class Decoder(nn.Module):

View File

@ -350,8 +350,7 @@ class VideoResBlock(ResBlock):
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
return rearrange(x, "b c t h w -> (b t) c h w")
class Timestep(nn.Module):

View File

@ -79,11 +79,10 @@ class AlphaBlender(nn.Module):
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
x = (
return (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
@ -201,8 +200,7 @@ class CheckpointFunction(torch.autograd.Function):
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
return ctx.run_function(*ctx.input_tensors)
@staticmethod
def backward(ctx, *output_grads):

View File

@ -33,8 +33,7 @@ class DiagonalGaussianDistribution(object):
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
return self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
def kl(self, other=None):
if self.deterministic:

View File

@ -15,13 +15,11 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x
return (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
def unscale(self, x):
# back to original data stats
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x
return (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:

View File

@ -177,8 +177,7 @@ def _get_attention_scores_no_kv_chunking(
attn_scores /= summed
attn_probs = attn_scores
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
return hidden_states_slice
return torch.bmm(attn_probs.to(value.dtype), value)
class ScannedChunk(NamedTuple):
chunk_idx: int
@ -264,7 +263,7 @@ def efficient_dot_product_attention(
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([
return torch.cat([
compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key_t=key_t,
@ -272,4 +271,3 @@ def efficient_dot_product_attention(
mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
return res

View File

@ -73,8 +73,7 @@ class MultiHeadCrossAttention(nn.Module):
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
x = self.proj(x)
x = self.proj_drop(x)
return x
return self.proj_drop(x)
class AttentionKVCompress(nn.Module):
@ -172,8 +171,7 @@ class AttentionKVCompress(nn.Module):
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
x = x.view(B, N, C)
x = self.proj(x)
return x
return self.proj(x)
class FinalLayer(nn.Module):
@ -192,8 +190,7 @@ class FinalLayer(nn.Module):
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
return self.linear(x)
class T2IFinalLayer(nn.Module):
"""
@ -209,8 +206,7 @@ class T2IFinalLayer(nn.Module):
def forward(self, x, t):
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
return self.linear(x)
class MaskFinalLayer(nn.Module):
@ -228,8 +224,7 @@ class MaskFinalLayer(nn.Module):
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
return self.linear(x)
class DecoderLayer(nn.Module):
@ -247,8 +242,7 @@ class DecoderLayer(nn.Module):
def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x)
return x
return self.linear(x)
class SizeEmbedder(TimestepEmbedder):
@ -276,8 +270,7 @@ class SizeEmbedder(TimestepEmbedder):
s = rearrange(s, "b d -> (b d)")
s_freq = timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
return rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
class LabelEmbedder(nn.Module):
@ -299,15 +292,13 @@ class LabelEmbedder(nn.Module):
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
return torch.where(drop_ids, self.num_classes, labels)
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
return self.embedding_table(labels)
class CaptionEmbedder(nn.Module):
@ -331,8 +322,7 @@ class CaptionEmbedder(nn.Module):
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else:
drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
return torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
def forward(self, caption, train, force_drop_ids=None):
if train:
@ -340,8 +330,7 @@ class CaptionEmbedder(nn.Module):
use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption)
return caption
return self.y_proj(caption)
class CaptionEmbedderDoubleBr(nn.Module):

View File

@ -23,8 +23,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_si
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
return torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
class PixArtMSBlock(nn.Module):
"""
@ -57,9 +56,8 @@ class PixArtMSBlock(nn.Module):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ###
@ -212,9 +210,8 @@ class PixArtMS(nn.Module):
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
return self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape
@ -252,5 +249,4 @@ class PixArtMS(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
return x.reshape(shape=(x.shape[0], c, h * p, w * p))

View File

@ -29,8 +29,7 @@ def log_txt_as_img(wh, xc, size=10):
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
return torch.tensor(txts)
def ismap(x):

View File

@ -206,8 +206,7 @@ class BaseModel(torch.nn.Module):
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
return data
return torch.cat(cond_concat, dim=1)
return None
def extra_conds(self, **kwargs):
@ -418,8 +417,7 @@ class SVD_img2vid(BaseModel):
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
return torch.flatten(torch.cat(out)).unsqueeze(dim=0)
def extra_conds(self, **kwargs):
out = {}
@ -457,8 +455,7 @@ class SV3D_u(SVD_img2vid):
out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
return torch.flatten(torch.cat(out)).unsqueeze(dim=0)
class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):

View File

@ -887,8 +887,7 @@ class ModelPatcher:
all_models = self.get_additional_models()
models_set = set(all_models)
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
return real_all_models
return _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
def use_ejected(self, skip_and_inject_on_exit_only=False):
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)

View File

@ -286,8 +286,7 @@ class StableCascadeSampling(ModelSamplingDiscrete):
var = 1 / ((sigma * sigma) + 1)
var = var.clamp(0, 1.0)
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return t
return (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
def percent_to_sigma(self, percent):
if percent <= 0.0:

View File

@ -21,8 +21,7 @@ def prepare_noise(latent_image, seed, noise_inds=None):
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
return torch.cat(noises, axis=0)
def fix_empty_latent_channels(model, latent_image):
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
@ -43,10 +42,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
return samples
return samples.to(comfy.model_management.intermediate_device())
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device())
return samples
return samples.to(comfy.model_management.intermediate_device())

View File

@ -713,8 +713,7 @@ class KSAMPLER(Sampler):
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples
return model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
def ksampler(sampler_name, extra_options={}, inpaint_options={}):

View File

@ -111,7 +111,7 @@ class CLIP:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
self.use_clip_schedule = False
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self):
n = CLIP(no_init=True)
@ -422,12 +422,11 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output(
return self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0)
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
@ -485,8 +484,7 @@ class VAE:
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
return pixel_samples.to(self.output_device).movedim(1,-1)
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
@ -898,7 +896,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae, clipvision)

View File

@ -304,13 +304,11 @@ def token_weights(string, current_weight):
def escape_important(text):
text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2")
return text
return text.replace("\\(", "\0\2")
def unescape_important(text):
text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(")
return text
return text.replace("\0\2", "(")
def safe_load_embed_zip(embed_path):
with zipfile.ZipFile(embed_path) as myzip:
@ -635,8 +633,7 @@ class SD1ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out
return getattr(self, self.clip).encode_token_weights(token_weight_pairs)
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)

View File

@ -51,8 +51,7 @@ class SD15(supported_models_base.BASE):
replace_prefix = {}
replace_prefix["cond_stage_model."] = "clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
return state_dict
return utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
def process_clip_state_dict_for_saving(self, state_dict):
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
@ -97,15 +96,13 @@ class SD20(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict
return utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
return diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel)
@ -158,8 +155,7 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
return utils.state_dict_key_replace(state_dict, keys_to_replace)
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
@ -167,8 +163,7 @@ class SDXLRefiner(supported_models_base.BASE):
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
return utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
@ -221,8 +216,7 @@ class SDXL(supported_models_base.BASE):
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
return state_dict
return utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
@ -239,8 +233,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
return utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
@ -310,8 +303,7 @@ class SVD_img2vid(supported_models_base.BASE):
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
return model_base.SVD_img2vid(self, device=device)
def clip_target(self, state_dict={}):
return None
@ -331,8 +323,7 @@ class SV3D_u(SVD_img2vid):
vae_key_prefix = ["conditioner.embedders.1.encoder."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_u(self, device=device)
return out
return model_base.SV3D_u(self, device=device)
class SV3D_p(SV3D_u):
unet_config = {
@ -348,8 +339,7 @@ class SV3D_p(SV3D_u):
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_p(self, device=device)
return out
return model_base.SV3D_p(self, device=device)
class Stable_Zero123(supported_models_base.BASE):
unet_config = {
@ -376,8 +366,7 @@ class Stable_Zero123(supported_models_base.BASE):
latent_format = latent_formats.SD15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
return model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
def clip_target(self, state_dict={}):
return None
@ -407,8 +396,7 @@ class SD_X4Upscaler(SD20):
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD_X4Upscaler(self, device=device)
return out
return model_base.SD_X4Upscaler(self, device=device)
class Stable_Cascade_C(supported_models_base.BASE):
unet_config = {
@ -450,8 +438,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
return state_dict
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_C(self, device=device)
return out
return model_base.StableCascade_C(self, device=device)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
@ -473,8 +460,7 @@ class Stable_Cascade_B(Stable_Cascade_C):
clip_vision_prefix = None
def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device)
return out
return model_base.StableCascade_B(self, device=device)
class SD15_instructpix2pix(SD15):
unet_config = {
@ -521,8 +507,7 @@ class SD3(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device)
return out
return model_base.SD3(self, device=device)
def clip_target(self, state_dict={}):
clip_l = False
@ -587,8 +572,7 @@ class AuraFlow(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.AuraFlow(self, device=device)
return out
return model_base.AuraFlow(self, device=device)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
@ -648,8 +632,7 @@ class HunyuanDiT(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanDiT(self, device=device)
return out
return model_base.HunyuanDiT(self, device=device)
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
@ -686,8 +669,7 @@ class Flux(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, device=device)
return out
return model_base.Flux(self, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@ -715,8 +697,7 @@ class FluxSchnell(Flux):
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
return model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
class GenmoMochi(supported_models_base.BASE):
unet_config = {
@ -739,8 +720,7 @@ class GenmoMochi(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.GenmoMochi(self, device=device)
return out
return model_base.GenmoMochi(self, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@ -767,8 +747,7 @@ class LTXV(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
return model_base.LTXV(self, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@ -795,8 +774,7 @@ class HunyuanVideo(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo(self, device=device)
return out
return model_base.HunyuanVideo(self, device=device)
def process_unet_state_dict(self, state_dict):
out_sd = {}

View File

@ -87,8 +87,7 @@ class BASE:
return out
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict
return utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
def process_unet_state_dict(self, state_dict):
return state_dict

View File

@ -60,8 +60,7 @@ class Downsample(nn.Module):
padding = [x.shape[2] % 2, x.shape[3] % 2]
self.op.padding = padding
x = self.op(x)
return x
return self.op(x)
class ResnetBlock(nn.Module):
@ -196,8 +195,7 @@ class ResidualAttentionBlock(nn.Module):
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
return x + self.mlp(self.ln_2(x))
class StyleAdapter(nn.Module):
@ -224,9 +222,8 @@ class StyleAdapter(nn.Module):
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, -self.num_token:, :])
x = x @ self.proj
return x @ self.proj
return x
class ResnetBlock_light(nn.Module):
@ -262,9 +259,8 @@ class extractor(nn.Module):
x = self.down_opt(x)
x = self.in_conv(x)
x = self.body(x)
x = self.out_conv(x)
return self.out_conv(x)
return x
class Adapter_light(nn.Module):

View File

@ -72,8 +72,7 @@ class TAESD(nn.Module):
def decode(self, x):
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
return x_sample.sub(0.5).mul(2)
def encode(self, x):
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift

View File

@ -17,8 +17,7 @@ class BertAttention(torch.nn.Module):
k = self.key(x)
v = self.value(x)
out = optimized_attention(q, k, v, self.heads, mask)
return out
return optimized_attention(q, k, v, self.heads, mask)
class BertOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
@ -30,8 +29,7 @@ class BertOutput(torch.nn.Module):
def forward(self, x, y):
x = self.dense(x)
# hidden_states = self.dropout(hidden_states)
x = self.LayerNorm(x + y)
return x
return self.LayerNorm(x + y)
class BertAttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
@ -100,8 +98,7 @@ class BertEmbeddings(torch.nn.Module):
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
else:
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
x = self.LayerNorm(x)
return x
return self.LayerNorm(x)
class BertModel_(torch.nn.Module):

View File

@ -142,9 +142,8 @@ class TransformerBlock(nn.Module):
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return residual + x
return x
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):

View File

@ -30,8 +30,7 @@ class T5DenseActDense(torch.nn.Module):
def forward(self, x):
x = self.act(self.wi(x))
# x = self.dropout(x)
x = self.wo(x)
return x
return self.wo(x)
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
@ -47,8 +46,7 @@ class T5DenseGatedActDense(torch.nn.Module):
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
# x = self.dropout(x)
x = self.wo(x)
return x
return self.wo(x)
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
@ -145,8 +143,7 @@ class T5Attention(torch.nn.Module):
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
return values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
q = self.q(x)

View File

@ -292,8 +292,7 @@ def unet_to_diffusers(unet_config):
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
return torch.cat([scale, shift], dim=0)
MMDIT_MAP_BASIC = {
("context_embedder.bias", "context_embedder.bias"),
@ -1000,8 +999,7 @@ def reshape_mask(input_mask, output_shape):
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = repeat_to_batch_size(mask, output_shape[0])
return mask
return repeat_to_batch_size(mask, output_shape[0])
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
hi, wi = img_size_in
@ -1037,9 +1035,8 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
# reassemble the mask from blocks
out = torch.cat([
return torch.cat([
torch.cat([txt_to_txt, txt_to_img], dim=2),
torch.cat([img_to_txt, img_to_img], dim=2)],
dim=1
)
return out

View File

@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps):
new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys
return np.exp(new_ys)[::-1].copy()
NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],

View File

@ -104,9 +104,8 @@ def create_vorbis_comment_block(comment_dict, last_block):
id = b'\x84'
else:
id = b'\x04'
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
return id + struct.pack('>I', len(comment_data))[1:] + comment_data
return comment_block
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
if len(comment_dict) == 0:

View File

@ -150,8 +150,7 @@ class PorterDuffImageComposite:
out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas))
return result
return (torch.stack(out_images), torch.stack(out_alphas))
class SplitImageWithAlpha:
@ -170,8 +169,7 @@ class SplitImageWithAlpha:
def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
return (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
class JoinImageWithAlpha:
@ -196,8 +194,7 @@ class JoinImageWithAlpha:
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result
return (torch.stack(out_images),)
NODE_CLASS_MAPPINGS = {

View File

@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps):
new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy()
return interped_ys
return np.exp(new_ys)[::-1].copy()
NOISE_LEVELS = {
0.80: [

View File

@ -27,8 +27,7 @@ class Mahiro:
normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4
return wm
return (simsc*cfg + (4-simsc)*leap) / 4
m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, )

View File

@ -11,8 +11,7 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co
perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
return cfg_result
return noise_pred_nocond + cond_scale*(pos - perp_neg)
#TODO: This node should be removed, it has been replaced with PerpNegGuider
class PerpNeg:
@ -44,8 +43,7 @@ class PerpNeg:
(noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
return cfg_result
return x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
m.set_model_sampler_cfg_function(cfg_function)

View File

@ -52,8 +52,7 @@ class FuseModule(nn.Module):
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
stacked_id_embeds = self.mlp2(stacked_id_embeds)
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
return stacked_id_embeds
return self.layer_norm(stacked_id_embeds)
def forward(
self,
@ -86,8 +85,7 @@ class FuseModule(nn.Module):
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
return updated_prompt_embeds
return prompt_embeds.view(batch_size, seq_length, -1)
class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
def __init__(self):
@ -111,9 +109,8 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
return self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
return updated_prompt_embeds
class PhotoMakerLoader:

View File

@ -162,8 +162,7 @@ class Quantize:
result = result.to(dtype=torch.uint8)
im = Image.fromarray(result.cpu().numpy())
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
return im
return im.quantize(palette=pal_im, dither=Image.Dither.NONE)
def quantize(self, image: torch.Tensor, colors: int, dither: str):
batch_size, height, width, _ = image.shape

View File

@ -50,8 +50,7 @@ class LatentRebatch:
def cat_batch(batch1, batch2):
if batch1[0] is None:
return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
return [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]

View File

@ -82,8 +82,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
return blurred * mask + x0 * (1 - mask)
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
@ -101,8 +100,7 @@ def gaussian_blur_2d(img, kernel_size, sigma):
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
return F.conv2d(img, kernel2d, groups=img.shape[-3])
class SelfAttentionGuidance:
@classmethod

View File

@ -5,7 +5,7 @@ import comfy.utils
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack(
return torch.stack(
[
torch.deg2rad(
(90 - elevation) - (90)
@ -17,7 +17,6 @@ def camera_embeddings(elevation, azimuth):
),
], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning:

View File

@ -81,11 +81,10 @@ class CacheSet:
self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
return {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()

View File

@ -356,8 +356,7 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2))
return input
return input.replace("%second%", str(now.tm_sec).zfill(2))
if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height)

View File

@ -607,8 +607,7 @@ class unCLIPCheckpointLoader:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
return comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CLIPSetLastLayer:
@classmethod

View File

@ -4,11 +4,13 @@ lint.ignore = ["ALL"]
# Enable specific rules
lint.select = [
"S307", # suspicious-eval-usage
"T201", # print-usage
"S102", # exec
"T", # print-usage
"W",
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F",
"RET504",
]
exclude = ["*.ipynb"]

View File

@ -129,8 +129,7 @@ class TestCompareImageMetrics:
def read_img(self, filename: str) -> np.ndarray:
cvImg = imread(filename)
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
return cvImg
return cvtColor(cvImg, COLOR_BGR2RGB)
def image_grid(self, img_list: list[list[Image.Image]]):
# imgs is a 2D list of images
@ -154,8 +153,7 @@ class TestCompareImageMetrics:
with open(metrics_output_file, 'r') as f:
for line in f:
if fname_basestr in line:
score = float(line.split('|')[5])
return score
return float(line.split('|')[5])
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
def gather_file_basenames(self, directory: str):

View File

@ -141,14 +141,13 @@ class TestExecutionBlockerNode:
@classmethod
def INPUT_TYPES(cls):
inputs = {
return {
"required": {
"input": ("*",),
"block": ("BOOLEAN",),
"verbose": ("BOOLEAN", {"default": False}),
},
}
return inputs
RETURN_TYPES = ("*",)
RETURN_NAMES = ("output",)