mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Speed up fp8 matrix mult by using better code.
This commit is contained in:
parent
ea63b1c092
commit
510f3438c1
12
comfy/ops.py
12
comfy/ops.py
@ -251,16 +251,16 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
|
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
|
||||||
inn = input.to(dtype)
|
inn = input.view(-1, input.shape[2]).to(dtype)
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
||||||
for i in range(input.shape[0]):
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
|
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
|
||||||
else:
|
else:
|
||||||
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
|
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype)
|
||||||
out[i] = o
|
|
||||||
return out
|
return o.view((-1, input.shape[1], self.weight.shape[0]))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class fp8_ops(manual_cast):
|
class fp8_ops(manual_cast):
|
||||||
|
Loading…
Reference in New Issue
Block a user