Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 7bdb6bc

Browse files
Nicolas VasilacheJules Pondard
authored andcommitted
Update test_caffe2 to latest python bindings
Tested internally
1 parent 2b92f05 commit 7bdb6bc

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

python/tests/test_caffe2.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import hypothesis.strategies as st
1919
import caffe2.python.hypothesis_test_util as hu
2020
import tensor_comprehensions as tc
21+
import torch
2122

2223
from hypothesis import given, settings
2324
from caffe2.python import core, dyndep
@@ -33,9 +34,6 @@
3334
def matmul(float(M,N) A, float(N,K) B) -> (output) {
3435
output(m, k) +=! A(m, r_n) * B(r_n, k)
3536
}
36-
"""
37-
38-
MATMUL_GRAD_LANG = """
3937
def matmul_grad(float(M, N) A, float(N, K) B, float(M, K) d_O) -> (d_A, d_B) {
4038
d_A(m, n) +=! d_O(m, r_k) * B(n, r_k)
4139
d_B(n, k) +=! d_O(r_m, k) * A(r_m, n)
@@ -61,7 +59,7 @@ def ref(X, W):
6159
"TcOp", ["X", "Y"], "out",
6260
tc_def=MATMUL_LANG,
6361
tc_name="matmul",
64-
tc_grad_def=MATMUL_GRAD_LANG,
62+
tc_grad_def=MATMUL_LANG,
6563
tc_grad_name="matmul_grad",
6664
inputs_used_by_gradient=[0, 1],
6765
output_gradients_used_by_gradient=[0],
@@ -91,24 +89,23 @@ def ref(X, W):
9189
**hu.gcs_gpu_only)
9290
@settings(max_examples=2)
9391
def test_matmul_tune_and_run(self, n, m, k, seed, gc, dc):
94-
matmul = tc.define(MATMUL_LANG, name="matmul")
95-
matmul_grad = tc.define(MATMUL_GRAD_LANG, name="matmul_grad")
96-
97-
mapping_options = matmul.autotune(
98-
(n, k), (k, m),
99-
generations=3,
100-
threads=32,
101-
pop_size=2,
102-
tuner_min_launch_total_threads=1,
103-
)
104-
105-
grad_mapping_options = matmul_grad.autotune(
106-
(n, k), (k, m), (n, m),
107-
generations=1,
108-
threads=32,
109-
pop_size=2,
110-
tuner_min_launch_total_threads=1,
111-
)
92+
tuner = tc.Tuner(MATMUL_LANG)
93+
tuner_config = (
94+
tc.TunerConfig().generations(3).threads(32).pop_size(2)
95+
.tuner_min_launch_total_threads(1))
96+
matmul_top1 = tuner.tune(
97+
'matmul',
98+
(torch.randn(n, k, device='cuda'),
99+
torch.randn(k, m, device='cuda')),
100+
tc.MappingOptions('naive'),
101+
tuner_config)
102+
matmul_grad_top1 = tuner.tune(
103+
'matmul_grad',
104+
(torch.randn(n, k, device='cuda'),
105+
torch.randn(k, m, device='cuda'),
106+
torch.randn(n, m, device='cuda')),
107+
tc.MappingOptions('naive'),
108+
tuner_config)
112109

113110
X = np.random.rand(m, k).astype(np.float32)
114111
W = np.random.rand(k, n).astype(np.float32)
@@ -120,13 +117,13 @@ def ref(X, W):
120117
"TcOp", ["X", "Y"], "out",
121118
tc_def=MATMUL_LANG,
122119
tc_name="matmul",
123-
tc_grad_def=MATMUL_GRAD_LANG,
120+
tc_grad_def=MATMUL_LANG,
124121
tc_grad_name="matmul_grad",
125122
inputs_used_by_gradient=[0, 1],
126123
output_gradients_used_by_gradient=[0],
127124
inputs_to_compute_gradients_of=[0, 1],
128-
mapping_options=mapping_options.serialize(),
129-
grad_mapping_options=grad_mapping_options.serialize(),
125+
mapping_options=matmul_top1.serialize(),
126+
grad_mapping_options=matmul_grad_top1.serialize(),
130127
)
131128

132129
self.assertReferenceChecks(

0 commit comments

Comments
 (0)