Skip to content

Commit f09dfab

Browse files
author
Prashant Kumar
committed
Add comprehensive tests to test the kernel across available dtypes.
Added softmax and gemm kernel to test across the available float and int dtypes.
1 parent 971231c commit f09dfab

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

core/tests/kernel/coverage_test.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch
2+
import shark_turbine.kernel as tk
3+
import shark_turbine.kernel.lang as tkl
4+
import pytest
5+
6+
7+
TKL_TO_TORCH_DTYPE = {
8+
tkl.f16: torch.half,
9+
tkl.f32: torch.float,
10+
tkl.f64: torch.double,
11+
tkl.bool: torch.bool,
12+
tkl.i8: torch.int8,
13+
tkl.i16: torch.int16,
14+
tkl.i32: torch.int32,
15+
tkl.i64: torch.int64,
16+
}
17+
18+
FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64]
19+
INT_DTYPES = [
20+
tkl.bool,
21+
tkl.i4,
22+
tkl.i8,
23+
tkl.i16,
24+
tkl.i32,
25+
tkl.i64,
26+
tkl.index,
27+
]
28+
29+
30+
def iota_krnl(dtype, input):
31+
M = tkl.sym.M
32+
33+
@tk.gen.thread(M)
34+
def iota_kernel(out: tkl.OutputBuffer[M, dtype]):
35+
a = (
36+
tkl.constant((17, 37, 19), dtype, 5)
37+
if dtype in INT_DTYPES
38+
else tkl.constant((17, 37, 19), dtype, 5.0)
39+
)
40+
b = (
41+
tkl.constant((17, 37, 19), dtype, 10)
42+
if dtype in INT_DTYPES
43+
else tkl.constant((17, 37, 19), dtype, 10.0)
44+
)
45+
c = (
46+
tkl.constant((17, 37, 19), dtype, 2)
47+
if dtype in INT_DTYPES
48+
else tkl.constant((17, 37, 19), dtype, 2.0)
49+
)
50+
if dtype in INT_DTYPES:
51+
c = (a * b) // c
52+
else:
53+
c = (a * b) / c
54+
c = c + a - b
55+
56+
with tk.gen.TestLaunchContext():
57+
iota_kernel(input)
58+
59+
60+
def softmax_krnl(dtype, input, output):
61+
M = tkl.sym.M
62+
K = tkl.sym.K
63+
64+
@tk.gen.thread(M)
65+
def softmax_kernel(
66+
input: tk.lang.InputBuffer[M, K, dtype],
67+
output: tk.lang.OutputBuffer[M, K, dtype],
68+
):
69+
row_index = tk.lang.program_id(0)
70+
input_row = input[row_index, :]
71+
numerator = tkl.exp2(input_row - tkl.max(input_row))
72+
if dtype in INT_DTYPES:
73+
output_row = numerator // tkl.sum(numerator)
74+
else:
75+
output_row = numerator / tkl.sum(numerator)
76+
output[row_index, :] = output_row
77+
78+
with tk.gen.TestLaunchContext():
79+
softmax_kernel(input, output)
80+
81+
82+
def gemm_fx_kernel(dtype, A, B, output):
83+
N = tkl.sym.N
84+
M = tkl.sym.M
85+
K = tkl.sym.K
86+
BLOCK_SIZE = tkl.sym.BLOCK_SIZE
87+
88+
@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
89+
def gemm_kernel(
90+
A: tkl.InputBuffer[N, K, dtype],
91+
B: tkl.InputBuffer[K, M, dtype],
92+
output: tkl.OutputBuffer[N, M, dtype],
93+
):
94+
grid_n = tkl.program_id(0)
95+
grid_m = tkl.program_id(1)
96+
97+
acc = None
98+
# TODO: Only considering the float and integer cases.
99+
if dtype in INT_DTYPES:
100+
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0)
101+
else:
102+
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0)
103+
104+
@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
105+
def body(i, c):
106+
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
107+
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
108+
return (tkl.dot(a, b, c),)
109+
110+
tkl.store(output, (grid_n, grid_m), body[0])
111+
112+
with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
113+
gemm_kernel(A, B, output)
114+
115+
116+
@pytest.mark.parametrize(
117+
("dtype",),
118+
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
119+
)
120+
def test_iota_krnl(dtype):
121+
input = torch.zeros(17)
122+
iota_krnl(dtype, input)
123+
124+
125+
@pytest.mark.parametrize(
126+
("dtype",),
127+
[(x,) for x in FLOAT_DTYPES],
128+
)
129+
def test_softmax_krnl(dtype):
130+
if dtype in TKL_TO_TORCH_DTYPE:
131+
input = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
132+
output = torch.randn(128, 64).to(TKL_TO_TORCH_DTYPE[dtype])
133+
softmax_krnl(dtype, input, output)
134+
135+
136+
@pytest.mark.parametrize(
137+
("dtype",),
138+
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
139+
)
140+
def test_gemm_krnl(dtype):
141+
if dtype in TKL_TO_TORCH_DTYPE:
142+
A = torch.randn(512, 1024).to(TKL_TO_TORCH_DTYPE[dtype])
143+
B = torch.randn(1024, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
144+
output = torch.zeros(512, 2048).to(TKL_TO_TORCH_DTYPE[dtype])
145+
gemm_fx_kernel(dtype, A, B, output)

0 commit comments

Comments
 (0)