mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
add RMSNorm to comfy.ops
This commit is contained in:
parent
a14c2fc356
commit
8a438115fb
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.ops
|
import comfy.rmsnorm
|
||||||
|
|
||||||
|
|
||||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
|||||||
|
|
||||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||||
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
except:
|
|
||||||
rms_norm_torch = None
|
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
rms_norm = comfy.rmsnorm.rms_norm
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
|
||||||
if weight is None:
|
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
20
comfy/ops.py
20
comfy/ops.py
@ -21,6 +21,7 @@ import logging
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
import comfy.rmsnorm
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
@ -146,6 +147,25 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.bias = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
if self.weight is not None:
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
else:
|
||||||
|
weight = None
|
||||||
|
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||||
|
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
65
comfy/rmsnorm.py
Normal file
65
comfy/rmsnorm.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
rms_norm_torch = torch.nn.functional.rms_norm
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
except:
|
||||||
|
rms_norm_torch = None
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||||
|
if weight is None:
|
||||||
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||||
|
else:
|
||||||
|
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
|
else:
|
||||||
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||||
|
if weight is None:
|
||||||
|
return r
|
||||||
|
else:
|
||||||
|
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
if RMSNorm is None:
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.learnable_scale = elementwise_affine
|
||||||
|
if self.learnable_scale:
|
||||||
|
self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape,
|
||||||
|
eps=None,
|
||||||
|
elementwise_affine=True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return rms_norm(x, self.weight, self.eps)
|
Loading…
Reference in New Issue
Block a user