Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit eb29ce7

Browse files
caffe2 python test and benchmark
This commit move caffe2_benchmar.py and revives python/tests/test_caffe2.py. Unfortunately there no caffe2 python + TC support in OSS so we need to deactivate them in .jenkins.
1 parent 9d06687 commit eb29ce7

File tree

3 files changed

+150
-1
lines changed

3 files changed

+150
-1
lines changed

.jenkins/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ WITH_CAFFE2=ON CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda CLANG_PREFIX=$(${CONDA_PREF
6868

6969
python setup.py install
7070

71-
for f in $(find ./python/ -name "*.py"); do
71+
for f in $(find ./python/ -name "*.py" | grep -v caffe2); do
7272
python $f -v
7373
done
7474

python/tests/test_caffe2.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
##############################################################################
15+
16+
import unittest, os
17+
import numpy as np
18+
import hypothesis.strategies as st
19+
import caffe2.python.hypothesis_test_util as hu
20+
import tensor_comprehensions as tc
21+
22+
from hypothesis import given, settings
23+
from caffe2.python import core, dyndep
24+
25+
26+
CONDA_PREFIX = os.environ.get("CONDA_PREFIX")
27+
if CONDA_PREFIX:
28+
tc_c2_lib = os.path.join(CONDA_PREFIX, "lib/libtc_c2.so")
29+
else:
30+
dyndep.InitOpsLibrary("@/tc/tc:tc_c2")
31+
32+
MATMUL_LANG = """
33+
def matmul(float(M,N) A, float(N,K) B) -> (output) {
34+
output(m, k) +=! A(m, r_n) * B(r_n, k)
35+
}
36+
"""
37+
38+
MATMUL_GRAD_LANG = """
39+
def matmul_grad(float(M, N) A, float(N, K) B, float(M, K) d_O) -> (d_A, d_B) {
40+
d_A(m, n) +=! d_O(m, r_k) * B(n, r_k)
41+
d_B(n, k) +=! d_O(r_m, k) * A(r_m, n)
42+
}
43+
"""
44+
45+
class TestCaffe2(hu.HypothesisTestCase):
46+
@given(n=st.integers(1, 4),
47+
m=st.integers(1, 4),
48+
k=st.integers(1, 4),
49+
seed=st.integers(min_value=0, max_value=2**32 - 1),
50+
**hu.gcs_gpu_only)
51+
def test_matmul(self, n, m, k, seed, gc, dc):
52+
np.random.seed(seed)
53+
54+
X = np.random.rand(m, k).astype(np.float32)
55+
W = np.random.rand(k, n).astype(np.float32)
56+
57+
def ref(X, W):
58+
return [np.dot(X, W)]
59+
60+
op = core.CreateOperator(
61+
"TcOp", ["X", "Y"], "out",
62+
tc_def=MATMUL_LANG,
63+
tc_name="matmul",
64+
tc_grad_def=MATMUL_GRAD_LANG,
65+
tc_grad_name="matmul_grad",
66+
inputs_used_by_gradient=[0, 1],
67+
output_gradients_used_by_gradient=[0],
68+
inputs_to_compute_gradients_of=[0, 1],
69+
)
70+
71+
self.assertReferenceChecks(
72+
device_option=gc,
73+
op=op,
74+
inputs=[X, W],
75+
reference=ref,
76+
)
77+
78+
for i in range(2):
79+
self.assertGradientChecks(
80+
device_option=gc,
81+
op=op,
82+
inputs=[X, W],
83+
outputs_to_check=i,
84+
outputs_with_grads=[0],
85+
)
86+
87+
@given(n=st.integers(1, 4),
88+
m=st.integers(1, 4),
89+
k=st.integers(1, 4),
90+
seed=st.integers(min_value=0, max_value=2**32 - 1),
91+
**hu.gcs_gpu_only)
92+
@settings(max_examples=2)
93+
def test_matmul_tune_and_run(self, n, m, k, seed, gc, dc):
94+
matmul = tc.define(MATMUL_LANG, name="matmul")
95+
matmul_grad = tc.define(MATMUL_GRAD_LANG, name="matmul_grad")
96+
97+
mapping_options = matmul.autotune(
98+
(n, k), (k, m),
99+
generations=3,
100+
threads=32,
101+
pop_size=2,
102+
tuner_min_launch_total_threads=1,
103+
)
104+
105+
grad_mapping_options = matmul_grad.autotune(
106+
(n, k), (k, m), (n, m),
107+
generations=1,
108+
threads=32,
109+
pop_size=2,
110+
tuner_min_launch_total_threads=1,
111+
)
112+
113+
X = np.random.rand(m, k).astype(np.float32)
114+
W = np.random.rand(k, n).astype(np.float32)
115+
116+
def ref(X, W):
117+
return [np.dot(X, W)]
118+
119+
op = core.CreateOperator(
120+
"TcOp", ["X", "Y"], "out",
121+
tc_def=MATMUL_LANG,
122+
tc_name="matmul",
123+
tc_grad_def=MATMUL_GRAD_LANG,
124+
tc_grad_name="matmul_grad",
125+
inputs_used_by_gradient=[0, 1],
126+
output_gradients_used_by_gradient=[0],
127+
inputs_to_compute_gradients_of=[0, 1],
128+
mapping_options=mapping_options.serialize(),
129+
grad_mapping_options=grad_mapping_options.serialize(),
130+
)
131+
132+
self.assertReferenceChecks(
133+
device_option=gc,
134+
op=op,
135+
inputs=[X, W],
136+
reference=ref,
137+
)
138+
139+
for i in range(2):
140+
self.assertGradientChecks(
141+
device_option=gc,
142+
op=op,
143+
inputs=[X, W],
144+
outputs_to_check=i,
145+
outputs_with_grads=[0],
146+
)
147+
148+
if __name__ == '__main__':
149+
unittest.main()

0 commit comments

Comments
 (0)