From 25853d0be8be6622195afaba2bc92e49e518bdcc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 Jul 2024 05:03:20 -0400 Subject: [PATCH] Use common function for casting weights to input. --- comfy/ldm/audio/dit.py | 23 ++++++++++--------- comfy/ldm/aura/mmdit.py | 5 +++-- comfy/ldm/cascade/common.py | 15 ++++--------- comfy/ldm/hydit/models.py | 5 +++-- comfy/ldm/hydit/poolers.py | 6 ++--- comfy/ldm/modules/diffusionmodules/mmdit.py | 3 ++- comfy/ops.py | 25 +++++++++++++++++++-- 7 files changed, 51 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 1c1112c5..4d2185be 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -9,6 +9,7 @@ from einops import rearrange from torch import nn from torch.nn import functional as F import math +import comfy.ops class FourierFeatures(nn.Module): def __init__(self, in_features, out_features, std=1., dtype=None, device=None): @@ -18,7 +19,7 @@ class FourierFeatures(nn.Module): [out_features // 2, in_features], dtype=dtype, device=device)) def forward(self, input): - f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device) + f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input) return torch.cat([f.cos(), f.sin()], dim=-1) # norms @@ -38,9 +39,9 @@ class LayerNorm(nn.Module): def forward(self, x): beta = self.beta - if self.beta is not None: - beta = beta.to(dtype=x.dtype, device=x.device) - return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta) + if beta is not None: + beta = comfy.ops.cast_to_input(beta, x) + return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta) class GLU(nn.Module): def __init__( @@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module): scale_base = 512, interpolation_factor = 1., base = 10000, - base_rescale_factor = 1. + base_rescale_factor = 1., + dtype=None, + device=None, ): super().__init__() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning @@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module): # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ base *= base_rescale_factor ** (dim / (dim - 2)) - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + # inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype)) assert interpolation_factor >= 1. self.interpolation_factor = interpolation_factor @@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module): t = t / self.interpolation_factor - freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device)) + freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t)) freqs = torch.cat((freqs, freqs), dim = -1) if self.scale is None: return freqs, 1. power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base - scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1') + scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1') scale = torch.cat((scale, scale), dim = -1) return freqs, scale @@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module): self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() if rotary_pos_emb: - self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype) else: self.rotary_pos_emb = None diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 2564166a..9956d363 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention +import comfy.ops def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -427,7 +428,7 @@ class MMDiT(nn.Module): max_dim = max(h, w) cur_dim = self.h_max - pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) + pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x) if max_dim > cur_dim: pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1) @@ -455,7 +456,7 @@ class MMDiT(nn.Module): t = timestep c = self.cond_seq_linear(c_seq) # B, T_c, D - c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1) + c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1) global_cond = self.t_embedder(t, x.dtype) # B, D diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 124902c0..3eaa0c82 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -19,14 +19,7 @@ import torch import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention - -class Linear(torch.nn.Linear): - def reset_parameters(self): - return None - -class Conv2d(torch.nn.Conv2d): - def reset_parameters(self): - return None +import comfy.ops class OptimizedAttention(nn.Module): def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): @@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module): "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" def __init__(self, dim, dtype=None, device=None): super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x + return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x class ResBlock(nn.Module): diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index bde9687e..37836fc3 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import comfy.ops from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from torch.utils import checkpoint @@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module): if self.use_style_cond: # Here we use a default learned embedder layer for future extension. - self.style_embedder = nn.Embedding(1, hidden_size, dtype=dtype, device=device) + self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device) self.extra_in_dim += hidden_size # Text embedding for `add` @@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module): b_t5, l_t5, c_t5 = text_states_t5.shape text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1) - padding = self.text_embedding_padding.to(text_states) + padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states) text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len]) text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:]) diff --git a/comfy/ldm/hydit/poolers.py b/comfy/ldm/hydit/poolers.py index 2c6e46e6..f5e5b406 100644 --- a/comfy/ldm/hydit/poolers.py +++ b/comfy/ldm/hydit/poolers.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.attention import optimized_attention #TODO - +from comfy.ldm.modules.attention import optimized_attention +import comfy.ops class AttentionPool(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None): @@ -19,7 +19,7 @@ class AttentionPool(nn.Module): x = x[:,:self.positional_embedding.shape[0] - 1] x = x.permute(1, 0, 2) # NLC -> LNC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC - x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC + x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC q = self.q_proj(x[:1]) k = self.k_proj(x) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index aac48a7f..ea1b5aa0 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -8,6 +8,7 @@ import torch.nn as nn from .. import attention from einops import rearrange, repeat from .util import timestep_embedding +import comfy.ops def default(x, y): if x is not None: @@ -926,7 +927,7 @@ class MMDiT(nn.Module): context = self.context_processor(context) hw = x.shape[-2:] - x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device) + x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x) c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) diff --git a/comfy/ops.py b/comfy/ops.py index 0f1ceb57..6d88a1b9 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,14 +19,17 @@ import torch import comfy.model_management +def cast_to_input(weight, input, non_blocking=False): + return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + def cast_bias_weight(s, input): bias = None non_blocking = comfy.model_management.device_should_use_non_blocking(input.device) if s.bias is not None: - bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + bias = cast_to_input(s.bias, input, non_blocking=non_blocking) if s.bias_function is not None: bias = s.bias_function(bias) - weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + weight = cast_to_input(s.weight, input, non_blocking=non_blocking) if s.weight_function is not None: weight = s.weight_function(weight) return weight, bias @@ -168,6 +171,21 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class Embedding(torch.nn.Embedding, CastWeightBiasOp): + def reset_parameters(self): + self.bias = None + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -202,3 +220,6 @@ class manual_cast(disable_weight_init): class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True + + class Embedding(disable_weight_init.Embedding): + comfy_cast_weights = True