diff --git a/comfy/float.py b/comfy/float.py index 9822ae48..51ae9872 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -32,8 +32,12 @@ def manual_stochastic_round_to_float8(x, dtype): result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa) # Handle zero case - result = torch.where(abs_x == 0, torch.zeros_like(result), result) + zero_mask = (abs_x == 0) + result = torch.where(zero_mask, torch.zeros_like(result), result) + # Handle subnormal numbers + min_normal = 2.0 ** (-EXPONENT_BIAS + 1) + result = torch.where((abs_x < min_normal) & (~zero_mask), torch.round(x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) * (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)), result) return result.to(dtype=dtype)