@@ -655,7 +655,7 @@ def test_matmul_split_k(self):
655
655
def matmul_split_k (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
656
656
m , k = x .size ()
657
657
k2 , n = y .size ()
658
- out = torch .empty ([m , n ], dtype = x .dtype , device = x .device )
658
+ out = torch .zeros ([m , n ], dtype = x .dtype , device = x .device )
659
659
for tile_m , tile_n , outer_k in hl .tile ([m , n , k ]):
660
660
acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
661
661
for inner_k in hl .tile (outer_k .begin , outer_k .end ):
@@ -710,7 +710,7 @@ def _matmul_split_k_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_
710
710
def matmul_split_k(x: torch.Tensor, y: torch.Tensor):
711
711
m, k = x.size()
712
712
k2, n = y.size()
713
- out = torch.empty ([m, n], dtype=x.dtype, device=x.device)
713
+ out = torch.zeros ([m, n], dtype=x.dtype, device=x.device)
714
714
_BLOCK_SIZE_0 = 16
715
715
_BLOCK_SIZE_1 = 16
716
716
_BLOCK_SIZE_2 = 256
@@ -721,7 +721,7 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor):
721
721
def _matmul_split_k_make_precompiler(x: torch.Tensor, y: torch.Tensor):
722
722
m, k = x.size()
723
723
k2, n = y.size()
724
- out = torch.empty ([m, n], dtype=x.dtype, device=x.device)
724
+ out = torch.zeros ([m, n], dtype=x.dtype, device=x.device)
725
725
_BLOCK_SIZE_0 = 16
726
726
_BLOCK_SIZE_1 = 16
727
727
_BLOCK_SIZE_2 = 256
0 commit comments