Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6712ed3

Browse files
authored
Merge pull request #255 from salexspb/caffe2_py_test
caffe2 test fix + validation
2 parents f418bfa + 91e5f58 commit 6712ed3

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

test_python/test_c2.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,55 @@
1515

1616
import unittest, os
1717
import numpy as np
18+
import hypothesis.strategies as st
19+
import caffe2.python.hypothesis_test_util as hu
1820

19-
from caffe2.proto import caffe2_pb2
20-
from caffe2.python import core, workspace, dyndep
21+
from hypothesis import given
22+
from caffe2.python import core, dyndep
2123

22-
tc_c2_lib = os.path.join(os.environ.get("CONDA_PREFIX"), "lib/libtc_c2.so")
23-
dyndep.InitOpsLibrary(tc_c2_lib)
2424

25+
CONDA_PREFIX = os.environ.get("CONDA_PREFIX")
26+
if CONDA_PREFIX:
27+
tc_c2_lib = os.path.join(CONDA_PREFIX, "lib/libtc_c2.so")
28+
else:
29+
dyndep.InitOpsLibrary("@/tc/tc:tc_c2")
2530

26-
class TestCaffe2(unittest.TestCase):
2731

28-
def test_matmul_caffe2(self):
29-
lang = """
32+
class TestCaffe2(hu.HypothesisTestCase):
33+
@given(n=st.integers(1, 128),
34+
m=st.integers(1, 128),
35+
k=st.integers(1, 128),
36+
seed=st.integers(min_value=0, max_value=2**32 - 1),
37+
**hu.gcs_gpu_only)
38+
def test_matmul(self, n, m, k, seed, gc, dc):
39+
np.random.seed(seed)
40+
41+
tc_forward = """
3042
def matmul(float(M,N) A, float(N,K) B) -> (output) {
3143
output(i, j) +=! A(i, kk) * B(kk, j)
3244
}
3345
"""
46+
3447
# TODO: (prigoyal) serialize the options
3548
# options = Options("mlp")
36-
mat1, mat2 = np.random.rand(100, 400), np.random.rand(400, 500)
37-
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA, 0)):
38-
workspace.FeedBlob('mat1', mat1.astype(np.float32))
39-
workspace.FeedBlob('mat2', mat2.astype(np.float32))
40-
matmul = core.CreateOperator(
41-
"TcOp", ["mat1", "mat2"], ["out"], lang=lang, tcName="matmul"
42-
)
43-
workspace.RunOperatorOnce(matmul)
44-
out = workspace.FetchBlob("out")
49+
X = np.random.rand(m, k).astype(np.float32)
50+
W = np.random.rand(k, n).astype(np.float32)
51+
52+
def ref(X, W):
53+
return [np.dot(X, W)]
54+
55+
op = core.CreateOperator(
56+
"TcOp", ["X", "Y"], "out",
57+
tcDef=tc_forward,
58+
tcName="matmul",
59+
)
60+
61+
self.assertReferenceChecks(
62+
device_option=gc,
63+
op=op,
64+
inputs=[X, W],
65+
reference=ref,
66+
)
4567

4668

4769
if __name__ == '__main__':

0 commit comments

Comments
 (0)