18
18
import hypothesis .strategies as st
19
19
import caffe2 .python .hypothesis_test_util as hu
20
20
import tensor_comprehensions as tc
21
+ import torch
21
22
22
23
from hypothesis import given , settings
23
24
from caffe2 .python import core , dyndep
33
34
def matmul(float(M,N) A, float(N,K) B) -> (output) {
34
35
output(m, k) +=! A(m, r_n) * B(r_n, k)
35
36
}
36
- """
37
-
38
- MATMUL_GRAD_LANG = """
39
37
def matmul_grad(float(M, N) A, float(N, K) B, float(M, K) d_O) -> (d_A, d_B) {
40
38
d_A(m, n) +=! d_O(m, r_k) * B(n, r_k)
41
39
d_B(n, k) +=! d_O(r_m, k) * A(r_m, n)
@@ -61,7 +59,7 @@ def ref(X, W):
61
59
"TcOp" , ["X" , "Y" ], "out" ,
62
60
tc_def = MATMUL_LANG ,
63
61
tc_name = "matmul" ,
64
- tc_grad_def = MATMUL_GRAD_LANG ,
62
+ tc_grad_def = MATMUL_LANG ,
65
63
tc_grad_name = "matmul_grad" ,
66
64
inputs_used_by_gradient = [0 , 1 ],
67
65
output_gradients_used_by_gradient = [0 ],
@@ -91,24 +89,23 @@ def ref(X, W):
91
89
** hu .gcs_gpu_only )
92
90
@settings (max_examples = 2 )
93
91
def test_matmul_tune_and_run (self , n , m , k , seed , gc , dc ):
94
- matmul = tc .define (MATMUL_LANG , name = "matmul" )
95
- matmul_grad = tc .define (MATMUL_GRAD_LANG , name = "matmul_grad" )
96
-
97
- mapping_options = matmul .autotune (
98
- (n , k ), (k , m ),
99
- generations = 3 ,
100
- threads = 32 ,
101
- pop_size = 2 ,
102
- tuner_min_launch_total_threads = 1 ,
103
- )
104
-
105
- grad_mapping_options = matmul_grad .autotune (
106
- (n , k ), (k , m ), (n , m ),
107
- generations = 1 ,
108
- threads = 32 ,
109
- pop_size = 2 ,
110
- tuner_min_launch_total_threads = 1 ,
111
- )
92
+ tuner = tc .Tuner (MATMUL_LANG )
93
+ tuner_config = (
94
+ tc .TunerConfig ().generations (3 ).threads (32 ).pop_size (2 )
95
+ .tuner_min_launch_total_threads (1 ))
96
+ matmul_top1 = tuner .tune (
97
+ 'matmul' ,
98
+ (torch .randn (n , k , device = 'cuda' ),
99
+ torch .randn (k , m , device = 'cuda' )),
100
+ tc .MappingOptions ('naive' ),
101
+ tuner_config )
102
+ matmul_grad_top1 = tuner .tune (
103
+ 'matmul_grad' ,
104
+ (torch .randn (n , k , device = 'cuda' ),
105
+ torch .randn (k , m , device = 'cuda' ),
106
+ torch .randn (n , m , device = 'cuda' )),
107
+ tc .MappingOptions ('naive' ),
108
+ tuner_config )
112
109
113
110
X = np .random .rand (m , k ).astype (np .float32 )
114
111
W = np .random .rand (k , n ).astype (np .float32 )
@@ -120,13 +117,13 @@ def ref(X, W):
120
117
"TcOp" , ["X" , "Y" ], "out" ,
121
118
tc_def = MATMUL_LANG ,
122
119
tc_name = "matmul" ,
123
- tc_grad_def = MATMUL_GRAD_LANG ,
120
+ tc_grad_def = MATMUL_LANG ,
124
121
tc_grad_name = "matmul_grad" ,
125
122
inputs_used_by_gradient = [0 , 1 ],
126
123
output_gradients_used_by_gradient = [0 ],
127
124
inputs_to_compute_gradients_of = [0 , 1 ],
128
- mapping_options = mapping_options .serialize (),
129
- grad_mapping_options = grad_mapping_options .serialize (),
125
+ mapping_options = matmul_top1 .serialize (),
126
+ grad_mapping_options = matmul_grad_top1 .serialize (),
130
127
)
131
128
132
129
self .assertReferenceChecks (
0 commit comments