Skip to content

Commit 05d834c

Browse files
author
Prashant Kumar
committed
Add comprehensive tests to test the kernel across available dtypes.
Added softmax and gemm kernel to test across the available float and int dtypes.
1 parent 971231c commit 05d834c

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

core/tests/kernel/coverage_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
import shark_turbine.kernel as tk
3+
import shark_turbine.kernel.lang as tkl
4+
import pytest
5+
6+
7+
TKL_TO_TORCH_DTYPE = {
8+
tkl.f16: torch.half,
9+
tkl.f32: torch.float,
10+
tkl.f64: torch.double,
11+
tkl.bool: torch.bool,
12+
tkl.i8: torch.int8,
13+
tkl.i16: torch.int16,
14+
tkl.i32: torch.int32,
15+
tkl.i64: torch.int64,
16+
}
17+
18+
FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64]
19+
INT_DTYPES = [
20+
tkl.bool,
21+
tkl.i4,
22+
tkl.i8,
23+
tkl.i16,
24+
tkl.i32,
25+
tkl.i64,
26+
tkl.index,
27+
]
28+
29+
30+
def softmax_krnl(dtype, input, output):
31+
M = tkl.sym.M
32+
K = tkl.sym.K
33+
34+
@tk.gen.thread(M)
35+
def softmax_kernel(
36+
input: tk.lang.InputBuffer[M, K, dtype],
37+
output: tk.lang.OutputBuffer[M, K, dtype],
38+
):
39+
row_index = tk.lang.program_id(0)
40+
input_row = input[row_index, :]
41+
numerator = tkl.exp2(input_row - tkl.max(input_row))
42+
if dtype in INT_DTYPES:
43+
output_row = numerator // tkl.sum(numerator)
44+
else:
45+
output_row = numerator / tkl.sum(numerator)
46+
output[row_index, :] = output_row
47+
48+
with tk.gen.TestLaunchContext():
49+
softmax_kernel(input, output)
50+
51+
52+
def gemm_fx_kernel(dtype, A, B, output):
53+
N = tkl.sym.N
54+
M = tkl.sym.M
55+
K = tkl.sym.K
56+
BLOCK_SIZE = tkl.sym.BLOCK_SIZE
57+
58+
@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
59+
def gemm_kernel(
60+
A: tkl.InputBuffer[N, K, dtype],
61+
B: tkl.InputBuffer[K, M, dtype],
62+
output: tkl.OutputBuffer[N, M, dtype],
63+
):
64+
grid_n = tkl.program_id(0)
65+
grid_m = tkl.program_id(1)
66+
67+
acc = None
68+
# TODO: Only considering the float and integer cases.
69+
if dtype in INT_DTYPES:
70+
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0)
71+
else:
72+
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0)
73+
74+
@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
75+
def body(i, c):
76+
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
77+
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
78+
return (tkl.dot(a, b, c),)
79+
80+
tkl.store(output, (grid_n, grid_m), body[0])
81+
82+
with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
83+
gemm_kernel(A, B, output)
84+
85+
86+
@pytest.mark.parametrize(
87+
("dtype",),
88+
[(x,) for x in FLOAT_DTYPES],
89+
)
90+
def test_softmax_krnl(dtype):
91+
if dtype in TKL_TO_TORCH_DTYPE:
92+
input = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
93+
output = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
94+
softmax_krnl(dtype, input, output)
95+
96+
97+
@pytest.mark.parametrize(
98+
("dtype",),
99+
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
100+
)
101+
def test_gemm_krnl(dtype):
102+
if dtype in TKL_TO_TORCH_DTYPE:
103+
A = torch.randn(512, 1024).to(TKL_TO_TORCH_DTYPE[dtype])
104+
B = torch.randn(1024, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
105+
output = torch.zeros(512, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
106+
gemm_fx_kernel(dtype, A, B, output)

0 commit comments

Comments
 (0)