mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Lower fp8 lora memory usage.
This commit is contained in:
parent
d043997d30
commit
00a5d08103
@ -41,9 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|||||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
)
|
)
|
||||||
del abs_x
|
|
||||||
|
|
||||||
return sign.to(dtype=dtype)
|
return sign
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -57,6 +56,11 @@ def stochastic_rounding(value, dtype, seed=0):
|
|||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
generator = torch.Generator(device=value.device)
|
generator = torch.Generator(device=value.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
return value.to(dtype=dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user