diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 2f26bf5c365b..9d98df536d68 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -129,10 +129,10 @@ def scaled_mm( def mm(self, inp, weights, bias=None): if not support_tuned_gemms: return F.linear(inp, weights, bias) - # F.Linear can take a 3 dimensional input. vllm - # uses this for linear units. However, sampler - # will use torch.matmul with 2 dimensions only - if inp.dim() == 3: + # F.Linear can take a 3 dimensional (or even larger) + # input. vllm uses this for linear units. However, + # sampler will use torch.matmul with 2 dimensions only + if inp.dim() >= 3: try: inp_view = inp.view(-1, inp.size(-1)) batched = True @@ -157,7 +157,7 @@ def mm(self, inp, weights, bias=None): out = self.apply_skinny(m, n, k, inp_view, weights) if out is not None: if batched: - out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + out = out.view(*inp.shape[:-1], weights.shape[0]) if bias is not None: return out + bias return out @@ -182,7 +182,7 @@ def mm(self, inp, weights, bias=None): self.tuned_df.to_csv(self.untune_path, index=False) return F.linear(inp, weights, bias) if batched: - out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + out = out.view(*inp.shape[:-1], weights.shape[0]) return out