Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit fab553d

Browse files
author
Mingzhe Li
committed
Sync fbcode to github
1 parent 2cd56da commit fab553d

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

tensor_comprehensions/pybinds/pybind_options.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ PYBIND11_MODULE(mapping_options, m) {
8989
},
9090
"Require TC to try and execute different TC expressions interleaved (Max), separately (Min)\nor interleaved as long as sufficient parallelism is exploited (Preserve3Coincident) by\nperforming loop fusion and fission. Applies before tiling")
9191
.def(
92-
"serializeToProtobuf",
92+
"serialize",
9393
[](tc::CudaMappingOptions& instance) {
9494
std::string str = instance.toProtobufSerializedString();
9595
return py::bytes(str);

test_python/test_c2.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import numpy as np
1818
import hypothesis.strategies as st
1919
import caffe2.python.hypothesis_test_util as hu
20+
import tensor_comprehensions as tc
2021

21-
from hypothesis import given
22+
from hypothesis import given, settings
2223
from caffe2.python import core, dyndep
2324

2425

@@ -28,6 +29,11 @@
2829
else:
2930
dyndep.InitOpsLibrary("@/tc/tc:tc_c2")
3031

32+
MATMUL_LANG = """
33+
def matmul(float(M,N) A, float(N,K) B) -> (output) {
34+
output(i, j) +=! A(i, kk) * B(kk, j)
35+
}
36+
"""
3137

3238
class TestCaffe2(hu.HypothesisTestCase):
3339
@given(n=st.integers(1, 128),
@@ -38,14 +44,6 @@ class TestCaffe2(hu.HypothesisTestCase):
3844
def test_matmul(self, n, m, k, seed, gc, dc):
3945
np.random.seed(seed)
4046

41-
tc_forward = """
42-
def matmul(float(M,N) A, float(N,K) B) -> (output) {
43-
output(i, j) +=! A(i, kk) * B(kk, j)
44-
}
45-
"""
46-
47-
# TODO: (prigoyal) serialize the options
48-
# options = Options("mlp")
4947
X = np.random.rand(m, k).astype(np.float32)
5048
W = np.random.rand(k, n).astype(np.float32)
5149

@@ -54,7 +52,7 @@ def ref(X, W):
5452

5553
op = core.CreateOperator(
5654
"TcOp", ["X", "Y"], "out",
57-
tcDef=tc_forward,
55+
tcDef=MATMUL_LANG,
5856
tcName="matmul",
5957
)
6058

@@ -65,6 +63,42 @@ def ref(X, W):
6563
reference=ref,
6664
)
6765

66+
@given(n=st.integers(1, 128),
67+
m=st.integers(1, 128),
68+
k=st.integers(1, 128),
69+
seed=st.integers(min_value=0, max_value=2**32 - 1),
70+
**hu.gcs_gpu_only)
71+
@settings(max_examples=2)
72+
def test_matmul_tune_and_run(self, n, m, k, seed, gc, dc):
73+
matmul = tc.define(MATMUL_LANG, name="matmul")
74+
75+
mapping_options = matmul.autotune(
76+
(n, k), (k, m),
77+
generations=1,
78+
threads=32,
79+
pop_size=2,
80+
tuner_min_launch_total_threads=1,
81+
)
82+
83+
X = np.random.rand(m, k).astype(np.float32)
84+
W = np.random.rand(k, n).astype(np.float32)
85+
86+
def ref(X, W):
87+
return [np.dot(X, W)]
88+
89+
op = core.CreateOperator(
90+
"TcOp", ["X", "Y"], "out",
91+
tcDef=MATMUL_LANG,
92+
tcName="matmul",
93+
mappingOptions=mapping_options.serialize(),
94+
)
95+
96+
self.assertReferenceChecks(
97+
device_option=gc,
98+
op=op,
99+
inputs=[X, W],
100+
reference=ref,
101+
)
68102

69103
if __name__ == '__main__':
70104
unittest.main()

0 commit comments

Comments
 (0)