diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index ec01665e..a0618fb7 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -50,8 +50,7 @@ class ResBlockUnionControlnet(nn.Module): def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x + return x + self.mlp(self.ln_2(x)) class ControlledUnetModel(UNetModel): #implemented in the ldm unet diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 23ddea9c..b931d274 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -36,8 +36,7 @@ class CLIPMLP(torch.nn.Module): def forward(self, x): x = self.fc1(x) x = self.activation(x) - x = self.fc2(x) - return x + return self.fc2(x) class CLIPLayer(torch.nn.Module): def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 056b1aa6..22697828 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -270,5 +270,4 @@ class CheckLazyMixin: Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status """ - need = [name for name in kwargs if kwargs[name] is None] - return need + return [name for name in kwargs if kwargs[name] is None] diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ee29251b..f5569839 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -404,8 +404,7 @@ class ControlLora(ControlNet): super().cleanup() def get_models(self): - out = ControlBase.get_models(self) - return out + return ControlBase.get_models(self) def inference_memory_requirements(self, dtype): return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) @@ -461,8 +460,7 @@ def load_controlnet_mmdit(sd, model_options={}): latent_format = comfy.latent_formats.SD3() latent_format.shift_factor = 0 #SD3 controlnet weirdness - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype) - return control + return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype) class ControlNetSD35(ControlNet): @@ -536,8 +534,7 @@ def load_controlnet_sd35(sd, model_options={}): elif depth_cnet: preprocess_image = lambda a: 1.0 - a - control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image) - return control + return ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image) @@ -549,16 +546,14 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}): latent_format = comfy.latent_formats.SDXL() extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img'] - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) - return control + return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] - control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) - return control + return ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) def load_controlnet_flux_instantx(sd, model_options={}): new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") @@ -581,8 +576,7 @@ def load_controlnet_flux_instantx(sd, model_options={}): latent_format = comfy.latent_formats.Flux() extra_conds = ['y', 'guidance'] - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) - return control + return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -738,8 +732,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): logging.debug("unexpected controlnet keys: {}".format(unexpected)) global_average_pooling = model_options.get("global_average_pooling", False) - control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) - return control + return ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) def load_controlnet(ckpt_path, model=None, model_options={}): if "global_average_pooling" not in model_options: diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 26e8d96d..244f7755 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -99,8 +99,7 @@ def convert_unet_state_dict(unet_state_dict): for sd_part, hf_part in unet_conversion_map_layer: v = v.replace(hf_part, sd_part) mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - return new_state_dict + return {v: unet_state_dict[k] for k, v in mapping.items()} # ================# diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 5b80a8af..98f30aee 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -136,8 +136,7 @@ class NoiseScheduleVP: return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 elif self.schedule == 'cosine': log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t + return log_alpha_fn(t) - self.cosine_log_alpha_0 def marginal_alpha(self, t): """ @@ -174,8 +173,7 @@ class NoiseScheduleVP: else: log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t + return t_fn(log_alpha) def model_wrapper( @@ -378,8 +376,7 @@ class UniPC: p = self.dynamic_thresholding_ratio s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 + return torch.clamp(x0, -s, s) / s def noise_prediction_fn(self, x, t): """ @@ -423,8 +420,7 @@ class UniPC: return torch.linspace(t_T, t_0, N + 1).to(device) elif skip_type == 'time_quadratic': t_order = 2 - t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) - return t + return torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) else: raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) @@ -801,8 +797,7 @@ def interpolate_fn(x, xp, yp): y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand + return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) def expand_dims(v, dims): diff --git a/comfy/gligen.py b/comfy/gligen.py index 161d8a5e..575c3a0d 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -78,10 +78,9 @@ class GatedCrossAttentionDense(nn.Module): x = x + self.scale * \ torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) - x = x + self.scale * \ + return x + self.scale * \ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) - return x class GatedSelfAttentionDense(nn.Module): @@ -118,10 +117,9 @@ class GatedSelfAttentionDense(nn.Module): x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] - x = x + self.scale * \ + return x + self.scale * \ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) - return x class GatedSelfAttentionDense2(nn.Module): @@ -172,10 +170,9 @@ class GatedSelfAttentionDense2(nn.Module): # add residual to visual feature x = x + self.scale * torch.tanh(self.alpha_attn) * residual - x = x + self.scale * \ + return x + self.scale * \ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) - return x class FourierEmbedder(): @@ -340,5 +337,4 @@ def load_gligen(sd): w.position_net = PositionNet(in_dim, out_dim) w.load_state_dict(sd, strict=False) - gligen = Gligen(output_list, w.position_net, key_dim) - return gligen + return Gligen(output_list, w.position_net, key_dim) diff --git a/comfy/k_diffusion/deis.py b/comfy/k_diffusion/deis.py index a1167a4a..96154705 100644 --- a/comfy/k_diffusion/deis.py +++ b/comfy/k_diffusion/deis.py @@ -46,8 +46,7 @@ def cal_intergrand(beta_0, beta_1, taus): log_alpha = alpha.log() log_alpha.sum().backward() d_log_alpha_dtau = taus.grad - integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha)) - return integrand + return -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha)) #---------------------------------------------------------------------------- diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0f7cc4ca..7d454c88 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -50,8 +50,7 @@ def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'): x = torch.linspace(0, 1, n, device=device) clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max) lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon) - sigmas = clamp(torch.exp(lmb)) - return sigmas + return clamp(torch.exp(lmb)) diff --git a/comfy/ldm/audio/autoencoder.py b/comfy/ldm/audio/autoencoder.py index 9e7e7c87..0a10b45a 100644 --- a/comfy/ldm/audio/autoencoder.py +++ b/comfy/ldm/audio/autoencoder.py @@ -70,9 +70,8 @@ class SnakeBeta(nn.Module): if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) - x = snake_beta(x, alpha, beta) + return snake_beta(x, alpha, beta) - return x def WNConv1d(*args, **kwargs): try: diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 179c5b67..801c1693 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -89,8 +89,7 @@ class AbsolutePositionalEmbedding(nn.Module): pos = (pos - seq_start_pos[..., None]).clamp(min = 0) pos_emb = self.emb(pos) - pos_emb = pos_emb * self.scale - return pos_emb + return pos_emb * self.scale class ScaledSinusoidalEmbedding(nn.Module): def __init__(self, dim, theta = 10000): @@ -404,9 +403,8 @@ class ConformerModule(nn.Module): x = self.swish(x) x = rearrange(x, 'b n d -> b d n') x = self.pointwise_conv_2(x) - x = rearrange(x, 'b d n -> b n d') + return rearrange(x, 'b d n -> b n d') - return x class TransformerBlock(nn.Module): def __init__( diff --git a/comfy/ldm/audio/embedders.py b/comfy/ldm/audio/embedders.py index 20edb365..7541fdb7 100644 --- a/comfy/ldm/audio/embedders.py +++ b/comfy/ldm/audio/embedders.py @@ -21,8 +21,7 @@ class LearnedPositionalEmbedding(nn.Module): x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) - fouriered = torch.cat((x, fouriered), dim=-1) - return fouriered + return torch.cat((x, fouriered), dim=-1) def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: return nn.Sequential( @@ -49,8 +48,7 @@ class NumberEmbedder(nn.Module): shape = x.shape x = rearrange(x, "... -> (...)") embedding = self.embedding(x) - x = embedding.view(*shape, self.features) - return x # type: ignore + return embedding.view(*shape, self.features) class Conditioner(nn.Module): diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 1258ae11..00cf9852 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -36,8 +36,7 @@ class MLP(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.c_fc1(x)) * self.c_fc2(x) - x = self.c_proj(x) - return x + return self.c_proj(x) class MultiHeadLayerNorm(nn.Module): @@ -95,8 +94,7 @@ class SingleAttention(nn.Module): q, k = self.q_norm1(q), self.k_norm1(k) output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) - c = self.w1o(output) - return c + return self.w1o(output) @@ -265,9 +263,8 @@ class DiTBlock(nn.Module): mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) cx = gate_mlp.unsqueeze(1) * mlpout - cx = cxres + cx + return cxres + cx - return cx @@ -298,8 +295,7 @@ class TimestepEmbedder(nn.Module): #@torch.compile() def forward(self, t, dtype): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) - t_emb = self.mlp(t_freq) - return t_emb + return self.mlp(t_freq) class MMDiT(nn.Module): @@ -401,8 +397,7 @@ class MMDiT(nn.Module): x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum("nhwpqc->nchpwq", x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) - return imgs + return x.reshape(shape=(x.shape[0], c, h * p, w * p)) def patchify(self, x): B, C, H, W = x.size() @@ -415,8 +410,7 @@ class MMDiT(nn.Module): (W + 1) // self.patch_size, self.patch_size, ) - x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) - return x + return x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) def apply_pos_embeds(self, x, h, w): h = (h + 1) // self.patch_size @@ -494,5 +488,4 @@ class MMDiT(nn.Module): x = modulate(x, fshift, fscale) x = self.final_linear(x) - x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w] - return x + return self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w] diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 3eaa0c82..ba7dd14e 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -54,8 +54,7 @@ class Attention2D(nn.Module): kv = torch.cat([x, kv], dim=1) # x = self.attn(x, kv, kv, need_weights=False)[0] x = self.attn(x, kv, kv) - x = x.permute(0, 2, 1).view(*orig_shape) - return x + return x.permute(0, 2, 1).view(*orig_shape) def LayerNorm2d_op(operations): @@ -116,8 +115,7 @@ class AttnBlock(nn.Module): def forward(self, x, kv): kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) - return x + return x + self.attention(self.norm(x), kv, self_attn=self.self_attn) class FeedForwardBlock(nn.Module): @@ -133,8 +131,7 @@ class FeedForwardBlock(nn.Module): ) def forward(self, x): - x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x + return x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class TimestepBlock(nn.Module): diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index ca8867ea..93971811 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -157,9 +157,8 @@ class ResBlock(nn.Module): x = x + self.depthwise[1](x_temp) * mods[2] x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] - x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + return x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] - return x class StageA(nn.Module): @@ -218,8 +217,7 @@ class StageA(nn.Module): def decode(self, x): x = self.up_blocks(x) - x = self.out_block(x) - return x + return self.out_block(x) def forward(self, x, quantize=False): qe, x, _, vq_loss = self.encode(x, quantize) @@ -251,5 +249,4 @@ class Discriminator(nn.Module): cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) x = torch.cat([x, cond], dim=1) x = self.shuffle(x) - x = self.logits(x) - return x + return self.logits(x) diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py index 77383095..ed2cd1ae 100644 --- a/comfy/ldm/cascade/stage_b.py +++ b/comfy/ldm/cascade/stage_b.py @@ -170,8 +170,7 @@ class StageB(nn.Module): if len(clip.shape) == 2: clip = clip.unsqueeze(1) clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) - clip = self.clip_norm(clip) - return clip + return self.clip_norm(clip) def _down_encode(self, x, r_embed, clip): level_outputs = [] diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index b952d034..d4fdbf64 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -179,8 +179,7 @@ class StageC(nn.Module): clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) - clip = self.clip_norm(clip) - return clip + return self.clip_norm(clip) def _down_encode(self, x, r_embed, clip, cnet=None): level_outputs = [] diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py index 0cb7c49f..5cfd4940 100644 --- a/comfy/ldm/cascade/stage_c_coder.py +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -35,8 +35,7 @@ class EfficientNetEncoder(nn.Module): def forward(self, x): x = x * 0.5 + 0.5 x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) - o = self.mapper(self.backbone(x)) - return o + return self.mapper(self.backbone(x)) # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8e055151..b925bd29 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -256,5 +256,4 @@ class LastLayer(nn.Module): def forward(self, x: Tensor, vec: Tensor) -> Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] - x = self.linear(x) - return x + return self.linear(x) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index b6549585..03f6c6d4 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -9,8 +9,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: q, k = apply_rope(q, k, pe) heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) - return x + return optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) def rope(pos: Tensor, dim: int, theta: int) -> Tensor: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index dead87de..dc5b516c 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -183,8 +183,7 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - return img + return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape diff --git a/comfy/ldm/flux/redux.py b/comfy/ldm/flux/redux.py index 527e8316..f444c161 100644 --- a/comfy/ldm/flux/redux.py +++ b/comfy/ldm/flux/redux.py @@ -21,5 +21,4 @@ class ReduxImageEncoder(torch.nn.Module): self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) def forward(self, sigclip_embeds) -> torch.Tensor: - projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds))) - return projected_x + return self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds))) diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 2c46c24b..138e58b6 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -34,9 +34,8 @@ import comfy.ops def modulated_rmsnorm(x, scale, eps=1e-6): # Normalize and modulate x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps) - x_modulated = x_normed * (1 + scale.unsqueeze(1)) + return x_normed * (1 + scale.unsqueeze(1)) - return x_modulated def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): @@ -47,9 +46,8 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate # Apply residual connection - output = x + x_normed + return x + x_normed - return output class AsymmetricAttention(nn.Module): def __init__( @@ -275,14 +273,12 @@ class AsymmetricJointBlock(nn.Module): def ff_block_x(self, x, scale_x, gate_x): x_mod = modulated_rmsnorm(x, scale_x) x_res = self.mlp_x(x_mod) - x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm - return x + return residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm def ff_block_y(self, y, scale_y, gate_y): y_mod = modulated_rmsnorm(y, scale_y) y_res = self.mlp_y(y_mod) - y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm - return y + return residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm class FinalLayer(nn.Module): @@ -312,8 +308,7 @@ class FinalLayer(nn.Module): c = F.silu(c) shift, scale = self.mod(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x + return self.linear(x) class AsymmDiTJoint(nn.Module): diff --git a/comfy/ldm/genmo/joint_model/layers.py b/comfy/ldm/genmo/joint_model/layers.py index 51d97955..3eab79c6 100644 --- a/comfy/ldm/genmo/joint_model/layers.py +++ b/comfy/ldm/genmo/joint_model/layers.py @@ -64,8 +64,7 @@ class TimestepEmbedder(nn.Module): if self.timestep_scale is not None: t = t * self.timestep_scale t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype) - t_emb = self.mlp(t_freq) - return t_emb + return self.mlp(t_freq) class FeedForward(nn.Module): @@ -93,8 +92,7 @@ class FeedForward(nn.Module): def forward(self, x): x, gate = self.w1(x).chunk(2, dim=-1) - x = self.w2(F.silu(x) * gate) - return x + return self.w2(F.silu(x) * gate) class PatchEmbed(nn.Module): @@ -149,8 +147,7 @@ class PatchEmbed(nn.Module): raise NotImplementedError("Must flatten output.") x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T) - x = self.norm(x) - return x + return self.norm(x) class RMSNorm(torch.nn.Module): diff --git a/comfy/ldm/genmo/joint_model/rope_mixed.py b/comfy/ldm/genmo/joint_model/rope_mixed.py index dee3fa21..990ff426 100644 --- a/comfy/ldm/genmo/joint_model/rope_mixed.py +++ b/comfy/ldm/genmo/joint_model/rope_mixed.py @@ -60,9 +60,8 @@ def create_position_matrix( # Stack and reshape the grids. pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] pos = pos.view(-1, 3) # [T * pH * pW, 3] - pos = pos.to(dtype=dtype, device=device) + return pos.to(dtype=dtype, device=device) - return pos def compute_mixed_rotation( diff --git a/comfy/ldm/genmo/joint_model/temporal_rope.py b/comfy/ldm/genmo/joint_model/temporal_rope.py index 88f5d6d2..53a69d05 100644 --- a/comfy/ldm/genmo/joint_model/temporal_rope.py +++ b/comfy/ldm/genmo/joint_model/temporal_rope.py @@ -30,5 +30,4 @@ def apply_rotary_emb_qk_real( sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) # Interleave the results back into the original shape - out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) - return out + return torch.stack([cos_part, sin_part], dim=-1).flatten(-2) diff --git a/comfy/ldm/genmo/joint_model/utils.py b/comfy/ldm/genmo/joint_model/utils.py index 1b399d5d..0ba42359 100644 --- a/comfy/ldm/genmo/joint_model/utils.py +++ b/comfy/ldm/genmo/joint_model/utils.py @@ -29,8 +29,7 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch. assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. mask = mask[:, :, None].to(dtype=x.dtype) mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) - pooled = (x * mask).sum(dim=1, keepdim=keepdim) - return pooled + return (x * mask).sum(dim=1, keepdim=keepdim) class AttentionPool(nn.Module): @@ -98,5 +97,4 @@ class AttentionPool(nn.Module): # Concatenate heads and run output. x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) - x = self.to_out(x) - return x + return self.to_out(x) diff --git a/comfy/ldm/genmo/vae/model.py b/comfy/ldm/genmo/vae/model.py index 1bde0c1e..d9c4f921 100644 --- a/comfy/ldm/genmo/vae/model.py +++ b/comfy/ldm/genmo/vae/model.py @@ -99,8 +99,7 @@ class Conv1x1(ops.Linear): """ x = x.movedim(1, -1) x = super().forward(x) - x = x.movedim(-1, 1) - return x + return x.movedim(-1, 1) class DepthToSpaceTime(nn.Module): @@ -269,8 +268,7 @@ class Attention(nn.Module): assert x.size(0) == q.size(0) x = self.out(x) - x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W) - return x + return rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W) class AttentionBlock(nn.Module): @@ -321,8 +319,7 @@ class CausalUpsampleBlock(nn.Module): def forward(self, x): x = self.blocks(x) x = self.proj(x) - x = self.d2st(x) - return x + return self.d2st(x) def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs): diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index d6d85408..99ff96c3 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -86,8 +86,7 @@ class TokenRefinerBlock(nn.Module): attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) - x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) - return x + return x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) class IndividualTokenRefiner(nn.Module): @@ -157,8 +156,7 @@ class TokenRefiner(nn.Module): c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, mask) - return x + return self.individual_token_refiner(x, c, mask) class HunyuanVideo(nn.Module): """ @@ -311,8 +309,7 @@ class HunyuanVideo(nn.Module): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape) - return img + return img.reshape(initial_shape) def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape @@ -326,5 +323,4 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) - return out + return self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) diff --git a/comfy/ldm/hydit/attn_layers.py b/comfy/ldm/hydit/attn_layers.py index 3ca25a5d..4b31f731 100644 --- a/comfy/ldm/hydit/attn_layers.py +++ b/comfy/ldm/hydit/attn_layers.py @@ -166,9 +166,8 @@ class CrossAttention(nn.Module): out = self.out_proj(context) # context.reshape - B, L1, -1 out = self.proj_drop(out) - out_tuple = (out,) + return (out,) - return out_tuple class Attention(nn.Module): @@ -213,6 +212,5 @@ class Attention(nn.Module): x = self.out_proj(x) x = self.proj_drop(x) - out_tuple = (x,) + return (x,) - return out_tuple diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 359f6a96..0b724fa7 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -19,8 +19,7 @@ def calc_rope(x, patch_size, head_size): sub_args = [start, stop, (th, tw)] # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] rope = get_2d_rotary_pos_embed(head_size, *sub_args) - rope = (rope[0].to(x), rope[1].to(x)) - return rope + return (rope[0].to(x), rope[1].to(x)) def modulate(x, shift, scale): @@ -110,9 +109,8 @@ class HunYuanDiTBlock(nn.Module): # FFN Layer mlp_inputs = self.norm2(x) - x = x + self.mlp(mlp_inputs) + return x + self.mlp(mlp_inputs) - return x def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None): if self.gradient_checkpointing and self.training: @@ -136,8 +134,7 @@ class FinalLayer(nn.Module): def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x + return self.linear(x) class HunYuanDiT(nn.Module): @@ -413,5 +410,4 @@ class HunYuanDiT(nn.Module): x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) - return imgs + return x.reshape(shape=(x.shape[0], c, h * p, w * p)) diff --git a/comfy/ldm/hydit/posemb_layers.py b/comfy/ldm/hydit/posemb_layers.py index dcb41a71..bff7813f 100644 --- a/comfy/ldm/hydit/posemb_layers.py +++ b/comfy/ldm/hydit/posemb_layers.py @@ -53,8 +53,7 @@ def get_meshgrid(start, *args): grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) # [2, W, H] - return grid + return np.stack(grid, axis=0) # [2, W, H] ################################################################################# # Sine/Cosine Positional Embedding Functions # @@ -87,8 +86,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb + return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): @@ -108,8 +106,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb + return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) ################################################################################# @@ -138,8 +135,7 @@ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): """ grid = get_meshgrid(start, *args) # [2, H, w] grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution - pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) - return pos_embed + return get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): @@ -154,8 +150,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) return cos, sin else: - emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) - return emb + return torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): @@ -187,8 +182,7 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis + return torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index eeeeaea0..aed22013 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -122,14 +122,13 @@ class Timesteps(nn.Module): self.scale = scale def forward(self, timesteps): - t_emb = get_timestep_embedding( + return get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) - return t_emb class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): @@ -149,8 +148,7 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - return timesteps_emb + return self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) class AdaLayerNormSingle(nn.Module): @@ -209,8 +207,7 @@ class PixArtAlphaTextProjection(nn.Module): def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states + return self.linear_2(hidden_states) class GELU_approx(nn.Module): @@ -247,9 +244,8 @@ def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and t_dup = torch.stack((-t2, t1), dim=-1) input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + return input_tensor * cos_freqs + input_tensor_rot * sin_freqs - return out class CrossAttention(nn.Module): @@ -316,14 +312,13 @@ class BasicTransformerBlock(nn.Module): return x def get_fractional_positions(indices_grid, max_pos): - fractional_positions = torch.stack( + return torch.stack( [ indices_grid[:, i] / max_pos[i] for i in range(3) ], dim=-1, ) - return fractional_positions def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index c58dfb20..c946b312 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -65,8 +65,7 @@ class Patchifier(ABC): scale = scale_grid[i] grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] - grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) - return grid + return rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) class SymmetricPatchifier(Patchifier): @@ -74,14 +73,13 @@ class SymmetricPatchifier(Patchifier): self, latents: Tensor, ) -> Tuple[Tensor, Tensor]: - latents = rearrange( + return rearrange( latents, "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", p1=self._patch_size[0], p2=self._patch_size[1], p3=self._patch_size[2], ) - return latents def unpatchify( self, @@ -93,7 +91,7 @@ class SymmetricPatchifier(Patchifier): ) -> Tuple[Tensor, Tensor]: output_height = output_height // self._patch_size[1] output_width = output_width // self._patch_size[2] - latents = rearrange( + return rearrange( latents, "b (f h w) (c p q) -> b c f (h p) (w q) ", f=output_num_frames, @@ -102,4 +100,3 @@ class SymmetricPatchifier(Patchifier): p=self._patch_size[1], q=self._patch_size[2], ) - return latents diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index c572e7e8..fe48c309 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -56,8 +56,7 @@ class CausalConv3d(nn.Module): (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) ) x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x + return self.conv(x) @property def weight(self): diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index e0344dee..6570b33c 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -417,9 +417,8 @@ class Decoder(nn.Module): sample = self.conv_act(sample) sample = self.conv_out(sample, causal=self.causal) - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + return unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - return sample class UNetMidBlock3D(nn.Module): @@ -563,8 +562,7 @@ class LayerNorm(nn.Module): def forward(self, x): x = rearrange(x, "b c d h w -> b d h w c") x = self.norm(x) - x = rearrange(x, "b d h w c -> b c d h w") - return x + return rearrange(x, "b d h w c -> b c d h w") class ResnetBlock3D(nn.Module): @@ -677,9 +675,8 @@ class ResnetBlock3D(nn.Module): # similar to the "explicit noise inputs" method in style-gan spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] - hidden_states = hidden_states + scaled_noise + return hidden_states + scaled_noise - return hidden_states def forward( self, @@ -740,9 +737,8 @@ class ResnetBlock3D(nn.Module): input_tensor = self.conv_shortcut(input_tensor) - output_tensor = input_tensor + hidden_states + return input_tensor + hidden_states - return output_tensor def patchify(x, patch_size_hw, patch_size_t=1): diff --git a/comfy/ldm/lightricks/vae/dual_conv3d.py b/comfy/ldm/lightricks/vae/dual_conv3d.py index 6bd54c0a..7ad55e84 100644 --- a/comfy/ldm/lightricks/vae/dual_conv3d.py +++ b/comfy/ldm/lightricks/vae/dual_conv3d.py @@ -114,7 +114,7 @@ class DualConv3d(nn.Module): return x # Second convolution - x = F.conv3d( + return F.conv3d( x, self.weight2, self.bias2, @@ -124,7 +124,6 @@ class DualConv3d(nn.Module): self.groups, ) - return x def forward_with_2d(self, x, skip_time_conv): b, c, d, h, w = x.shape @@ -142,8 +141,7 @@ class DualConv3d(nn.Module): _, _, h, w = x.shape if skip_time_conv: - x = rearrange(x, "(b d) c h w -> b c d h w", b=b) - return x + return rearrange(x, "(b d) c h w -> b c d h w", b=b) # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) @@ -155,9 +153,8 @@ class DualConv3d(nn.Module): padding2 = self.padding2[0] dilation2 = self.dilation2[0] x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) - x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + return rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) - return x @property def weight(self): diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index e6493155..6fa57652 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -136,8 +136,7 @@ class AutoencodingEngine(AbstractAutoencoder): return z def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: - x = self.decoder(z, **kwargs) - return x + return self.decoder(z, **kwargs) def forward( self, x: torch.Tensor, **additional_decode_kwargs @@ -178,8 +177,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine): self.embed_dim = embed_dim def get_autoencoder_params(self) -> list: - params = super().get_autoencoder_params() - return params + return super().get_autoencoder_params() def encode( self, x: torch.Tensor, return_reg_log: bool = False diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 0d54e6be..4de6370f 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -142,13 +142,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) - out = ( + return ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) .permute(0, 2, 1, 3) .reshape(b, -1, heads * dim_head) ) - return out def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False): @@ -216,8 +215,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.to(dtype) - hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) - return hidden_states + return hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): attn_precision = get_attn_precision(attn_precision) @@ -326,13 +324,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape del q, k, v - r1 = ( + return ( r1.unsqueeze(0) .reshape(b, heads, -1, dim_head) .permute(0, 2, 1, 3) .reshape(b, -1, heads * dim_head) ) - return r1 BROKEN_XFORMERS = False try: @@ -395,11 +392,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - out = ( + return ( out.reshape(b, -1, heads * dim_head) ) - return out if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue. SDP_BATCH_LIMIT = 2**15 @@ -932,7 +928,6 @@ class SpatialVideoTransformer(SpatialTransformer): x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) if not self.use_linear: x = self.proj_out(x) - out = x + x_in - return out + return x + x_in diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index e70f4431..5f520197 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -51,8 +51,7 @@ class Mlp(nn.Module): x = self.drop1(x) x = self.norm(x) x = self.fc2(x) - x = self.drop2(x) - return x + return self.drop2(x) class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding @@ -103,8 +102,7 @@ class PatchEmbed(nn.Module): x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC - x = self.norm(x) - return x + return self.norm(x) def modulate(x, shift, scale): if shift is None: @@ -155,8 +153,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb + return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): @@ -176,8 +173,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb + return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32): omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) @@ -187,8 +183,7 @@ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=t out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) - emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) - return emb + return torch.cat([emb_sin, emb_cos], dim=1) # (M, D) def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32): small = min(h, w) @@ -197,8 +192,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij') emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) - emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) - return emb + return torch.cat([emb_w, emb_h], dim=1) # (H*W, D) ################################################################################# @@ -222,8 +216,7 @@ class TimestepEmbedder(nn.Module): def forward(self, t, dtype, **kwargs): t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) - t_emb = self.mlp(t_freq) - return t_emb + return self.mlp(t_freq) class VectorEmbedder(nn.Module): @@ -240,8 +233,7 @@ class VectorEmbedder(nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - emb = self.mlp(x) - return emb + return self.mlp(x) ################################################################################# @@ -307,16 +299,14 @@ class SelfAttention(nn.Module): def post_attention(self, x: torch.Tensor) -> torch.Tensor: assert not self.pre_only x = self.proj(x) - x = self.proj_drop(x) - return x + return self.proj_drop(x) def forward(self, x: torch.Tensor) -> torch.Tensor: q, k, v = self.pre_attention(x) x = optimized_attention( q, k, v, heads=self.num_heads ) - x = self.post_attention(x) - return x + return self.post_attention(x) class RMSNorm(torch.nn.Module): @@ -530,10 +520,9 @@ class DismantledBlock(nn.Module): def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): assert not self.pre_only x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) - x = x + gate_mlp.unsqueeze(1) * self.mlp( + return x + gate_mlp.unsqueeze(1) * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp) ) - return x def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: assert self.x_block_self_attn @@ -568,10 +557,9 @@ class DismantledBlock(nn.Module): out2 = gate_msa2.unsqueeze(1) * attn2 x = x + out1 x = x + out2 - x = x + gate_mlp.unsqueeze(1) * self.mlp( + return x + gate_mlp.unsqueeze(1) * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp) ) - return x def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: assert not self.pre_only @@ -696,8 +684,7 @@ class FinalLayer(nn.Module): def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x + return self.linear(x) class SelfAttentionContext(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None): @@ -902,13 +889,12 @@ class MMDiT(nn.Module): w=self.pos_embed_max_size, ) spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] - spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") + return rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") # print(spatial_pos_embed, top, left, h, w) # # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better # # print(t) # return t - return spatial_pos_embed def unpatchify(self, x, hw=None): """ @@ -927,8 +913,7 @@ class MMDiT(nn.Module): x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum("nhwpqc->nchpwq", x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) - return imgs + return x.reshape(shape=(x.shape[0], c, h * p, w * p)) def forward_core_with_concat( self, @@ -976,8 +961,7 @@ class MMDiT(nn.Module): if add is not None: x += add - x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) - return x + return self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) def forward( self, diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index ed1e8821..aa2a85b9 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -493,8 +493,7 @@ class Model(nn.Module): # end h = self.norm_out(h) h = nonlinearity(h) - h = self.conv_out(h) - return h + return self.conv_out(h) def get_last_layer(self): return self.conv_out.weight @@ -602,8 +601,7 @@ class Encoder(nn.Module): # end h = self.norm_out(h) h = nonlinearity(h) - h = self.conv_out(h) - return h + return self.conv_out(h) class Decoder(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53ca..12967243 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -350,8 +350,7 @@ class VideoResBlock(ResBlock): x = self.time_mixer( x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator ) - x = rearrange(x, "b c t h w -> (b t) c h w") - return x + return rearrange(x, "b c t h w -> (b t) c h w") class Timestep(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 233011dc..38b2cd3f 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -79,11 +79,10 @@ class AlphaBlender(nn.Module): image_only_indicator=None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator, x_spatial.device) - x = ( + return ( alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal ) - return x def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): @@ -201,8 +200,7 @@ class CheckpointFunction(torch.autograd.Function): "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled()} with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors + return ctx.run_function(*ctx.input_tensors) @staticmethod def backward(ctx, *output_grads): diff --git a/comfy/ldm/modules/distributions/distributions.py b/comfy/ldm/modules/distributions/distributions.py index df987c5e..80155006 100644 --- a/comfy/ldm/modules/distributions/distributions.py +++ b/comfy/ldm/modules/distributions/distributions.py @@ -33,8 +33,7 @@ class DiagonalGaussianDistribution(object): self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device) def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) - return x + return self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) def kl(self, other=None): if self.deterministic: diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index a5d86603..c41c77ee 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -15,13 +15,11 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): def scale(self, x): # re-normalize to centered mean and unit variance - x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) - return x + return (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) def unscale(self, x): # back to original data stats - x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) - return x + return (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) def forward(self, x, noise_level=None, seed=None): if noise_level is None: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 21c72373..08bb7c96 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -177,8 +177,7 @@ def _get_attention_scores_no_kv_chunking( attn_scores /= summed attn_probs = attn_scores - hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value) - return hidden_states_slice + return torch.bmm(attn_probs.to(value.dtype), value) class ScannedChunk(NamedTuple): chunk_idx: int @@ -264,7 +263,7 @@ def efficient_dot_product_attention( # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # and pass slices to be mutated, instead of torch.cat()ing the returned slices - res = torch.cat([ + return torch.cat([ compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key_t=key_t, @@ -272,4 +271,3 @@ def efficient_dot_product_attention( mask=get_mask_chunk(i * query_chunk_size) ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) - return res diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 2225076e..390f7c26 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -73,8 +73,7 @@ class MultiHeadCrossAttention(nn.Module): x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None) x = self.proj(x) - x = self.proj_drop(x) - return x + return self.proj_drop(x) class AttentionKVCompress(nn.Module): @@ -172,8 +171,7 @@ class AttentionKVCompress(nn.Module): x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True) x = x.view(B, N, C) - x = self.proj(x) - return x + return self.proj(x) class FinalLayer(nn.Module): @@ -192,8 +190,7 @@ class FinalLayer(nn.Module): def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x + return self.linear(x) class T2IFinalLayer(nn.Module): """ @@ -209,8 +206,7 @@ class T2IFinalLayer(nn.Module): def forward(self, x, t): shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1) x = t2i_modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x + return self.linear(x) class MaskFinalLayer(nn.Module): @@ -228,8 +224,7 @@ class MaskFinalLayer(nn.Module): def forward(self, x, t): shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x + return self.linear(x) class DecoderLayer(nn.Module): @@ -247,8 +242,7 @@ class DecoderLayer(nn.Module): def forward(self, x, t): shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) x = modulate(self.norm_decoder(x), shift, scale) - x = self.linear(x) - return x + return self.linear(x) class SizeEmbedder(TimestepEmbedder): @@ -276,8 +270,7 @@ class SizeEmbedder(TimestepEmbedder): s = rearrange(s, "b d -> (b d)") s_freq = timestep_embedding(s, self.frequency_embedding_size) s_emb = self.mlp(s_freq.to(s.dtype)) - s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) - return s_emb + return rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) class LabelEmbedder(nn.Module): @@ -299,15 +292,13 @@ class LabelEmbedder(nn.Module): drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob else: drop_ids = force_drop_ids == 1 - labels = torch.where(drop_ids, self.num_classes, labels) - return labels + return torch.where(drop_ids, self.num_classes, labels) def forward(self, labels, train, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings + return self.embedding_table(labels) class CaptionEmbedder(nn.Module): @@ -331,8 +322,7 @@ class CaptionEmbedder(nn.Module): drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob else: drop_ids = force_drop_ids == 1 - caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) - return caption + return torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) def forward(self, caption, train, force_drop_ids=None): if train: @@ -340,8 +330,7 @@ class CaptionEmbedder(nn.Module): use_dropout = self.uncond_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): caption = self.token_drop(caption, force_drop_ids) - caption = self.y_proj(caption) - return caption + return self.y_proj(caption) class CaptionEmbedderDoubleBr(nn.Module): diff --git a/comfy/ldm/pixart/pixartms.py b/comfy/ldm/pixart/pixartms.py index 7d4eebdc..78a472f7 100644 --- a/comfy/ldm/pixart/pixartms.py +++ b/comfy/ldm/pixart/pixartms.py @@ -23,8 +23,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_si ) emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) - emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) - return emb + return torch.cat([emb_w, emb_h], dim=1) # (H*W, D) class PixArtMSBlock(nn.Module): """ @@ -57,9 +56,8 @@ class PixArtMSBlock(nn.Module): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1) x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) x = x + self.cross_attn(x, y, mask) - x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + return x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) - return x ### Core PixArt Model ### @@ -212,9 +210,8 @@ class PixArtMS(nn.Module): x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) - x = self.unpatchify(x, H, W) # (N, out_channels, H, W) + return self.unpatchify(x, H, W) # (N, out_channels, H, W) - return x def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs): B, C, H, W = x.shape @@ -252,5 +249,4 @@ class PixArtMS(nn.Module): x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) - return imgs + return x.reshape(shape=(x.shape[0], c, h * p, w * p)) diff --git a/comfy/ldm/util.py b/comfy/ldm/util.py index 30b4b472..4f1ac703 100644 --- a/comfy/ldm/util.py +++ b/comfy/ldm/util.py @@ -29,8 +29,7 @@ def log_txt_as_img(wh, xc, size=10): txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) - txts = torch.tensor(txts) - return txts + return torch.tensor(txts) def ismap(x): diff --git a/comfy/model_base.py b/comfy/model_base.py index 141f3f40..61d0209c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -206,8 +206,7 @@ class BaseModel(torch.nn.Module): cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": cond_concat.append(self.blank_inpaint_image_like(noise)) - data = torch.cat(cond_concat, dim=1) - return data + return torch.cat(cond_concat, dim=1) return None def extra_conds(self, **kwargs): @@ -418,8 +417,7 @@ class SVD_img2vid(BaseModel): out.append(self.embedder(torch.Tensor([motion_bucket_id]))) out.append(self.embedder(torch.Tensor([augmentation]))) - flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) - return flat + return torch.flatten(torch.cat(out)).unsqueeze(dim=0) def extra_conds(self, **kwargs): out = {} @@ -457,8 +455,7 @@ class SV3D_u(SVD_img2vid): out = [] out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) - flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) - return flat + return torch.flatten(torch.cat(out)).unsqueeze(dim=0) class SV3D_p(SVD_img2vid): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4597ce11..77d2d84f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -874,8 +874,7 @@ class ModelPatcher: all_models = self.get_additional_models() models_set = set(all_models) - real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set) - return real_all_models + return _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set) def use_ejected(self, skip_and_inject_on_exit_only=False): return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 4370516b..5bdd359a 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -286,8 +286,7 @@ class StableCascadeSampling(ModelSamplingDiscrete): var = 1 / ((sigma * sigma) + 1) var = var.clamp(0, 1.0) s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) - t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s - return t + return (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s def percent_to_sigma(self, percent): if percent <= 0.0: diff --git a/comfy/sample.py b/comfy/sample.py index be5a7e24..50258958 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -21,8 +21,7 @@ def prepare_noise(latent_image, seed, noise_inds=None): if i in unique_inds: noises.append(noise) noises = [noises[i] for i in inverse] - noises = torch.cat(noises, axis=0) - return noises + return torch.cat(noises, axis=0) def fix_empty_latent_channels(model, latent_image): latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels @@ -43,10 +42,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) - return samples + return samples.to(comfy.model_management.intermediate_device()) def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) - return samples + return samples.to(comfy.model_management.intermediate_device()) diff --git a/comfy/samplers.py b/comfy/samplers.py index 89464a42..1ccf0ca3 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -713,8 +713,7 @@ class KSAMPLER(Sampler): k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) - samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) - return samples + return model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) def ksampler(sampler_name, extra_options={}, inpaint_options={}): diff --git a/comfy/sd.py b/comfy/sd.py index c6d6236b..cd311678 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -422,12 +422,11 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() - output = self.process_output( + return self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) / 3.0) - return output def decode_tiled_1d(self, samples, tile_x=128, overlap=32): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() @@ -485,8 +484,7 @@ class VAE: overlap = tile // 4 pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) - return pixel_samples + return pixel_samples.to(self.output_device).movedim(1,-1) def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 95d41c30..5c546e16 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -304,13 +304,11 @@ def token_weights(string, current_weight): def escape_important(text): text = text.replace("\\)", "\0\1") - text = text.replace("\\(", "\0\2") - return text + return text.replace("\\(", "\0\2") def unescape_important(text): text = text.replace("\0\1", ")") - text = text.replace("\0\2", "(") - return text + return text.replace("\0\2", "(") def safe_load_embed_zip(embed_path): with zipfile.ZipFile(embed_path) as myzip: @@ -635,8 +633,7 @@ class SD1ClipModel(torch.nn.Module): def encode_token_weights(self, token_weight_pairs): token_weight_pairs = token_weight_pairs[self.clip_name] - out = getattr(self, self.clip).encode_token_weights(token_weight_pairs) - return out + return getattr(self, self.clip).encode_token_weights(token_weight_pairs) def load_sd(self, sd): return getattr(self, self.clip).load_sd(sd) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6a2cc75a..540cfe72 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -51,8 +51,7 @@ class SD15(supported_models_base.BASE): replace_prefix = {} replace_prefix["cond_stage_model."] = "clip_l." - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - return state_dict + return utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) def process_clip_state_dict_for_saving(self, state_dict): pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] @@ -97,15 +96,13 @@ class SD20(supported_models_base.BASE): replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format replace_prefix["cond_stage_model.model."] = "clip_h." state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") - return state_dict + return utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} replace_prefix["clip_h"] = "cond_stage_model.model" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) - state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) - return state_dict + return diffusers_convert.convert_text_enc_state_dict_v20(state_dict) def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel) @@ -158,8 +155,7 @@ class SDXLRefiner(supported_models_base.BASE): state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") - state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) - return state_dict + return utils.state_dict_key_replace(state_dict, keys_to_replace) def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} @@ -167,8 +163,7 @@ class SDXLRefiner(supported_models_base.BASE): if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") replace_prefix["clip_g"] = "conditioner.embedders.0.model" - state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) - return state_dict_g + return utils.state_dict_prefix_replace(state_dict_g, replace_prefix) def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) @@ -221,8 +216,7 @@ class SDXL(supported_models_base.BASE): state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) - state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") - return state_dict + return utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} @@ -239,8 +233,7 @@ class SDXL(supported_models_base.BASE): replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" - state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) - return state_dict_g + return utils.state_dict_prefix_replace(state_dict_g, replace_prefix) def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) @@ -310,8 +303,7 @@ class SVD_img2vid(supported_models_base.BASE): sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002} def get_model(self, state_dict, prefix="", device=None): - out = model_base.SVD_img2vid(self, device=device) - return out + return model_base.SVD_img2vid(self, device=device) def clip_target(self, state_dict={}): return None @@ -331,8 +323,7 @@ class SV3D_u(SVD_img2vid): vae_key_prefix = ["conditioner.embedders.1.encoder."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.SV3D_u(self, device=device) - return out + return model_base.SV3D_u(self, device=device) class SV3D_p(SV3D_u): unet_config = { @@ -348,8 +339,7 @@ class SV3D_p(SV3D_u): def get_model(self, state_dict, prefix="", device=None): - out = model_base.SV3D_p(self, device=device) - return out + return model_base.SV3D_p(self, device=device) class Stable_Zero123(supported_models_base.BASE): unet_config = { @@ -376,8 +366,7 @@ class Stable_Zero123(supported_models_base.BASE): latent_format = latent_formats.SD15 def get_model(self, state_dict, prefix="", device=None): - out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) - return out + return model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) def clip_target(self, state_dict={}): return None @@ -407,8 +396,7 @@ class SD_X4Upscaler(SD20): } def get_model(self, state_dict, prefix="", device=None): - out = model_base.SD_X4Upscaler(self, device=device) - return out + return model_base.SD_X4Upscaler(self, device=device) class Stable_Cascade_C(supported_models_base.BASE): unet_config = { @@ -450,8 +438,7 @@ class Stable_Cascade_C(supported_models_base.BASE): return state_dict def get_model(self, state_dict, prefix="", device=None): - out = model_base.StableCascade_C(self, device=device) - return out + return model_base.StableCascade_C(self, device=device) def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) @@ -473,8 +460,7 @@ class Stable_Cascade_B(Stable_Cascade_C): clip_vision_prefix = None def get_model(self, state_dict, prefix="", device=None): - out = model_base.StableCascade_B(self, device=device) - return out + return model_base.StableCascade_B(self, device=device) class SD15_instructpix2pix(SD15): unet_config = { @@ -521,8 +507,7 @@ class SD3(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.SD3(self, device=device) - return out + return model_base.SD3(self, device=device) def clip_target(self, state_dict={}): clip_l = False @@ -587,8 +572,7 @@ class AuraFlow(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.AuraFlow(self, device=device) - return out + return model_base.AuraFlow(self, device=device) def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model) @@ -648,8 +632,7 @@ class HunyuanDiT(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.HunyuanDiT(self, device=device) - return out + return model_base.HunyuanDiT(self, device=device) def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel) @@ -686,8 +669,7 @@ class Flux(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.Flux(self, device=device) - return out + return model_base.Flux(self, device=device) def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] @@ -715,8 +697,7 @@ class FluxSchnell(Flux): } def get_model(self, state_dict, prefix="", device=None): - out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) - return out + return model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) class GenmoMochi(supported_models_base.BASE): unet_config = { @@ -739,8 +720,7 @@ class GenmoMochi(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.GenmoMochi(self, device=device) - return out + return model_base.GenmoMochi(self, device=device) def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] @@ -767,8 +747,7 @@ class LTXV(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.LTXV(self, device=device) - return out + return model_base.LTXV(self, device=device) def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] @@ -795,8 +774,7 @@ class HunyuanVideo(supported_models_base.BASE): text_encoder_key_prefix = ["text_encoders."] def get_model(self, state_dict, prefix="", device=None): - out = model_base.HunyuanVideo(self, device=device) - return out + return model_base.HunyuanVideo(self, device=device) def process_unet_state_dict(self, state_dict): out_sd = {} diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb..0743157a 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -87,8 +87,7 @@ class BASE: return out def process_clip_state_dict(self, state_dict): - state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) - return state_dict + return utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) def process_unet_state_dict(self, state_dict): return state_dict diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 10ea18e3..4364249b 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -60,8 +60,7 @@ class Downsample(nn.Module): padding = [x.shape[2] % 2, x.shape[3] % 2] self.op.padding = padding - x = self.op(x) - return x + return self.op(x) class ResnetBlock(nn.Module): @@ -196,8 +195,7 @@ class ResidualAttentionBlock(nn.Module): def forward(self, x: torch.Tensor): x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x + return x + self.mlp(self.ln_2(x)) class StyleAdapter(nn.Module): @@ -224,9 +222,8 @@ class StyleAdapter(nn.Module): x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_post(x[:, -self.num_token:, :]) - x = x @ self.proj + return x @ self.proj - return x class ResnetBlock_light(nn.Module): @@ -262,9 +259,8 @@ class extractor(nn.Module): x = self.down_opt(x) x = self.in_conv(x) x = self.body(x) - x = self.out_conv(x) + return self.out_conv(x) - return x class Adapter_light(nn.Module): diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index ce36f1a8..9077f6e6 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -72,8 +72,7 @@ class TAESD(nn.Module): def decode(self, x): x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale) - x_sample = x_sample.sub(0.5).mul(2) - return x_sample + return x_sample.sub(0.5).mul(2) def encode(self, x): return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index fc9bac1d..7be2b367 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -17,8 +17,7 @@ class BertAttention(torch.nn.Module): k = self.key(x) v = self.value(x) - out = optimized_attention(q, k, v, self.heads, mask) - return out + return optimized_attention(q, k, v, self.heads, mask) class BertOutput(torch.nn.Module): def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations): @@ -30,8 +29,7 @@ class BertOutput(torch.nn.Module): def forward(self, x, y): x = self.dense(x) # hidden_states = self.dropout(hidden_states) - x = self.LayerNorm(x + y) - return x + return self.LayerNorm(x + y) class BertAttentionBlock(torch.nn.Module): def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): @@ -100,8 +98,7 @@ class BertEmbeddings(torch.nn.Module): x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype) else: x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x) - x = self.LayerNorm(x) - return x + return self.LayerNorm(x) class BertModel_(torch.nn.Module): diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ad4b4623..0d1ef277 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -142,9 +142,8 @@ class TransformerBlock(nn.Module): residual = x x = self.post_attention_layernorm(x) x = self.mlp(x) - x = residual + x + return residual + x - return x class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index 38d8d523..942dc5a1 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -30,8 +30,7 @@ class T5DenseActDense(torch.nn.Module): def forward(self, x): x = self.act(self.wi(x)) # x = self.dropout(x) - x = self.wo(x) - return x + return self.wo(x) class T5DenseGatedActDense(torch.nn.Module): def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): @@ -47,8 +46,7 @@ class T5DenseGatedActDense(torch.nn.Module): hidden_linear = self.wi_1(x) x = hidden_gelu * hidden_linear # x = self.dropout(x) - x = self.wo(x) - return x + return self.wo(x) class T5LayerFF(torch.nn.Module): def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations): @@ -145,8 +143,7 @@ class T5Attention(torch.nn.Module): max_distance=self.relative_attention_max_distance, ) values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) - return values + return values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) def forward(self, x, mask=None, past_bias=None, optimized_attention=None): q = self.q(x) diff --git a/comfy/utils.py b/comfy/utils.py index 8e64dbe0..9cc291f2 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -292,8 +292,7 @@ def unet_to_diffusers(unet_config): def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight + return torch.cat([scale, shift], dim=0) MMDIT_MAP_BASIC = { ("context_embedder.bias", "context_embedder.bias"), @@ -982,8 +981,7 @@ def reshape_mask(input_mask, output_shape): mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode) if mask.shape[1] < output_shape[1]: mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]] - mask = repeat_to_batch_size(mask, output_shape[0]) - return mask + return repeat_to_batch_size(mask, output_shape[0]) def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): hi, wi = img_size_in @@ -1019,9 +1017,8 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t") # reassemble the mask from blocks - out = torch.cat([ + return torch.cat([ torch.cat([txt_to_txt, txt_to_img], dim=2), torch.cat([img_to_txt, img_to_img], dim=2)], dim=1 ) - return out diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index 8d856d0e..85f128b1 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps): new_xs = np.linspace(0, 1, num_steps) new_ys = np.interp(new_xs, xs, ys) - interped_ys = np.exp(new_ys)[::-1].copy() - return interped_ys + return np.exp(new_ys)[::-1].copy() NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 3cb918e0..e01a63d6 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -104,9 +104,8 @@ def create_vorbis_comment_block(comment_dict, last_block): id = b'\x84' else: id = b'\x04' - comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data + return id + struct.pack('>I', len(comment_data))[1:] + comment_data - return comment_block def insert_or_replace_vorbis_comment(flac_io, comment_dict): if len(comment_dict) == 0: diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 2f994fa1..8246bc23 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -150,8 +150,7 @@ class PorterDuffImageComposite: out_images.append(out_image) out_alphas.append(out_alpha.squeeze(2)) - result = (torch.stack(out_images), torch.stack(out_alphas)) - return result + return (torch.stack(out_images), torch.stack(out_alphas)) class SplitImageWithAlpha: @@ -170,8 +169,7 @@ class SplitImageWithAlpha: def split_image_with_alpha(self, image: torch.Tensor): out_images = [i[:,:,:3] for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] - result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) - return result + return (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) class JoinImageWithAlpha: @@ -196,8 +194,7 @@ class JoinImageWithAlpha: for i in range(batch_size): out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) - result = (torch.stack(out_images),) - return result + return (torch.stack(out_images),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py index 47b1dd04..72f82c02 100644 --- a/comfy_extras/nodes_gits.py +++ b/comfy_extras/nodes_gits.py @@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps): new_xs = np.linspace(0, 1, num_steps) new_ys = np.interp(new_xs, xs, ys) - interped_ys = np.exp(new_ys)[::-1].copy() - return interped_ys + return np.exp(new_ys)[::-1].copy() NOISE_LEVELS = { 0.80: [ diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 8fcdfba7..2db7cc40 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -27,8 +27,7 @@ class Mahiro: normm = torch.sqrt(merge.abs()) * merge.sign() sim = F.cosine_similarity(normu, normm).mean() simsc = 2 * (sim+1) - wm = (simsc*cfg + (4-simsc)*leap) / 4 - return wm + return (simsc*cfg + (4-simsc)*leap) / 4 m.set_model_sampler_post_cfg_function(mahiro_normd) return (m, ) diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 6c6f7176..7bee271f 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -11,8 +11,7 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos perp_neg = perp * neg_scale - cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) - return cfg_result + return noise_pred_nocond + cond_scale*(pos - perp_neg) #TODO: This node should be removed, it has been replaced with PerpNegGuider class PerpNeg: @@ -44,8 +43,7 @@ class PerpNeg: (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options) - cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale) - return cfg_result + return x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale) m.set_model_sampler_cfg_function(cfg_function) diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index d358ed6d..ffb90a6b 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -52,8 +52,7 @@ class FuseModule(nn.Module): stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds stacked_id_embeds = self.mlp2(stacked_id_embeds) - stacked_id_embeds = self.layer_norm(stacked_id_embeds) - return stacked_id_embeds + return self.layer_norm(stacked_id_embeds) def forward( self, @@ -86,8 +85,7 @@ class FuseModule(nn.Module): stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) - updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) - return updated_prompt_embeds + return prompt_embeds.view(batch_size, seq_length, -1) class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): def __init__(self): @@ -111,9 +109,8 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) - updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + return self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) - return updated_prompt_embeds class PhotoMakerLoader: diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 68f6ef51..f48a0c11 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -162,8 +162,7 @@ class Quantize: result = result.to(dtype=torch.uint8) im = Image.fromarray(result.cpu().numpy()) - im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) - return im + return im.quantize(palette=pal_im, dither=Image.Dither.NONE) def quantize(self, image: torch.Tensor, colors: int, dither: str): batch_size, height, width, _ = image.shape diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index e29cb9ed..cadc77fd 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -50,8 +50,7 @@ class LatentRebatch: def cat_batch(batch1, batch2): if batch1[0] is None: return batch2 - result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] - return result + return [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] def rebatch(self, latents, batch_size): batch_size = batch_size[0] diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 1bd8d736..5c9bbdc1 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -82,8 +82,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): mask = F.interpolate(mask, (lh, lw)) blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) - blurred = blurred * mask + x0 * (1 - mask) - return blurred + return blurred * mask + x0 * (1 - mask) def gaussian_blur_2d(img, kernel_size, sigma): ksize_half = (kernel_size - 1) * 0.5 @@ -101,8 +100,7 @@ def gaussian_blur_2d(img, kernel_size, sigma): padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] img = F.pad(img, padding, mode="reflect") - img = F.conv2d(img, kernel2d, groups=img.shape[-3]) - return img + return F.conv2d(img, kernel2d, groups=img.shape[-3]) class SelfAttentionGuidance: @classmethod diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index be2e34c2..02430a0b 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -5,7 +5,7 @@ import comfy.utils def camera_embeddings(elevation, azimuth): elevation = torch.as_tensor([elevation]) azimuth = torch.as_tensor([azimuth]) - embeddings = torch.stack( + return torch.stack( [ torch.deg2rad( (90 - elevation) - (90) @@ -17,7 +17,6 @@ def camera_embeddings(elevation, azimuth): ), ], dim=-1).unsqueeze(1) - return embeddings class StableZero123_Conditioning: diff --git a/execution.py b/execution.py index 2c979205..4ce174fd 100644 --- a/execution.py +++ b/execution.py @@ -81,11 +81,10 @@ class CacheSet: self.objects = HierarchicalCache(CacheKeySetID) def recursive_debug_dump(self): - result = { + return { "outputs": self.outputs.recursive_debug_dump(), "ui": self.ui.recursive_debug_dump(), } - return result def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() diff --git a/folder_paths.py b/folder_paths.py index 3542d2ed..d9712ac3 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -356,8 +356,7 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im input = input.replace("%day%", str(now.tm_mday).zfill(2)) input = input.replace("%hour%", str(now.tm_hour).zfill(2)) input = input.replace("%minute%", str(now.tm_min).zfill(2)) - input = input.replace("%second%", str(now.tm_sec).zfill(2)) - return input + return input.replace("%second%", str(now.tm_sec).zfill(2)) if "%" in filename_prefix: filename_prefix = compute_vars(filename_prefix, image_width, image_height) diff --git a/nodes.py b/nodes.py index 45686fc7..3de5b1c6 100644 --- a/nodes.py +++ b/nodes.py @@ -607,8 +607,7 @@ class unCLIPCheckpointLoader: def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - return out + return comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) class CLIPSetLastLayer: @classmethod diff --git a/ruff.toml b/ruff.toml index 660831f2..017f1203 100644 --- a/ruff.toml +++ b/ruff.toml @@ -9,6 +9,7 @@ lint.select = [ # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", + "RET504", ] exclude = ["*.ipynb"] diff --git a/tests/compare/test_quality.py b/tests/compare/test_quality.py index 01c19054..115916d9 100644 --- a/tests/compare/test_quality.py +++ b/tests/compare/test_quality.py @@ -129,8 +129,7 @@ class TestCompareImageMetrics: def read_img(self, filename: str) -> np.ndarray: cvImg = imread(filename) - cvImg = cvtColor(cvImg, COLOR_BGR2RGB) - return cvImg + return cvtColor(cvImg, COLOR_BGR2RGB) def image_grid(self, img_list: list[list[Image.Image]]): # imgs is a 2D list of images @@ -154,8 +153,7 @@ class TestCompareImageMetrics: with open(metrics_output_file, 'r') as f: for line in f: if fname_basestr in line: - score = float(line.split('|')[5]) - return score + return float(line.split('|')[5]) raise ValueError(f"Could not find score for {fname} in {metrics_output_file}") def gather_file_basenames(self, directory: str): diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py index ba943be6..f590ab86 100644 --- a/tests/inference/testing_nodes/testing-pack/flow_control.py +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -141,14 +141,13 @@ class TestExecutionBlockerNode: @classmethod def INPUT_TYPES(cls): - inputs = { + return { "required": { "input": ("*",), "block": ("BOOLEAN",), "verbose": ("BOOLEAN", {"default": False}), }, } - return inputs RETURN_TYPES = ("*",) RETURN_NAMES = ("output",)