Make fast fp8 take a bit less peak memory.

This commit is contained in:
comfyanonymous 2024-12-24 18:05:19 -05:00
parent 73e04987f7
commit 99a1fb6027

View File

@ -255,9 +255,10 @@ def fp8_linear(self, input):
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
scale_weight = self.scale_weight
@ -269,23 +270,24 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None