diff --git a/comfy/ops.py b/comfy/ops.py index fc78dd83..6eb8eb3d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -251,16 +251,16 @@ def fp8_linear(self, input): if len(input.shape) == 3: out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype) - inn = input.to(dtype) + inn = input.view(-1, input.shape[2]).to(dtype) non_blocking = comfy.model_management.device_supports_non_blocking(input.device) w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t() - for i in range(input.shape[0]): - if self.bias is not None: - o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking)) - else: - o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype) - out[i] = o - return out + + if self.bias is not None: + o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking)) + else: + o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype) + + return o.view((-1, input.shape[1], self.weight.shape[0])) return None class fp8_ops(manual_cast):