mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 10:53:29 +00:00
66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
import torch
|
|
import comfy.model_management
|
|
import numbers
|
|
|
|
RMSNorm = None
|
|
|
|
try:
|
|
rms_norm_torch = torch.nn.functional.rms_norm
|
|
RMSNorm = torch.nn.RMSNorm
|
|
except:
|
|
rms_norm_torch = None
|
|
|
|
|
|
def rms_norm(x, weight=None, eps=1e-6):
|
|
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
|
if weight is None:
|
|
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
else:
|
|
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
else:
|
|
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
|
if weight is None:
|
|
return r
|
|
else:
|
|
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
|
|
|
|
if RMSNorm is None:
|
|
class RMSNorm(torch.nn.Module):
|
|
def __init__(
|
|
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
|
|
):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.learnable_scale = elementwise_affine
|
|
if self.learnable_scale:
|
|
self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
|
else:
|
|
self.register_parameter("weight", None)
|
|
|
|
def __init__(
|
|
self,
|
|
normalized_shape,
|
|
eps=None,
|
|
elementwise_affine=True,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
if isinstance(normalized_shape, numbers.Integral):
|
|
# mypy error: incompatible types in assignment
|
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = torch.nn.Parameter(
|
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
)
|
|
else:
|
|
self.register_parameter("weight", None)
|
|
|
|
def forward(self, x):
|
|
return rms_norm(x, self.weight, self.eps)
|