Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit e0e13e1

Browse files
Merge pull request #341 from mingzhe09088/sync-fbcode
TC: Sync fbcode to github
2 parents df36d16 + fab553d commit e0e13e1

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

tensor_comprehensions/pybinds/pybind_options.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ PYBIND11_MODULE(mapping_options, m) {
8989
},
9090
"Require TC to try and execute different TC expressions interleaved (Max), separately (Min)\nor interleaved as long as sufficient parallelism is exploited (Preserve3Coincident) by\nperforming loop fusion and fission. Applies before tiling")
9191
.def(
92-
"serializeToProtobuf",
92+
"serialize",
9393
[](tc::CudaMappingOptions& instance) {
9494
std::string str = instance.toProtobufSerializedString();
9595
return py::bytes(str);

test_python/test_c2.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import numpy as np
1818
import hypothesis.strategies as st
1919
import caffe2.python.hypothesis_test_util as hu
20+
import tensor_comprehensions as tc
2021

21-
from hypothesis import given
22+
from hypothesis import given, settings
2223
from caffe2.python import core, dyndep
2324

2425

@@ -28,6 +29,11 @@
2829
else:
2930
dyndep.InitOpsLibrary("@/tc/tc:tc_c2")
3031

32+
MATMUL_LANG = """
33+
def matmul(float(M,N) A, float(N,K) B) -> (output) {
34+
output(i, j) +=! A(i, kk) * B(kk, j)
35+
}
36+
"""
3137

3238
class TestCaffe2(hu.HypothesisTestCase):
3339
@given(n=st.integers(1, 128),
@@ -38,14 +44,6 @@ class TestCaffe2(hu.HypothesisTestCase):
3844
def test_matmul(self, n, m, k, seed, gc, dc):
3945
np.random.seed(seed)
4046

41-
tc_forward = """
42-
def matmul(float(M,N) A, float(N,K) B) -> (output) {
43-
output(i, j) +=! A(i, kk) * B(kk, j)
44-
}
45-
"""
46-
47-
# TODO: (prigoyal) serialize the options
48-
# options = Options("mlp")
4947
X = np.random.rand(m, k).astype(np.float32)
5048
W = np.random.rand(k, n).astype(np.float32)
5149

@@ -54,7 +52,7 @@ def ref(X, W):
5452

5553
op = core.CreateOperator(
5654
"TcOp", ["X", "Y"], "out",
57-
tcDef=tc_forward,
55+
tcDef=MATMUL_LANG,
5856
tcName="matmul",
5957
)
6058

@@ -65,6 +63,42 @@ def ref(X, W):
6563
reference=ref,
6664
)
6765

66+
@given(n=st.integers(1, 128),
67+
m=st.integers(1, 128),
68+
k=st.integers(1, 128),
69+
seed=st.integers(min_value=0, max_value=2**32 - 1),
70+
**hu.gcs_gpu_only)
71+
@settings(max_examples=2)
72+
def test_matmul_tune_and_run(self, n, m, k, seed, gc, dc):
73+
matmul = tc.define(MATMUL_LANG, name="matmul")
74+
75+
mapping_options = matmul.autotune(
76+
(n, k), (k, m),
77+
generations=1,
78+
threads=32,
79+
pop_size=2,
80+
tuner_min_launch_total_threads=1,
81+
)
82+
83+
X = np.random.rand(m, k).astype(np.float32)
84+
W = np.random.rand(k, n).astype(np.float32)
85+
86+
def ref(X, W):
87+
return [np.dot(X, W)]
88+
89+
op = core.CreateOperator(
90+
"TcOp", ["X", "Y"], "out",
91+
tcDef=MATMUL_LANG,
92+
tcName="matmul",
93+
mappingOptions=mapping_options.serialize(),
94+
)
95+
96+
self.assertReferenceChecks(
97+
device_option=gc,
98+
op=op,
99+
inputs=[X, W],
100+
reference=ref,
101+
)
68102

69103
if __name__ == '__main__':
70104
unittest.main()

0 commit comments

Comments
 (0)