We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 96253e7 commit 4165a69Copy full SHA for 4165a69
examples/matmul.py
@@ -23,11 +23,11 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
23
return out
24
25
26
-def check(n: int, k: int, m: int) -> None:
+def check(m: int, k: int, n: int) -> None:
27
from triton.testing import do_bench
28
29
- x = torch.randn([n, k], device="cuda", dtype=torch.float16)
30
- y = torch.randn([k, m], device="cuda", dtype=torch.float16)
+ x = torch.randn([m, k], device="cuda", dtype=torch.float16)
+ y = torch.randn([k, n], device="cuda", dtype=torch.float16)
31
result = matmul(x, y)
32
torch.testing.assert_close(result, x @ y, rtol=1e-2, atol=1e-1)
33
sec = do_bench(lambda: matmul(x, y))
0 commit comments