applied RUFF RET504 rule (unnecessary assignment before return statement)

Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
Alexander Piskun 2025-01-01 12:45:55 +02:00 committed by bigcat88
parent 79eea51a1d
commit f4e86a4a07
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
81 changed files with 219 additions and 436 deletions

View File

@ -50,8 +50,7 @@ class ResBlockUnionControlnet(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x)) x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x)) return x + self.mlp(self.ln_2(x))
return x
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):
#implemented in the ldm unet #implemented in the ldm unet

View File

@ -36,8 +36,7 @@ class CLIPMLP(torch.nn.Module):
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
x = self.activation(x) x = self.activation(x)
x = self.fc2(x) return self.fc2(x)
return x
class CLIPLayer(torch.nn.Module): class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):

View File

@ -270,5 +270,4 @@ class CheckLazyMixin:
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
""" """
need = [name for name in kwargs if kwargs[name] is None] return [name for name in kwargs if kwargs[name] is None]
return need

View File

@ -404,8 +404,7 @@ class ControlLora(ControlNet):
super().cleanup() super().cleanup()
def get_models(self): def get_models(self):
out = ControlBase.get_models(self) return ControlBase.get_models(self)
return out
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
@ -461,8 +460,7 @@ def load_controlnet_mmdit(sd, model_options={}):
latent_format = comfy.latent_formats.SD3() latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
class ControlNetSD35(ControlNet): class ControlNetSD35(ControlNet):
@ -536,8 +534,7 @@ def load_controlnet_sd35(sd, model_options={}):
elif depth_cnet: elif depth_cnet:
preprocess_image = lambda a: 1.0 - a preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image) return ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
@ -549,16 +546,14 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
latent_format = comfy.latent_formats.SDXL() latent_format = comfy.latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img'] extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd) control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd, model_options={}): def load_controlnet_flux_instantx(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
@ -581,8 +576,7 @@ def load_controlnet_flux_instantx(sd, model_options={}):
latent_format = comfy.latent_formats.Flux() latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd): def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@ -738,8 +732,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = model_options.get("global_average_pooling", False) global_average_pooling = model_options.get("global_average_pooling", False)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet(ckpt_path, model=None, model_options={}): def load_controlnet(ckpt_path, model=None, model_options={}):
if "global_average_pooling" not in model_options: if "global_average_pooling" not in model_options:

View File

@ -99,8 +99,7 @@ def convert_unet_state_dict(unet_state_dict):
for sd_part, hf_part in unet_conversion_map_layer: for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part) v = v.replace(hf_part, sd_part)
mapping[k] = v mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} return {v: unet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================# # ================#

View File

@ -136,8 +136,7 @@ class NoiseScheduleVP:
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == 'cosine': elif self.schedule == 'cosine':
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 return log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
def marginal_alpha(self, t): def marginal_alpha(self, t):
""" """
@ -174,8 +173,7 @@ class NoiseScheduleVP:
else: else:
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
t = t_fn(log_alpha) return t_fn(log_alpha)
return t
def model_wrapper( def model_wrapper(
@ -378,8 +376,7 @@ class UniPC:
p = self.dynamic_thresholding_ratio p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s return torch.clamp(x0, -s, s) / s
return x0
def noise_prediction_fn(self, x, t): def noise_prediction_fn(self, x, t):
""" """
@ -423,8 +420,7 @@ class UniPC:
return torch.linspace(t_T, t_0, N + 1).to(device) return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic': elif skip_type == 'time_quadratic':
t_order = 2 t_order = 2
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) return torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
return t
else: else:
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
@ -801,8 +797,7 @@ def interpolate_fn(x, xp, yp):
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims): def expand_dims(v, dims):

View File

@ -78,10 +78,9 @@ class GatedCrossAttentionDense(nn.Module):
x = x + self.scale * \ x = x + self.scale * \
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
x = x + self.scale * \ return x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense(nn.Module): class GatedSelfAttentionDense(nn.Module):
@ -118,10 +117,9 @@ class GatedSelfAttentionDense(nn.Module):
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
x = x + self.scale * \ return x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense2(nn.Module): class GatedSelfAttentionDense2(nn.Module):
@ -172,10 +170,9 @@ class GatedSelfAttentionDense2(nn.Module):
# add residual to visual feature # add residual to visual feature
x = x + self.scale * torch.tanh(self.alpha_attn) * residual x = x + self.scale * torch.tanh(self.alpha_attn) * residual
x = x + self.scale * \ return x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class FourierEmbedder(): class FourierEmbedder():
@ -340,5 +337,4 @@ def load_gligen(sd):
w.position_net = PositionNet(in_dim, out_dim) w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False) w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim) return Gligen(output_list, w.position_net, key_dim)
return gligen

View File

@ -46,8 +46,7 @@ def cal_intergrand(beta_0, beta_1, taus):
log_alpha = alpha.log() log_alpha = alpha.log()
log_alpha.sum().backward() log_alpha.sum().backward()
d_log_alpha_dtau = taus.grad d_log_alpha_dtau = taus.grad
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha)) return -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
return integrand
#---------------------------------------------------------------------------- #----------------------------------------------------------------------------

View File

@ -50,8 +50,7 @@ def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
x = torch.linspace(0, 1, n, device=device) x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max) clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon) lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb)) return clamp(torch.exp(lmb))
return sigmas

View File

@ -70,9 +70,8 @@ class SnakeBeta(nn.Module):
if self.alpha_logscale: if self.alpha_logscale:
alpha = torch.exp(alpha) alpha = torch.exp(alpha)
beta = torch.exp(beta) beta = torch.exp(beta)
x = snake_beta(x, alpha, beta) return snake_beta(x, alpha, beta)
return x
def WNConv1d(*args, **kwargs): def WNConv1d(*args, **kwargs):
try: try:

View File

@ -89,8 +89,7 @@ class AbsolutePositionalEmbedding(nn.Module):
pos = (pos - seq_start_pos[..., None]).clamp(min = 0) pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos) pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale return pos_emb * self.scale
return pos_emb
class ScaledSinusoidalEmbedding(nn.Module): class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta = 10000): def __init__(self, dim, theta = 10000):
@ -404,9 +403,8 @@ class ConformerModule(nn.Module):
x = self.swish(x) x = self.swish(x)
x = rearrange(x, 'b n d -> b d n') x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x) x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d') return rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__( def __init__(

View File

@ -21,8 +21,7 @@ class LearnedPositionalEmbedding(nn.Module):
x = rearrange(x, "b -> b 1") x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1) return torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential( return nn.Sequential(
@ -49,8 +48,7 @@ class NumberEmbedder(nn.Module):
shape = x.shape shape = x.shape
x = rearrange(x, "... -> (...)") x = rearrange(x, "... -> (...)")
embedding = self.embedding(x) embedding = self.embedding(x)
x = embedding.view(*shape, self.features) return embedding.view(*shape, self.features)
return x # type: ignore
class Conditioner(nn.Module): class Conditioner(nn.Module):

View File

@ -36,8 +36,7 @@ class MLP(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x) x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x) return self.c_proj(x)
return x
class MultiHeadLayerNorm(nn.Module): class MultiHeadLayerNorm(nn.Module):
@ -95,8 +94,7 @@ class SingleAttention(nn.Module):
q, k = self.q_norm1(q), self.k_norm1(k) q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True) output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
c = self.w1o(output) return self.w1o(output)
return c
@ -265,9 +263,8 @@ class DiTBlock(nn.Module):
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp)) mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout cx = gate_mlp.unsqueeze(1) * mlpout
cx = cxres + cx return cxres + cx
return cx
@ -298,8 +295,7 @@ class TimestepEmbedder(nn.Module):
#@torch.compile() #@torch.compile()
def forward(self, t, dtype): def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq) return self.mlp(t_freq)
return t_emb
class MMDiT(nn.Module): class MMDiT(nn.Module):
@ -401,8 +397,7 @@ class MMDiT(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x) x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) return x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def patchify(self, x): def patchify(self, x):
B, C, H, W = x.size() B, C, H, W = x.size()
@ -415,8 +410,7 @@ class MMDiT(nn.Module):
(W + 1) // self.patch_size, (W + 1) // self.patch_size,
self.patch_size, self.patch_size,
) )
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) return x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
def apply_pos_embeds(self, x, h, w): def apply_pos_embeds(self, x, h, w):
h = (h + 1) // self.patch_size h = (h + 1) // self.patch_size
@ -494,5 +488,4 @@ class MMDiT(nn.Module):
x = modulate(x, fshift, fscale) x = modulate(x, fshift, fscale)
x = self.final_linear(x) x = self.final_linear(x)
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w] return self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
return x

View File

@ -54,8 +54,7 @@ class Attention2D(nn.Module):
kv = torch.cat([x, kv], dim=1) kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0] # x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv) x = self.attn(x, kv, kv)
x = x.permute(0, 2, 1).view(*orig_shape) return x.permute(0, 2, 1).view(*orig_shape)
return x
def LayerNorm2d_op(operations): def LayerNorm2d_op(operations):
@ -116,8 +115,7 @@ class AttnBlock(nn.Module):
def forward(self, x, kv): def forward(self, x, kv):
kv = self.kv_mapper(kv) kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) return x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x
class FeedForwardBlock(nn.Module): class FeedForwardBlock(nn.Module):
@ -133,8 +131,7 @@ class FeedForwardBlock(nn.Module):
) )
def forward(self, x): def forward(self, x):
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):

View File

@ -157,9 +157,8 @@ class ResBlock(nn.Module):
x = x + self.depthwise[1](x_temp) * mods[2] x = x + self.depthwise[1](x_temp) * mods[2]
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] return x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x
class StageA(nn.Module): class StageA(nn.Module):
@ -218,8 +217,7 @@ class StageA(nn.Module):
def decode(self, x): def decode(self, x):
x = self.up_blocks(x) x = self.up_blocks(x)
x = self.out_block(x) return self.out_block(x)
return x
def forward(self, x, quantize=False): def forward(self, x, quantize=False):
qe, x, _, vq_loss = self.encode(x, quantize) qe, x, _, vq_loss = self.encode(x, quantize)
@ -251,5 +249,4 @@ class Discriminator(nn.Module):
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1) x = torch.cat([x, cond], dim=1)
x = self.shuffle(x) x = self.shuffle(x)
x = self.logits(x) return self.logits(x)
return x

View File

@ -170,8 +170,7 @@ class StageB(nn.Module):
if len(clip.shape) == 2: if len(clip.shape) == 2:
clip = clip.unsqueeze(1) clip = clip.unsqueeze(1)
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
clip = self.clip_norm(clip) return self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip): def _down_encode(self, x, r_embed, clip):
level_outputs = [] level_outputs = []

View File

@ -179,8 +179,7 @@ class StageC(nn.Module):
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
clip = self.clip_norm(clip) return self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None): def _down_encode(self, x, r_embed, clip, cnet=None):
level_outputs = [] level_outputs = []

View File

@ -35,8 +35,7 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x): def forward(self, x):
x = x * 0.5 + 0.5 x = x * 0.5 + 0.5
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
o = self.mapper(self.backbone(x)) return self.mapper(self.backbone(x))
return o
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192

View File

@ -256,5 +256,4 @@ class LastLayer(nn.Module):
def forward(self, x: Tensor, vec: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x) return self.linear(x)
return x

View File

@ -9,8 +9,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
heads = q.shape[1] heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) return optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor:

View File

@ -183,8 +183,7 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape bs, c, h, w = x.shape

View File

@ -21,5 +21,4 @@ class ReduxImageEncoder(torch.nn.Module):
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor: def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds))) return self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x

View File

@ -34,9 +34,8 @@ import comfy.ops
def modulated_rmsnorm(x, scale, eps=1e-6): def modulated_rmsnorm(x, scale, eps=1e-6):
# Normalize and modulate # Normalize and modulate
x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps) x_normed = comfy.ldm.common_dit.rms_norm(x, eps=eps)
x_modulated = x_normed * (1 + scale.unsqueeze(1)) return x_normed * (1 + scale.unsqueeze(1))
return x_modulated
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6): def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
@ -47,9 +46,8 @@ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate x_normed = comfy.ldm.common_dit.rms_norm(x_res, eps=eps) * tanh_gate
# Apply residual connection # Apply residual connection
output = x + x_normed return x + x_normed
return output
class AsymmetricAttention(nn.Module): class AsymmetricAttention(nn.Module):
def __init__( def __init__(
@ -275,14 +273,12 @@ class AsymmetricJointBlock(nn.Module):
def ff_block_x(self, x, scale_x, gate_x): def ff_block_x(self, x, scale_x, gate_x):
x_mod = modulated_rmsnorm(x, scale_x) x_mod = modulated_rmsnorm(x, scale_x)
x_res = self.mlp_x(x_mod) x_res = self.mlp_x(x_mod)
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm return residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
return x
def ff_block_y(self, y, scale_y, gate_y): def ff_block_y(self, y, scale_y, gate_y):
y_mod = modulated_rmsnorm(y, scale_y) y_mod = modulated_rmsnorm(y, scale_y)
y_res = self.mlp_y(y_mod) y_res = self.mlp_y(y_mod)
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm return residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
return y
class FinalLayer(nn.Module): class FinalLayer(nn.Module):
@ -312,8 +308,7 @@ class FinalLayer(nn.Module):
c = F.silu(c) c = F.silu(c)
shift, scale = self.mod(c).chunk(2, dim=1) shift, scale = self.mod(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) return self.linear(x)
return x
class AsymmDiTJoint(nn.Module): class AsymmDiTJoint(nn.Module):

View File

@ -64,8 +64,7 @@ class TimestepEmbedder(nn.Module):
if self.timestep_scale is not None: if self.timestep_scale is not None:
t = t * self.timestep_scale t = t * self.timestep_scale
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype) t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=out_dtype)
t_emb = self.mlp(t_freq) return self.mlp(t_freq)
return t_emb
class FeedForward(nn.Module): class FeedForward(nn.Module):
@ -93,8 +92,7 @@ class FeedForward(nn.Module):
def forward(self, x): def forward(self, x):
x, gate = self.w1(x).chunk(2, dim=-1) x, gate = self.w1(x).chunk(2, dim=-1)
x = self.w2(F.silu(x) * gate) return self.w2(F.silu(x) * gate)
return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
@ -149,8 +147,7 @@ class PatchEmbed(nn.Module):
raise NotImplementedError("Must flatten output.") raise NotImplementedError("Must flatten output.")
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T) x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
x = self.norm(x) return self.norm(x)
return x
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):

View File

@ -60,9 +60,8 @@ def create_position_matrix(
# Stack and reshape the grids. # Stack and reshape the grids.
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3] pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
pos = pos.view(-1, 3) # [T * pH * pW, 3] pos = pos.view(-1, 3) # [T * pH * pW, 3]
pos = pos.to(dtype=dtype, device=device) return pos.to(dtype=dtype, device=device)
return pos
def compute_mixed_rotation( def compute_mixed_rotation(

View File

@ -30,5 +30,4 @@ def apply_rotary_emb_qk_real(
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
# Interleave the results back into the original shape # Interleave the results back into the original shape
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) return torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
return out

View File

@ -29,8 +29,7 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype) mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim) return (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class AttentionPool(nn.Module): class AttentionPool(nn.Module):
@ -98,5 +97,4 @@ class AttentionPool(nn.Module):
# Concatenate heads and run output. # Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x) return self.to_out(x)
return x

View File

@ -99,8 +99,7 @@ class Conv1x1(ops.Linear):
""" """
x = x.movedim(1, -1) x = x.movedim(1, -1)
x = super().forward(x) x = super().forward(x)
x = x.movedim(-1, 1) return x.movedim(-1, 1)
return x
class DepthToSpaceTime(nn.Module): class DepthToSpaceTime(nn.Module):
@ -269,8 +268,7 @@ class Attention(nn.Module):
assert x.size(0) == q.size(0) assert x.size(0) == q.size(0)
x = self.out(x) x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W) return rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
@ -321,8 +319,7 @@ class CausalUpsampleBlock(nn.Module):
def forward(self, x): def forward(self, x):
x = self.blocks(x) x = self.blocks(x)
x = self.proj(x) x = self.proj(x)
x = self.d2st(x) return self.d2st(x)
return x
def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs): def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):

View File

@ -86,8 +86,7 @@ class TokenRefinerBlock(nn.Module):
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True) attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1) x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1) return x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
return x
class IndividualTokenRefiner(nn.Module): class IndividualTokenRefiner(nn.Module):
@ -157,8 +156,7 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype)) c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x) x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask) return self.individual_token_refiner(x, c, mask)
return x
class HunyuanVideo(nn.Module): class HunyuanVideo(nn.Module):
""" """
@ -311,8 +309,7 @@ class HunyuanVideo(nn.Module):
shape[i] = shape[i] // self.patch_size[i] shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape) return img.reshape(initial_shape)
return img
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
@ -326,5 +323,4 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) return self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
return out

View File

@ -166,9 +166,8 @@ class CrossAttention(nn.Module):
out = self.out_proj(context) # context.reshape - B, L1, -1 out = self.out_proj(context) # context.reshape - B, L1, -1
out = self.proj_drop(out) out = self.proj_drop(out)
out_tuple = (out,) return (out,)
return out_tuple
class Attention(nn.Module): class Attention(nn.Module):
@ -213,6 +212,5 @@ class Attention(nn.Module):
x = self.out_proj(x) x = self.out_proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
out_tuple = (x,) return (x,)
return out_tuple

View File

@ -19,8 +19,7 @@ def calc_rope(x, patch_size, head_size):
sub_args = [start, stop, (th, tw)] sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args) rope = get_2d_rotary_pos_embed(head_size, *sub_args)
rope = (rope[0].to(x), rope[1].to(x)) return (rope[0].to(x), rope[1].to(x))
return rope
def modulate(x, shift, scale): def modulate(x, shift, scale):
@ -110,9 +109,8 @@ class HunYuanDiTBlock(nn.Module):
# FFN Layer # FFN Layer
mlp_inputs = self.norm2(x) mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs) return x + self.mlp(mlp_inputs)
return x
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None): def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
@ -136,8 +134,7 @@ class FinalLayer(nn.Module):
def forward(self, x, c): def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) return self.linear(x)
return x
class HunYuanDiT(nn.Module): class HunYuanDiT(nn.Module):
@ -413,5 +410,4 @@ class HunYuanDiT(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x) x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) return x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@ -53,8 +53,7 @@ def get_meshgrid(start, *args):
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H] return np.stack(grid, axis=0) # [2, W, H]
return grid
################################################################################# #################################################################################
# Sine/Cosine Positional Embedding Functions # # Sine/Cosine Positional Embedding Functions #
@ -87,8 +86,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
@ -108,8 +106,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
emb_sin = np.sin(out) # (M, D/2) emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
################################################################################# #################################################################################
@ -138,8 +135,7 @@ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
""" """
grid = get_meshgrid(start, *args) # [2, H, w] grid = get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) return get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
@ -154,8 +150,7 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin return cos, sin
else: else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) return torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
@ -187,8 +182,7 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
else: else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis

View File

@ -122,14 +122,13 @@ class Timesteps(nn.Module):
self.scale = scale self.scale = scale
def forward(self, timesteps): def forward(self, timesteps):
t_emb = get_timestep_embedding( return get_timestep_embedding(
timesteps, timesteps,
self.num_channels, self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos, flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift, downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale, scale=self.scale,
) )
return t_emb
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
@ -149,8 +148,7 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) return self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
class AdaLayerNormSingle(nn.Module): class AdaLayerNormSingle(nn.Module):
@ -209,8 +207,7 @@ class PixArtAlphaTextProjection(nn.Module):
def forward(self, caption): def forward(self, caption):
hidden_states = self.linear_1(caption) hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states) hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states) return self.linear_2(hidden_states)
return hidden_states
class GELU_approx(nn.Module): class GELU_approx(nn.Module):
@ -247,9 +244,8 @@ def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and
t_dup = torch.stack((-t2, t1), dim=-1) t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs return input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
@ -316,14 +312,13 @@ class BasicTransformerBlock(nn.Module):
return x return x
def get_fractional_positions(indices_grid, max_pos): def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack( return torch.stack(
[ [
indices_grid[:, i] / max_pos[i] indices_grid[:, i] / max_pos[i]
for i in range(3) for i in range(3)
], ],
dim=-1, dim=-1,
) )
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):

View File

@ -65,8 +65,7 @@ class Patchifier(ABC):
scale = scale_grid[i] scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) return rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier): class SymmetricPatchifier(Patchifier):
@ -74,14 +73,13 @@ class SymmetricPatchifier(Patchifier):
self, self,
latents: Tensor, latents: Tensor,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
latents = rearrange( return rearrange(
latents, latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0], p1=self._patch_size[0],
p2=self._patch_size[1], p2=self._patch_size[1],
p3=self._patch_size[2], p3=self._patch_size[2],
) )
return latents
def unpatchify( def unpatchify(
self, self,
@ -93,7 +91,7 @@ class SymmetricPatchifier(Patchifier):
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1] output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2] output_width = output_width // self._patch_size[2]
latents = rearrange( return rearrange(
latents, latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ", "b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames, f=output_num_frames,
@ -102,4 +100,3 @@ class SymmetricPatchifier(Patchifier):
p=self._patch_size[1], p=self._patch_size[1],
q=self._patch_size[2], q=self._patch_size[2],
) )
return latents

View File

@ -56,8 +56,7 @@ class CausalConv3d(nn.Module):
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1) (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
) )
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x) return self.conv(x)
return x
@property @property
def weight(self): def weight(self):

View File

@ -417,9 +417,8 @@ class Decoder(nn.Module):
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal) sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) return unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module): class UNetMidBlock3D(nn.Module):
@ -563,8 +562,7 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c") x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x) x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w") return rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module): class ResnetBlock3D(nn.Module):
@ -677,9 +675,8 @@ class ResnetBlock3D(nn.Module):
# similar to the "explicit noise inputs" method in style-gan # similar to the "explicit noise inputs" method in style-gan
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
hidden_states = hidden_states + scaled_noise return hidden_states + scaled_noise
return hidden_states
def forward( def forward(
self, self,
@ -740,9 +737,8 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor) input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states return input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1): def patchify(x, patch_size_hw, patch_size_t=1):

View File

@ -114,7 +114,7 @@ class DualConv3d(nn.Module):
return x return x
# Second convolution # Second convolution
x = F.conv3d( return F.conv3d(
x, x,
self.weight2, self.weight2,
self.bias2, self.bias2,
@ -124,7 +124,6 @@ class DualConv3d(nn.Module):
self.groups, self.groups,
) )
return x
def forward_with_2d(self, x, skip_time_conv): def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape b, c, d, h, w = x.shape
@ -142,8 +141,7 @@ class DualConv3d(nn.Module):
_, _, h, w = x.shape _, _, h, w = x.shape
if skip_time_conv: if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b) return rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
@ -155,9 +153,8 @@ class DualConv3d(nn.Module):
padding2 = self.padding2[0] padding2 = self.padding2[0]
dilation2 = self.dilation2[0] dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) return rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property @property
def weight(self): def weight(self):

View File

@ -136,8 +136,7 @@ class AutoencodingEngine(AbstractAutoencoder):
return z return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs) return self.decoder(z, **kwargs)
return x
def forward( def forward(
self, x: torch.Tensor, **additional_decode_kwargs self, x: torch.Tensor, **additional_decode_kwargs
@ -178,8 +177,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.embed_dim = embed_dim self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params() return super().get_autoencoder_params()
return params
def encode( def encode(
self, x: torch.Tensor, return_reg_log: bool = False self, x: torch.Tensor, return_reg_log: bool = False

View File

@ -142,13 +142,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
sim = sim.softmax(dim=-1) sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = ( return (
out.unsqueeze(0) out.unsqueeze(0)
.reshape(b, heads, -1, dim_head) .reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head) .reshape(b, -1, heads * dim_head)
) )
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
@ -216,8 +215,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) return hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
@ -326,13 +324,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
del q, k, v del q, k, v
r1 = ( return (
r1.unsqueeze(0) r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head) .reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head) .reshape(b, -1, heads * dim_head)
) )
return r1
BROKEN_XFORMERS = False BROKEN_XFORMERS = False
try: try:
@ -395,11 +392,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = ( return (
out.reshape(b, -1, heads * dim_head) out.reshape(b, -1, heads * dim_head)
) )
return out
if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue. if model_management.is_nvidia(): #pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2**15 SDP_BATCH_LIMIT = 2**15
@ -932,7 +928,6 @@ class SpatialVideoTransformer(SpatialTransformer):
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear: if not self.use_linear:
x = self.proj_out(x) x = self.proj_out(x)
out = x + x_in return x + x_in
return out

View File

@ -51,8 +51,7 @@ class Mlp(nn.Module):
x = self.drop1(x) x = self.drop1(x)
x = self.norm(x) x = self.norm(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop2(x) return self.drop2(x)
return x
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """ 2D Image to Patch Embedding
@ -103,8 +102,7 @@ class PatchEmbed(nn.Module):
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
x = self.norm(x) return self.norm(x)
return x
def modulate(x, shift, scale): def modulate(x, shift, scale):
if shift is None: if shift is None:
@ -155,8 +153,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
@ -176,8 +173,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
emb_sin = np.sin(out) # (M, D/2) emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32): def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
@ -187,8 +183,7 @@ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=t
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2) emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32): def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
small = min(h, w) small = min(h, w)
@ -197,8 +192,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude
grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij') grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) return torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
################################################################################# #################################################################################
@ -222,8 +216,7 @@ class TimestepEmbedder(nn.Module):
def forward(self, t, dtype, **kwargs): def forward(self, t, dtype, **kwargs):
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq) return self.mlp(t_freq)
return t_emb
class VectorEmbedder(nn.Module): class VectorEmbedder(nn.Module):
@ -240,8 +233,7 @@ class VectorEmbedder(nn.Module):
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
emb = self.mlp(x) return self.mlp(x)
return emb
################################################################################# #################################################################################
@ -307,16 +299,14 @@ class SelfAttention(nn.Module):
def post_attention(self, x: torch.Tensor) -> torch.Tensor: def post_attention(self, x: torch.Tensor) -> torch.Tensor:
assert not self.pre_only assert not self.pre_only
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) return self.proj_drop(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = self.pre_attention(x) q, k, v = self.pre_attention(x)
x = optimized_attention( x = optimized_attention(
q, k, v, heads=self.num_heads q, k, v, heads=self.num_heads
) )
x = self.post_attention(x) return self.post_attention(x)
return x
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
@ -530,10 +520,9 @@ class DismantledBlock(nn.Module):
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
assert not self.pre_only assert not self.pre_only
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + gate_mlp.unsqueeze(1) * self.mlp( return x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp) modulate(self.norm2(x), shift_mlp, scale_mlp)
) )
return x
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert self.x_block_self_attn assert self.x_block_self_attn
@ -568,10 +557,9 @@ class DismantledBlock(nn.Module):
out2 = gate_msa2.unsqueeze(1) * attn2 out2 = gate_msa2.unsqueeze(1) * attn2
x = x + out1 x = x + out1
x = x + out2 x = x + out2
x = x + gate_mlp.unsqueeze(1) * self.mlp( return x + gate_mlp.unsqueeze(1) * self.mlp(
modulate(self.norm2(x), shift_mlp, scale_mlp) modulate(self.norm2(x), shift_mlp, scale_mlp)
) )
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only assert not self.pre_only
@ -696,8 +684,7 @@ class FinalLayer(nn.Module):
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) return self.linear(x)
return x
class SelfAttentionContext(nn.Module): class SelfAttentionContext(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None): def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
@ -902,13 +889,12 @@ class MMDiT(nn.Module):
w=self.pos_embed_max_size, w=self.pos_embed_max_size,
) )
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") return rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
# print(spatial_pos_embed, top, left, h, w) # print(spatial_pos_embed, top, left, h, w)
# # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res # # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
# t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
# # print(t) # # print(t)
# return t # return t
return spatial_pos_embed
def unpatchify(self, x, hw=None): def unpatchify(self, x, hw=None):
""" """
@ -927,8 +913,7 @@ class MMDiT(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x) x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) return x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def forward_core_with_concat( def forward_core_with_concat(
self, self,
@ -976,8 +961,7 @@ class MMDiT(nn.Module):
if add is not None: if add is not None:
x += add x += add
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) return self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
return x
def forward( def forward(
self, self,

View File

@ -493,8 +493,7 @@ class Model(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) return self.conv_out(h)
return h
def get_last_layer(self): def get_last_layer(self):
return self.conv_out.weight return self.conv_out.weight
@ -602,8 +601,7 @@ class Encoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) return self.conv_out(h)
return h
class Decoder(nn.Module): class Decoder(nn.Module):

View File

@ -350,8 +350,7 @@ class VideoResBlock(ResBlock):
x = self.time_mixer( x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
) )
x = rearrange(x, "b c t h w -> (b t) c h w") return rearrange(x, "b c t h w -> (b t) c h w")
return x
class Timestep(nn.Module): class Timestep(nn.Module):

View File

@ -79,11 +79,10 @@ class AlphaBlender(nn.Module):
image_only_indicator=None, image_only_indicator=None,
) -> torch.Tensor: ) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator, x_spatial.device) alpha = self.get_alpha(image_only_indicator, x_spatial.device)
x = ( return (
alpha.to(x_spatial.dtype) * x_spatial alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
) )
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
@ -201,8 +200,7 @@ class CheckpointFunction(torch.autograd.Function):
"dtype": torch.get_autocast_gpu_dtype(), "dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()} "cache_enabled": torch.is_autocast_cache_enabled()}
with torch.no_grad(): with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors) return ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):

View File

@ -33,8 +33,7 @@ class DiagonalGaussianDistribution(object):
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
def sample(self): def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) return self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
return x
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:

View File

@ -15,13 +15,11 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x): def scale(self, x):
# re-normalize to centered mean and unit variance # re-normalize to centered mean and unit variance
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) return (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x
def unscale(self, x): def unscale(self, x):
# back to original data stats # back to original data stats
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x
def forward(self, x, noise_level=None, seed=None): def forward(self, x, noise_level=None, seed=None):
if noise_level is None: if noise_level is None:

View File

@ -177,8 +177,7 @@ def _get_attention_scores_no_kv_chunking(
attn_scores /= summed attn_scores /= summed
attn_probs = attn_scores attn_probs = attn_scores
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value) return torch.bmm(attn_probs.to(value.dtype), value)
return hidden_states_slice
class ScannedChunk(NamedTuple): class ScannedChunk(NamedTuple):
chunk_idx: int chunk_idx: int
@ -264,7 +263,7 @@ def efficient_dot_product_attention(
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices # and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([ return torch.cat([
compute_query_chunk_attn( compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size), query=get_query_chunk(i * query_chunk_size),
key_t=key_t, key_t=key_t,
@ -272,4 +271,3 @@ def efficient_dot_product_attention(
mask=get_mask_chunk(i * query_chunk_size) mask=get_mask_chunk(i * query_chunk_size)
) for i in range(math.ceil(q_tokens / query_chunk_size)) ) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1) ], dim=1)
return res

View File

@ -73,8 +73,7 @@ class MultiHeadCrossAttention(nn.Module):
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None) x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) return self.proj_drop(x)
return x
class AttentionKVCompress(nn.Module): class AttentionKVCompress(nn.Module):
@ -172,8 +171,7 @@ class AttentionKVCompress(nn.Module):
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True) x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
x = x.view(B, N, C) x = x.view(B, N, C)
x = self.proj(x) return self.proj(x)
return x
class FinalLayer(nn.Module): class FinalLayer(nn.Module):
@ -192,8 +190,7 @@ class FinalLayer(nn.Module):
def forward(self, x, c): def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) return self.linear(x)
return x
class T2IFinalLayer(nn.Module): class T2IFinalLayer(nn.Module):
""" """
@ -209,8 +206,7 @@ class T2IFinalLayer(nn.Module):
def forward(self, x, t): def forward(self, x, t):
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1) shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale) x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x) return self.linear(x)
return x
class MaskFinalLayer(nn.Module): class MaskFinalLayer(nn.Module):
@ -228,8 +224,7 @@ class MaskFinalLayer(nn.Module):
def forward(self, x, t): def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) return self.linear(x)
return x
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
@ -247,8 +242,7 @@ class DecoderLayer(nn.Module):
def forward(self, x, t): def forward(self, x, t):
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
x = modulate(self.norm_decoder(x), shift, scale) x = modulate(self.norm_decoder(x), shift, scale)
x = self.linear(x) return self.linear(x)
return x
class SizeEmbedder(TimestepEmbedder): class SizeEmbedder(TimestepEmbedder):
@ -276,8 +270,7 @@ class SizeEmbedder(TimestepEmbedder):
s = rearrange(s, "b d -> (b d)") s = rearrange(s, "b d -> (b d)")
s_freq = timestep_embedding(s, self.frequency_embedding_size) s_freq = timestep_embedding(s, self.frequency_embedding_size)
s_emb = self.mlp(s_freq.to(s.dtype)) s_emb = self.mlp(s_freq.to(s.dtype))
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) return rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return s_emb
class LabelEmbedder(nn.Module): class LabelEmbedder(nn.Module):
@ -299,15 +292,13 @@ class LabelEmbedder(nn.Module):
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
else: else:
drop_ids = force_drop_ids == 1 drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels) return torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None): def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0 use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None): if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids) labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels) return self.embedding_table(labels)
return embeddings
class CaptionEmbedder(nn.Module): class CaptionEmbedder(nn.Module):
@ -331,8 +322,7 @@ class CaptionEmbedder(nn.Module):
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
else: else:
drop_ids = force_drop_ids == 1 drop_ids = force_drop_ids == 1
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) return torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
return caption
def forward(self, caption, train, force_drop_ids=None): def forward(self, caption, train, force_drop_ids=None):
if train: if train:
@ -340,8 +330,7 @@ class CaptionEmbedder(nn.Module):
use_dropout = self.uncond_prob > 0 use_dropout = self.uncond_prob > 0
if (train and use_dropout) or (force_drop_ids is not None): if (train and use_dropout) or (force_drop_ids is not None):
caption = self.token_drop(caption, force_drop_ids) caption = self.token_drop(caption, force_drop_ids)
caption = self.y_proj(caption) return self.y_proj(caption)
return caption
class CaptionEmbedderDoubleBr(nn.Module): class CaptionEmbedderDoubleBr(nn.Module):

View File

@ -23,8 +23,7 @@ def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_si
) )
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) return torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
class PixArtMSBlock(nn.Module): class PixArtMSBlock(nn.Module):
""" """
@ -57,9 +56,8 @@ class PixArtMSBlock(nn.Module):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
x = x + self.cross_attn(x, y, mask) x = x + self.cross_attn(x, y, mask)
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) return x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
return x
### Core PixArt Model ### ### Core PixArt Model ###
@ -212,9 +210,8 @@ class PixArtMS(nn.Module):
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x, H, W) # (N, out_channels, H, W) return self.unpatchify(x, H, W) # (N, out_channels, H, W)
return x
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs): def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
B, C, H, W = x.shape B, C, H, W = x.shape
@ -252,5 +249,4 @@ class PixArtMS(nn.Module):
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x) x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) return x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@ -29,8 +29,7 @@ def log_txt_as_img(wh, xc, size=10):
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt) txts.append(txt)
txts = np.stack(txts) txts = np.stack(txts)
txts = torch.tensor(txts) return torch.tensor(txts)
return txts
def ismap(x): def ismap(x):

View File

@ -206,8 +206,7 @@ class BaseModel(torch.nn.Module):
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image": elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) return torch.cat(cond_concat, dim=1)
return data
return None return None
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
@ -418,8 +417,7 @@ class SVD_img2vid(BaseModel):
out.append(self.embedder(torch.Tensor([motion_bucket_id]))) out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation]))) out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) return torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
@ -457,8 +455,7 @@ class SV3D_u(SVD_img2vid):
out = [] out = []
out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) return torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
class SV3D_p(SVD_img2vid): class SV3D_p(SVD_img2vid):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):

View File

@ -874,8 +874,7 @@ class ModelPatcher:
all_models = self.get_additional_models() all_models = self.get_additional_models()
models_set = set(all_models) models_set = set(all_models)
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set) return _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
return real_all_models
def use_ejected(self, skip_and_inject_on_exit_only=False): def use_ejected(self, skip_and_inject_on_exit_only=False):
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only) return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)

View File

@ -286,8 +286,7 @@ class StableCascadeSampling(ModelSamplingDiscrete):
var = 1 / ((sigma * sigma) + 1) var = 1 / ((sigma * sigma) + 1)
var = var.clamp(0, 1.0) var = var.clamp(0, 1.0)
s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device)
t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s return (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return t
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:

View File

@ -21,8 +21,7 @@ def prepare_noise(latent_image, seed, noise_inds=None):
if i in unique_inds: if i in unique_inds:
noises.append(noise) noises.append(noise)
noises = [noises[i] for i in inverse] noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0) return torch.cat(noises, axis=0)
return noises
def fix_empty_latent_channels(model, latent_image): def fix_empty_latent_channels(model, latent_image):
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
@ -43,10 +42,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) return samples.to(comfy.model_management.intermediate_device())
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) return samples.to(comfy.model_management.intermediate_device())
return samples

View File

@ -713,8 +713,7 @@ class KSAMPLER(Sampler):
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples)
return samples
def ksampler(sampler_name, extra_options={}, inpaint_options={}): def ksampler(sampler_name, extra_options={}, inpaint_options={}):

View File

@ -422,12 +422,11 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output( return self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0) / 3.0)
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32): def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
@ -485,8 +484,7 @@ class VAE:
overlap = tile // 4 overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile

View File

@ -304,13 +304,11 @@ def token_weights(string, current_weight):
def escape_important(text): def escape_important(text):
text = text.replace("\\)", "\0\1") text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2") return text.replace("\\(", "\0\2")
return text
def unescape_important(text): def unescape_important(text):
text = text.replace("\0\1", ")") text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(") return text.replace("\0\2", "(")
return text
def safe_load_embed_zip(embed_path): def safe_load_embed_zip(embed_path):
with zipfile.ZipFile(embed_path) as myzip: with zipfile.ZipFile(embed_path) as myzip:
@ -635,8 +633,7 @@ class SD1ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name] token_weight_pairs = token_weight_pairs[self.clip_name]
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs) return getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out
def load_sd(self, sd): def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd) return getattr(self, self.clip).load_sd(sd)

View File

@ -51,8 +51,7 @@ class SD15(supported_models_base.BASE):
replace_prefix = {} replace_prefix = {}
replace_prefix["cond_stage_model."] = "clip_l." replace_prefix["cond_stage_model."] = "clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) return utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"] pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
@ -97,15 +96,13 @@ class SD20(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
replace_prefix["cond_stage_model.model."] = "clip_h." replace_prefix["cond_stage_model.model."] = "clip_h."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") return utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
replace_prefix["clip_h"] = "cond_stage_model.model" replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel) return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel)
@ -158,8 +155,7 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return utils.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
@ -167,8 +163,7 @@ class SDXLRefiner(supported_models_base.BASE):
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
replace_prefix["clip_g"] = "conditioner.embedders.0.model" replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel) return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
@ -221,8 +216,7 @@ class SDXL(supported_models_base.BASE):
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.") return utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
@ -239,8 +233,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0" replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
@ -310,8 +303,7 @@ class SVD_img2vid(supported_models_base.BASE):
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002} sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device) return model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return None return None
@ -331,8 +323,7 @@ class SV3D_u(SVD_img2vid):
vae_key_prefix = ["conditioner.embedders.1.encoder."] vae_key_prefix = ["conditioner.embedders.1.encoder."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_u(self, device=device) return model_base.SV3D_u(self, device=device)
return out
class SV3D_p(SV3D_u): class SV3D_p(SV3D_u):
unet_config = { unet_config = {
@ -348,8 +339,7 @@ class SV3D_p(SV3D_u):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.SV3D_p(self, device=device) return model_base.SV3D_p(self, device=device)
return out
class Stable_Zero123(supported_models_base.BASE): class Stable_Zero123(supported_models_base.BASE):
unet_config = { unet_config = {
@ -376,8 +366,7 @@ class Stable_Zero123(supported_models_base.BASE):
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) return model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return None return None
@ -407,8 +396,7 @@ class SD_X4Upscaler(SD20):
} }
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD_X4Upscaler(self, device=device) return model_base.SD_X4Upscaler(self, device=device)
return out
class Stable_Cascade_C(supported_models_base.BASE): class Stable_Cascade_C(supported_models_base.BASE):
unet_config = { unet_config = {
@ -450,8 +438,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
return state_dict return state_dict
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_C(self, device=device) return model_base.StableCascade_C(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
@ -473,8 +460,7 @@ class Stable_Cascade_B(Stable_Cascade_C):
clip_vision_prefix = None clip_vision_prefix = None
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.StableCascade_B(self, device=device) return model_base.StableCascade_B(self, device=device)
return out
class SD15_instructpix2pix(SD15): class SD15_instructpix2pix(SD15):
unet_config = { unet_config = {
@ -521,8 +507,7 @@ class SD3(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device) return model_base.SD3(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
clip_l = False clip_l = False
@ -587,8 +572,7 @@ class AuraFlow(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.AuraFlow(self, device=device) return model_base.AuraFlow(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model) return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
@ -648,8 +632,7 @@ class HunyuanDiT(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanDiT(self, device=device) return model_base.HunyuanDiT(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel) return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
@ -686,8 +669,7 @@ class Flux(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, device=device) return model_base.Flux(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
@ -715,8 +697,7 @@ class FluxSchnell(Flux):
} }
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) return model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
class GenmoMochi(supported_models_base.BASE): class GenmoMochi(supported_models_base.BASE):
unet_config = { unet_config = {
@ -739,8 +720,7 @@ class GenmoMochi(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.GenmoMochi(self, device=device) return model_base.GenmoMochi(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
@ -767,8 +747,7 @@ class LTXV(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device) return model_base.LTXV(self, device=device)
return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
@ -795,8 +774,7 @@ class HunyuanVideo(supported_models_base.BASE):
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo(self, device=device) return model_base.HunyuanVideo(self, device=device)
return out
def process_unet_state_dict(self, state_dict): def process_unet_state_dict(self, state_dict):
out_sd = {} out_sd = {}

View File

@ -87,8 +87,7 @@ class BASE:
return out return out
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) return utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict
def process_unet_state_dict(self, state_dict): def process_unet_state_dict(self, state_dict):
return state_dict return state_dict

View File

@ -60,8 +60,7 @@ class Downsample(nn.Module):
padding = [x.shape[2] % 2, x.shape[3] % 2] padding = [x.shape[2] % 2, x.shape[3] % 2]
self.op.padding = padding self.op.padding = padding
x = self.op(x) return self.op(x)
return x
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
@ -196,8 +195,7 @@ class ResidualAttentionBlock(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x)) x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x)) return x + self.mlp(self.ln_2(x))
return x
class StyleAdapter(nn.Module): class StyleAdapter(nn.Module):
@ -224,9 +222,8 @@ class StyleAdapter(nn.Module):
x = x.permute(1, 0, 2) # LND -> NLD x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, -self.num_token:, :]) x = self.ln_post(x[:, -self.num_token:, :])
x = x @ self.proj return x @ self.proj
return x
class ResnetBlock_light(nn.Module): class ResnetBlock_light(nn.Module):
@ -262,9 +259,8 @@ class extractor(nn.Module):
x = self.down_opt(x) x = self.down_opt(x)
x = self.in_conv(x) x = self.in_conv(x)
x = self.body(x) x = self.body(x)
x = self.out_conv(x) return self.out_conv(x)
return x
class Adapter_light(nn.Module): class Adapter_light(nn.Module):

View File

@ -72,8 +72,7 @@ class TAESD(nn.Module):
def decode(self, x): def decode(self, x):
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale) x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2) return x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x): def encode(self, x):
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift

View File

@ -17,8 +17,7 @@ class BertAttention(torch.nn.Module):
k = self.key(x) k = self.key(x)
v = self.value(x) v = self.value(x)
out = optimized_attention(q, k, v, self.heads, mask) return optimized_attention(q, k, v, self.heads, mask)
return out
class BertOutput(torch.nn.Module): class BertOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations): def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
@ -30,8 +29,7 @@ class BertOutput(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):
x = self.dense(x) x = self.dense(x)
# hidden_states = self.dropout(hidden_states) # hidden_states = self.dropout(hidden_states)
x = self.LayerNorm(x + y) return self.LayerNorm(x + y)
return x
class BertAttentionBlock(torch.nn.Module): class BertAttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
@ -100,8 +98,7 @@ class BertEmbeddings(torch.nn.Module):
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype) x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
else: else:
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x) x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
x = self.LayerNorm(x) return self.LayerNorm(x)
return x
class BertModel_(torch.nn.Module): class BertModel_(torch.nn.Module):

View File

@ -142,9 +142,8 @@ class TransformerBlock(nn.Module):
residual = x residual = x
x = self.post_attention_layernorm(x) x = self.post_attention_layernorm(x)
x = self.mlp(x) x = self.mlp(x)
x = residual + x return residual + x
return x
class Llama2_(nn.Module): class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):

View File

@ -30,8 +30,7 @@ class T5DenseActDense(torch.nn.Module):
def forward(self, x): def forward(self, x):
x = self.act(self.wi(x)) x = self.act(self.wi(x))
# x = self.dropout(x) # x = self.dropout(x)
x = self.wo(x) return self.wo(x)
return x
class T5DenseGatedActDense(torch.nn.Module): class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
@ -47,8 +46,7 @@ class T5DenseGatedActDense(torch.nn.Module):
hidden_linear = self.wi_1(x) hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear x = hidden_gelu * hidden_linear
# x = self.dropout(x) # x = self.dropout(x)
x = self.wo(x) return self.wo(x)
return x
class T5LayerFF(torch.nn.Module): class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations): def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
@ -145,8 +143,7 @@ class T5Attention(torch.nn.Module):
max_distance=self.relative_attention_max_distance, max_distance=self.relative_attention_max_distance,
) )
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads) values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x, mask=None, past_bias=None, optimized_attention=None): def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
q = self.q(x) q = self.q(x)

View File

@ -292,8 +292,7 @@ def unet_to_diffusers(unet_config):
def swap_scale_shift(weight): def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0) shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0) return torch.cat([scale, shift], dim=0)
return new_weight
MMDIT_MAP_BASIC = { MMDIT_MAP_BASIC = {
("context_embedder.bias", "context_embedder.bias"), ("context_embedder.bias", "context_embedder.bias"),
@ -982,8 +981,7 @@ def reshape_mask(input_mask, output_shape):
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode) mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]: if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]] mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = repeat_to_batch_size(mask, output_shape[0]) return repeat_to_batch_size(mask, output_shape[0])
return mask
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
hi, wi = img_size_in hi, wi = img_size_in
@ -1019,9 +1017,8 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t") img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
# reassemble the mask from blocks # reassemble the mask from blocks
out = torch.cat([ return torch.cat([
torch.cat([txt_to_txt, txt_to_img], dim=2), torch.cat([txt_to_txt, txt_to_img], dim=2),
torch.cat([img_to_txt, img_to_img], dim=2)], torch.cat([img_to_txt, img_to_img], dim=2)],
dim=1 dim=1
) )
return out

View File

@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps):
new_xs = np.linspace(0, 1, num_steps) new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys) new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy() return np.exp(new_ys)[::-1].copy()
return interped_ys
NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],

View File

@ -104,9 +104,8 @@ def create_vorbis_comment_block(comment_dict, last_block):
id = b'\x84' id = b'\x84'
else: else:
id = b'\x04' id = b'\x04'
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data return id + struct.pack('>I', len(comment_data))[1:] + comment_data
return comment_block
def insert_or_replace_vorbis_comment(flac_io, comment_dict): def insert_or_replace_vorbis_comment(flac_io, comment_dict):
if len(comment_dict) == 0: if len(comment_dict) == 0:

View File

@ -150,8 +150,7 @@ class PorterDuffImageComposite:
out_images.append(out_image) out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2)) out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas)) return (torch.stack(out_images), torch.stack(out_alphas))
return result
class SplitImageWithAlpha: class SplitImageWithAlpha:
@ -170,8 +169,7 @@ class SplitImageWithAlpha:
def split_image_with_alpha(self, image: torch.Tensor): def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image] out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) return (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
class JoinImageWithAlpha: class JoinImageWithAlpha:
@ -196,8 +194,7 @@ class JoinImageWithAlpha:
for i in range(batch_size): for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),) return (torch.stack(out_images),)
return result
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -12,8 +12,7 @@ def loglinear_interp(t_steps, num_steps):
new_xs = np.linspace(0, 1, num_steps) new_xs = np.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys) new_ys = np.interp(new_xs, xs, ys)
interped_ys = np.exp(new_ys)[::-1].copy() return np.exp(new_ys)[::-1].copy()
return interped_ys
NOISE_LEVELS = { NOISE_LEVELS = {
0.80: [ 0.80: [

View File

@ -27,8 +27,7 @@ class Mahiro:
normm = torch.sqrt(merge.abs()) * merge.sign() normm = torch.sqrt(merge.abs()) * merge.sign()
sim = F.cosine_similarity(normu, normm).mean() sim = F.cosine_similarity(normu, normm).mean()
simsc = 2 * (sim+1) simsc = 2 * (sim+1)
wm = (simsc*cfg + (4-simsc)*leap) / 4 return (simsc*cfg + (4-simsc)*leap) / 4
return wm
m.set_model_sampler_post_cfg_function(mahiro_normd) m.set_model_sampler_post_cfg_function(mahiro_normd)
return (m, ) return (m, )

View File

@ -11,8 +11,7 @@ def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, co
perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos
perp_neg = perp * neg_scale perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) return noise_pred_nocond + cond_scale*(pos - perp_neg)
return cfg_result
#TODO: This node should be removed, it has been replaced with PerpNegGuider #TODO: This node should be removed, it has been replaced with PerpNegGuider
class PerpNeg: class PerpNeg:
@ -44,8 +43,7 @@ class PerpNeg:
(noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options) (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options)
cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale) return x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale)
return cfg_result
m.set_model_sampler_cfg_function(cfg_function) m.set_model_sampler_cfg_function(cfg_function)

View File

@ -52,8 +52,7 @@ class FuseModule(nn.Module):
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
stacked_id_embeds = self.mlp2(stacked_id_embeds) stacked_id_embeds = self.mlp2(stacked_id_embeds)
stacked_id_embeds = self.layer_norm(stacked_id_embeds) return self.layer_norm(stacked_id_embeds)
return stacked_id_embeds
def forward( def forward(
self, self,
@ -86,8 +85,7 @@ class FuseModule(nn.Module):
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) return prompt_embeds.view(batch_size, seq_length, -1)
return updated_prompt_embeds
class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
def __init__(self): def __init__(self):
@ -111,9 +109,8 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection):
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) return self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
return updated_prompt_embeds
class PhotoMakerLoader: class PhotoMakerLoader:

View File

@ -162,8 +162,7 @@ class Quantize:
result = result.to(dtype=torch.uint8) result = result.to(dtype=torch.uint8)
im = Image.fromarray(result.cpu().numpy()) im = Image.fromarray(result.cpu().numpy())
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) return im.quantize(palette=pal_im, dither=Image.Dither.NONE)
return im
def quantize(self, image: torch.Tensor, colors: int, dither: str): def quantize(self, image: torch.Tensor, colors: int, dither: str):
batch_size, height, width, _ = image.shape batch_size, height, width, _ = image.shape

View File

@ -50,8 +50,7 @@ class LatentRebatch:
def cat_batch(batch1, batch2): def cat_batch(batch1, batch2):
if batch1[0] is None: if batch1[0] is None:
return batch2 return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] return [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size): def rebatch(self, latents, batch_size):
batch_size = batch_size[0] batch_size = batch_size[0]

View File

@ -82,8 +82,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
mask = F.interpolate(mask, (lh, lw)) mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask) return blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma): def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5 ksize_half = (kernel_size - 1) * 0.5
@ -101,8 +100,7 @@ def gaussian_blur_2d(img, kernel_size, sigma):
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect") img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance: class SelfAttentionGuidance:
@classmethod @classmethod

View File

@ -5,7 +5,7 @@ import comfy.utils
def camera_embeddings(elevation, azimuth): def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation]) elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth]) azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack( return torch.stack(
[ [
torch.deg2rad( torch.deg2rad(
(90 - elevation) - (90) (90 - elevation) - (90)
@ -17,7 +17,6 @@ def camera_embeddings(elevation, azimuth):
), ),
], dim=-1).unsqueeze(1) ], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning: class StableZero123_Conditioning:

View File

@ -81,11 +81,10 @@ class CacheSet:
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self): def recursive_debug_dump(self):
result = { return {
"outputs": self.outputs.recursive_debug_dump(), "outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(), "ui": self.ui.recursive_debug_dump(),
} }
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()

View File

@ -356,8 +356,7 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
input = input.replace("%day%", str(now.tm_mday).zfill(2)) input = input.replace("%day%", str(now.tm_mday).zfill(2))
input = input.replace("%hour%", str(now.tm_hour).zfill(2)) input = input.replace("%hour%", str(now.tm_hour).zfill(2))
input = input.replace("%minute%", str(now.tm_min).zfill(2)) input = input.replace("%minute%", str(now.tm_min).zfill(2))
input = input.replace("%second%", str(now.tm_sec).zfill(2)) return input.replace("%second%", str(now.tm_sec).zfill(2))
return input
if "%" in filename_prefix: if "%" in filename_prefix:
filename_prefix = compute_vars(filename_prefix, image_width, image_height) filename_prefix = compute_vars(filename_prefix, image_width, image_height)

View File

@ -607,8 +607,7 @@ class unCLIPCheckpointLoader:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
class CLIPSetLastLayer: class CLIPSetLastLayer:
@classmethod @classmethod

View File

@ -9,6 +9,7 @@ lint.select = [
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F", "F",
"RET504",
] ]
exclude = ["*.ipynb"] exclude = ["*.ipynb"]

View File

@ -129,8 +129,7 @@ class TestCompareImageMetrics:
def read_img(self, filename: str) -> np.ndarray: def read_img(self, filename: str) -> np.ndarray:
cvImg = imread(filename) cvImg = imread(filename)
cvImg = cvtColor(cvImg, COLOR_BGR2RGB) return cvtColor(cvImg, COLOR_BGR2RGB)
return cvImg
def image_grid(self, img_list: list[list[Image.Image]]): def image_grid(self, img_list: list[list[Image.Image]]):
# imgs is a 2D list of images # imgs is a 2D list of images
@ -154,8 +153,7 @@ class TestCompareImageMetrics:
with open(metrics_output_file, 'r') as f: with open(metrics_output_file, 'r') as f:
for line in f: for line in f:
if fname_basestr in line: if fname_basestr in line:
score = float(line.split('|')[5]) return float(line.split('|')[5])
return score
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}") raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
def gather_file_basenames(self, directory: str): def gather_file_basenames(self, directory: str):

View File

@ -141,14 +141,13 @@ class TestExecutionBlockerNode:
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
inputs = { return {
"required": { "required": {
"input": ("*",), "input": ("*",),
"block": ("BOOLEAN",), "block": ("BOOLEAN",),
"verbose": ("BOOLEAN", {"default": False}), "verbose": ("BOOLEAN", {"default": False}),
}, },
} }
return inputs
RETURN_TYPES = ("*",) RETURN_TYPES = ("*",)
RETURN_NAMES = ("output",) RETURN_NAMES = ("output",)