From 22ec02afc0c900ee86ba6f0387bfd9dc0e34fb83 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Aug 2024 05:19:59 -0400 Subject: [PATCH] Handle subnormal numbers in float8 rounding. --- comfy/float.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/float.py b/comfy/float.py index 9822ae48..51ae9872 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -32,8 +32,12 @@ def manual_stochastic_round_to_float8(x, dtype): result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa) # Handle zero case - result = torch.where(abs_x == 0, torch.zeros_like(result), result) + zero_mask = (abs_x == 0) + result = torch.where(zero_mask, torch.zeros_like(result), result) + # Handle subnormal numbers + min_normal = 2.0 ** (-EXPONENT_BIAS + 1) + result = torch.where((abs_x < min_normal) & (~zero_mask), torch.round(x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) * (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)), result) return result.to(dtype=dtype)