Use common function for casting weights to input.

This commit is contained in:
comfyanonymous 2024-07-30 05:03:20 -04:00
parent 79040635da
commit 25853d0be8
7 changed files with 51 additions and 31 deletions

View File

@ -9,6 +9,7 @@ from einops import rearrange
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import math import math
import comfy.ops
class FourierFeatures(nn.Module): class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1., dtype=None, device=None): def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
[out_features // 2, in_features], dtype=dtype, device=device)) [out_features // 2, in_features], dtype=dtype, device=device))
def forward(self, input): def forward(self, input):
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device) f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
return torch.cat([f.cos(), f.sin()], dim=-1) return torch.cat([f.cos(), f.sin()], dim=-1)
# norms # norms
@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
beta = self.beta beta = self.beta
if self.beta is not None: if beta is not None:
beta = beta.to(dtype=x.dtype, device=x.device) beta = comfy.ops.cast_to_input(beta, x)
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta) return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
class GLU(nn.Module): class GLU(nn.Module):
def __init__( def __init__(
@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
scale_base = 512, scale_base = 512,
interpolation_factor = 1., interpolation_factor = 1.,
base = 10000, base = 10000,
base_rescale_factor = 1. base_rescale_factor = 1.,
dtype=None,
device=None,
): ):
super().__init__() super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2)) base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) # inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq) self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
assert interpolation_factor >= 1. assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor self.interpolation_factor = interpolation_factor
@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
t = t / self.interpolation_factor t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device)) freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
freqs = torch.cat((freqs, freqs), dim = -1) freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None: if self.scale is None:
return freqs, 1. return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1') scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1) scale = torch.cat((scale, scale), dim = -1)
return freqs, scale return freqs, scale
@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
if rotary_pos_emb: if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
else: else:
self.rotary_pos_emb = None self.rotary_pos_emb = None

View File

@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
def modulate(x, shift, scale): def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@ -427,7 +428,7 @@ class MMDiT(nn.Module):
max_dim = max(h, w) max_dim = max(h, w)
cur_dim = self.h_max cur_dim = self.h_max
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
if max_dim > cur_dim: if max_dim > cur_dim:
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1) pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
@ -455,7 +456,7 @@ class MMDiT(nn.Module):
t = timestep t = timestep
c = self.cond_seq_linear(c_seq) # B, T_c, D c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1) c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t, x.dtype) # B, D global_cond = self.t_embedder(t, x.dtype) # B, D

View File

@ -19,14 +19,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class OptimizedAttention(nn.Module): class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim, dtype=None, device=None): def __init__(self, dim, dtype=None, device=None):
super().__init__() super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
def forward(self, x): def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
class ResBlock(nn.Module): class ResBlock(nn.Module):

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint from torch.utils import checkpoint
@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
if self.use_style_cond: if self.use_style_cond:
# Here we use a default learned embedder layer for future extension. # Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, hidden_size, dtype=dtype, device=device) self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
self.extra_in_dim += hidden_size self.extra_in_dim += hidden_size
# Text embedding for `add` # Text embedding for `add`
@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
b_t5, l_t5, c_t5 = text_states_t5.shape b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1) text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = self.text_embedding_padding.to(text_states) padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len]) text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:]) text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])

View File

@ -1,8 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention #TODO from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
class AttentionPool(nn.Module): class AttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
x = x[:,:self.positional_embedding.shape[0] - 1] x = x[:,:self.positional_embedding.shape[0] - 1]
x = x.permute(1, 0, 2) # NLC -> LNC x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
q = self.q_proj(x[:1]) q = self.q_proj(x[:1])
k = self.k_proj(x) k = self.k_proj(x)

View File

@ -8,6 +8,7 @@ import torch.nn as nn
from .. import attention from .. import attention
from einops import rearrange, repeat from einops import rearrange, repeat
from .util import timestep_embedding from .util import timestep_embedding
import comfy.ops
def default(x, y): def default(x, y):
if x is not None: if x is not None:
@ -926,7 +927,7 @@ class MMDiT(nn.Module):
context = self.context_processor(context) context = self.context_processor(context)
hw = x.shape[-2:] hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device) x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
c = self.t_embedder(t, dtype=x.dtype) # (N, D) c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None: if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D) y = self.y_embedder(y) # (N, D)

View File

@ -19,14 +19,17 @@
import torch import torch
import comfy.model_management import comfy.model_management
def cast_to_input(weight, input, non_blocking=False):
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
def cast_bias_weight(s, input): def cast_bias_weight(s, input):
bias = None bias = None
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device) non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
if s.bias is not None: if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
if s.bias_function is not None: if s.bias_function is not None:
bias = s.bias_function(bias) bias = s.bias_function(bias)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
if s.weight_function is not None: if s.weight_function is not None:
weight = s.weight_function(weight) weight = s.weight_function(weight)
return weight, bias return weight, bias
@ -168,6 +171,21 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod @classmethod
def conv_nd(s, dims, *args, **kwargs): def conv_nd(s, dims, *args, **kwargs):
if dims == 2: if dims == 2:
@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d): class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True