Speed up fp8 matrix mult by using better code.

This commit is contained in:
comfyanonymous 2024-08-20 22:53:26 -04:00
parent ea63b1c092
commit 510f3438c1

View File

@ -251,16 +251,16 @@ def fp8_linear(self, input):
if len(input.shape) == 3:
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
inn = input.to(dtype)
inn = input.view(-1, input.shape[2]).to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
for i in range(input.shape[0]):
if self.bias is not None:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
else:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
out[i] = o
return out
if self.bias is not None:
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
else:
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype)
return o.view((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast):