Skip to content

Commit 59bf929

Browse files
authored
Fix bug in test_matmul_split_k (#156)
1 parent 2729f43 commit 59bf929

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/test_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def test_matmul_split_k(self):
655655
def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
656656
m, k = x.size()
657657
k2, n = y.size()
658-
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
658+
out = torch.zeros([m, n], dtype=x.dtype, device=x.device)
659659
for tile_m, tile_n, outer_k in hl.tile([m, n, k]):
660660
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
661661
for inner_k in hl.tile(outer_k.begin, outer_k.end):
@@ -710,7 +710,7 @@ def _matmul_split_k_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_
710710
def matmul_split_k(x: torch.Tensor, y: torch.Tensor):
711711
m, k = x.size()
712712
k2, n = y.size()
713-
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
713+
out = torch.zeros([m, n], dtype=x.dtype, device=x.device)
714714
_BLOCK_SIZE_0 = 16
715715
_BLOCK_SIZE_1 = 16
716716
_BLOCK_SIZE_2 = 256
@@ -721,7 +721,7 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor):
721721
def _matmul_split_k_make_precompiler(x: torch.Tensor, y: torch.Tensor):
722722
m, k = x.size()
723723
k2, n = y.size()
724-
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
724+
out = torch.zeros([m, n], dtype=x.dtype, device=x.device)
725725
_BLOCK_SIZE_0 = 16
726726
_BLOCK_SIZE_1 = 16
727727
_BLOCK_SIZE_2 = 256

0 commit comments

Comments
 (0)