From 8a438115fb9e3ed8327de25b23d341dccde229d9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 14 Apr 2025 18:00:33 -0400 Subject: [PATCH] add RMSNorm to comfy.ops --- comfy/ldm/common_dit.py | 20 ++----------- comfy/ops.py | 20 +++++++++++++ comfy/rmsnorm.py | 65 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 17 deletions(-) create mode 100644 comfy/rmsnorm.py diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index e0f3057f..f7f56b72 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,5 +1,6 @@ import torch -import comfy.ops +import comfy.rmsnorm + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()): @@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): return torch.nn.functional.pad(img, pad, mode=padding_mode) -try: - rms_norm_torch = torch.nn.functional.rms_norm -except: - rms_norm_torch = None -def rms_norm(x, weight=None, eps=1e-6): - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - if weight is None: - return rms_norm_torch(x, (x.shape[-1],), eps=eps) - else: - return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) - else: - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - if weight is None: - return r - else: - return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) +rms_norm = comfy.rmsnorm.rms_norm diff --git a/comfy/ops.py b/comfy/ops.py index 9a5c1ee9..6b0e2930 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,6 +21,7 @@ import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float +import comfy.rmsnorm cast_to = comfy.model_management.cast_to #TODO: remove once no more references @@ -146,6 +147,25 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): + def reset_parameters(self): + self.bias = None + return None + + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py new file mode 100644 index 00000000..81b3e906 --- /dev/null +++ b/comfy/rmsnorm.py @@ -0,0 +1,65 @@ +import torch +import comfy.model_management +import numbers + +RMSNorm = None + +try: + rms_norm_torch = torch.nn.functional.rms_norm + RMSNorm = torch.nn.RMSNorm +except: + rms_norm_torch = None + + +def rms_norm(x, weight=None, eps=1e-6): + if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): + if weight is None: + return rms_norm_torch(x, (x.shape[-1],), eps=eps) + else: + return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + if weight is None: + return r + else: + return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) + + +if RMSNorm is None: + class RMSNorm(torch.nn.Module): + def __init__( + self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs + ): + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def __init__( + self, + normalized_shape, + eps=None, + elementwise_affine=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = torch.nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + + def forward(self, x): + return rms_norm(x, self.weight, self.eps)