mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Use common function for casting weights to input.
This commit is contained in:
parent
79040635da
commit
25853d0be8
@ -9,6 +9,7 @@ from einops import rearrange
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
import math
|
import math
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
class FourierFeatures(nn.Module):
|
class FourierFeatures(nn.Module):
|
||||||
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||||
@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
|
|||||||
[out_features // 2, in_features], dtype=dtype, device=device))
|
[out_features // 2, in_features], dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
|
||||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||||
|
|
||||||
# norms
|
# norms
|
||||||
@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
beta = self.beta
|
beta = self.beta
|
||||||
if self.beta is not None:
|
if beta is not None:
|
||||||
beta = beta.to(dtype=x.dtype, device=x.device)
|
beta = comfy.ops.cast_to_input(beta, x)
|
||||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
|
||||||
|
|
||||||
class GLU(nn.Module):
|
class GLU(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
|
|||||||
scale_base = 512,
|
scale_base = 512,
|
||||||
interpolation_factor = 1.,
|
interpolation_factor = 1.,
|
||||||
base = 10000,
|
base = 10000,
|
||||||
base_rescale_factor = 1.
|
base_rescale_factor = 1.,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||||
|
|
||||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
self.register_buffer('inv_freq', inv_freq)
|
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
|
||||||
|
|
||||||
assert interpolation_factor >= 1.
|
assert interpolation_factor >= 1.
|
||||||
self.interpolation_factor = interpolation_factor
|
self.interpolation_factor = interpolation_factor
|
||||||
@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
t = t / self.interpolation_factor
|
t = t / self.interpolation_factor
|
||||||
|
|
||||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
|
||||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||||
|
|
||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
return freqs, 1.
|
return freqs, 1.
|
||||||
|
|
||||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||||
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||||
scale = torch.cat((scale, scale), dim = -1)
|
scale = torch.cat((scale, scale), dim = -1)
|
||||||
|
|
||||||
return freqs, scale
|
return freqs, scale
|
||||||
@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
|
|||||||
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
||||||
|
|
||||||
if rotary_pos_emb:
|
if rotary_pos_emb:
|
||||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.rotary_pos_emb = None
|
self.rotary_pos_emb = None
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def modulate(x, shift, scale):
|
def modulate(x, shift, scale):
|
||||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
@ -427,7 +428,7 @@ class MMDiT(nn.Module):
|
|||||||
max_dim = max(h, w)
|
max_dim = max(h, w)
|
||||||
|
|
||||||
cur_dim = self.h_max
|
cur_dim = self.h_max
|
||||||
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
|
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
|
||||||
|
|
||||||
if max_dim > cur_dim:
|
if max_dim > cur_dim:
|
||||||
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
||||||
@ -455,7 +456,7 @@ class MMDiT(nn.Module):
|
|||||||
t = timestep
|
t = timestep
|
||||||
|
|
||||||
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
||||||
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
|
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
|
||||||
|
|
||||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||||
|
|
||||||
|
@ -19,14 +19,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.ops
|
||||||
class Linear(torch.nn.Linear):
|
|
||||||
def reset_parameters(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d):
|
|
||||||
def reset_parameters(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
class OptimizedAttention(nn.Module):
|
class OptimizedAttention(nn.Module):
|
||||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
|
|||||||
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||||
def __init__(self, dim, dtype=None, device=None):
|
def __init__(self, dim, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||||
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
|
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
from torch.utils import checkpoint
|
from torch.utils import checkpoint
|
||||||
@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
|
|
||||||
if self.use_style_cond:
|
if self.use_style_cond:
|
||||||
# Here we use a default learned embedder layer for future extension.
|
# Here we use a default learned embedder layer for future extension.
|
||||||
self.style_embedder = nn.Embedding(1, hidden_size, dtype=dtype, device=device)
|
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
|
||||||
self.extra_in_dim += hidden_size
|
self.extra_in_dim += hidden_size
|
||||||
|
|
||||||
# Text embedding for `add`
|
# Text embedding for `add`
|
||||||
@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||||
|
|
||||||
padding = self.text_embedding_padding.to(text_states)
|
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||||
|
|
||||||
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
|
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
|
||||||
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
|
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.attention import optimized_attention #TODO
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
class AttentionPool(nn.Module):
|
class AttentionPool(nn.Module):
|
||||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
||||||
@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
|
|||||||
x = x[:,:self.positional_embedding.shape[0] - 1]
|
x = x[:,:self.positional_embedding.shape[0] - 1]
|
||||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||||
x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC
|
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
|
||||||
|
|
||||||
q = self.q_proj(x[:1])
|
q = self.q_proj(x[:1])
|
||||||
k = self.k_proj(x)
|
k = self.k_proj(x)
|
||||||
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||||||
from .. import attention
|
from .. import attention
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from .util import timestep_embedding
|
from .util import timestep_embedding
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
def default(x, y):
|
def default(x, y):
|
||||||
if x is not None:
|
if x is not None:
|
||||||
@ -926,7 +927,7 @@ class MMDiT(nn.Module):
|
|||||||
context = self.context_processor(context)
|
context = self.context_processor(context)
|
||||||
|
|
||||||
hw = x.shape[-2:]
|
hw = x.shape[-2:]
|
||||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
|
||||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||||
if y is not None and self.y_embedder is not None:
|
if y is not None and self.y_embedder is not None:
|
||||||
y = self.y_embedder(y) # (N, D)
|
y = self.y_embedder(y) # (N, D)
|
||||||
|
25
comfy/ops.py
25
comfy/ops.py
@ -19,14 +19,17 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
def cast_to_input(weight, input, non_blocking=False):
|
||||||
|
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
def cast_bias_weight(s, input):
|
def cast_bias_weight(s, input):
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
|
||||||
if s.bias_function is not None:
|
if s.bias_function is not None:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
|
||||||
if s.weight_function is not None:
|
if s.weight_function is not None:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
@ -168,6 +171,21 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.bias = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conv_nd(s, dims, *args, **kwargs):
|
def conv_nd(s, dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class Embedding(disable_weight_init.Embedding):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
Loading…
Reference in New Issue
Block a user