1
+ import time
2
+ import torch
3
+ import torch .optim as optim
4
+ import torch .nn as nn
5
+ import torch .utils .data
6
+ import torch .nn .functional as F
7
+ import tensor_comprehensions as tc
8
+ import numpy as np
9
+ from enum import IntEnum
10
+
11
+ NB_HYPERPARAMS = 26
12
+
13
+ class ExpTunerConfig :
14
+ def __init__ (self , use_max_shared_memory = 0 ):
15
+ self .INIT_INPUT_SZ = - 1
16
+ self .USE_MAX_SHARED_MEMORY = use_max_shared_memory
17
+ self .tc_code = ""
18
+ self .tc_name = ""
19
+ self .inp = - 1
20
+ self .cat_val = - 1
21
+ self .cat_sz = - 1
22
+
23
+ def set_convolution_tc (self , size_type = "default" , inp_sz_list = [], use_max_shared_memory = False ):
24
+ self .INIT_INPUT_SZ = 7
25
+ self .tc_name = "convolution"
26
+ self .tc_code = """
27
+ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
28
+ O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
29
+ }
30
+ """
31
+
32
+ if (size_type == "input" ):
33
+ N , C , H , W , O , kH , kW = tuple (inp_sz_list )
34
+ elif (size_type == "default" ):
35
+ N , C , H , W , O , kH , kW = 16 , 4 , 56 , 56 , 16 , 1 , 1 #8, 2, 28, 28, 8, 1, 1
36
+ elif (size_type == "random" ):
37
+ N , C , H , W , O , kH , kW = \
38
+ getrand ([8 , 16 , 32 , 64 ]), \
39
+ getrand ([2 , 4 , 8 , 16 ]), \
40
+ getrand ([28 , 56 , 112 ]), \
41
+ getrand ([28 , 56 , 112 ]), \
42
+ getrand ([8 , 16 , 32 ]), \
43
+ getrand ([1 , 2 , 4 ]), \
44
+ getrand ([1 , 2 , 4 ])
45
+ else :
46
+ print ("Unknown size type" )
47
+ exit ()
48
+ I , W1 = torch .randn (N , C , H , W , device = 'cuda' ), torch .randn (O , C , kH , kW , device = 'cuda' )
49
+ self .inp = (I , W1 )
50
+ self .init_input_sz = np .array ([N ,C ,H ,W ,O , kH , kW ])
51
+ print (self .init_input_sz )
52
+ self .init_input_sz = torch .from_numpy (self .init_input_sz ).float ()
53
+
54
+ self .computeCat ()
55
+
56
+ def computeCat (self ):
57
+ inp = self .inp
58
+ self .cat_sz = np .zeros (NB_HYPERPARAMS ).astype (int )
59
+ self .cat_val = [[] for _ in range (NB_HYPERPARAMS )]
60
+
61
+ divs = getAllDivs (inp )
62
+ if (self .USE_MAX_SHARED_MEMORY ):
63
+ divs2 = getAllDivs ([np .array ([tc .tclib .shared_memory_size ()])])
64
+
65
+ self .cat_val [MappingOptionsIdx .outerScheduleFusionStrategy ] = \
66
+ [0 ,1 ,2 ]
67
+ self .cat_val [MappingOptionsIdx .intraTileScheduleFusionStrategy ] = \
68
+ [0 ,1 ,2 ]
69
+ self .cat_val [MappingOptionsIdx .fixParametersBeforeScheduling ] = \
70
+ [0 ,1 ]
71
+ self .cat_val [MappingOptionsIdx .nTiledDims ] = \
72
+ [i + 1 for i in range (6 )]
73
+ for i in range (6 ): #tiling
74
+ self .cat_val [MappingOptionsIdx .tiling1 + i ] = \
75
+ divs + [0 ]
76
+ self .cat_val [MappingOptionsIdx .unroll ] = \
77
+ [2 ** i for i in range (8 )]
78
+ self .cat_val [MappingOptionsIdx .matchLibraryCalls ] = \
79
+ [0 ,1 ]
80
+ self .cat_val [MappingOptionsIdx .nMappedToBlocksDims ] = \
81
+ [i + 1 for i in range (3 )]
82
+ for i in range (3 ): #mapping to blocks
83
+ self .cat_val [MappingOptionsIdx .mappingToBlocks1 + i ] = \
84
+ divs
85
+ self .cat_val [MappingOptionsIdx .nMappedToThreadsDims ] = \
86
+ [i + 1 for i in range (3 )]
87
+ for i in range (3 ): #mapping to threads
88
+ self .cat_val [MappingOptionsIdx .mappingToThreads1 + i ] = \
89
+ divs
90
+ self .cat_val [MappingOptionsIdx .useSharedMemory ] = \
91
+ [0 ,1 ]
92
+ self .cat_val [MappingOptionsIdx .usePrivateMemory ] = \
93
+ [0 ,1 ]
94
+ self .cat_val [MappingOptionsIdx .unrollCopyShared ] = \
95
+ [0 ,1 ]
96
+ self .cat_val [MappingOptionsIdx .maxSharedMemory ] = \
97
+ divs2 if USE_MAX_SHARED_MEMORY else [0 ]
98
+ self .cat_val [MappingOptionsIdx .useReadOnlyCache ] = \
99
+ [0 ,1 ]
100
+ self .cat_val [MappingOptionsIdx .privateDepth ] = \
101
+ [i for i in range (6 )]
102
+
103
+ for i in range (NB_HYPERPARAMS ):
104
+ self .cat_sz [i ] = len (self .cat_val [i ])
105
+
106
+ def catVec_to_optVec (self , catVec ):
107
+ opt = [self .cat_val [i ][catVec [i ]] for i in range (NB_HYPERPARAMS )]
108
+ return opt
109
+
110
+
111
+ class MappingOptionsIdx (IntEnum ):
112
+ outerScheduleFusionStrategy = 0
113
+ intraScheduleFusionStrategy = 1
114
+ fixParametersBeforeScheduling = 2
115
+ nTiledDims = 3
116
+ tiling1 = 4
117
+ tiling2 = 5
118
+ tiling3 = 6
119
+ tiling4 = 7
120
+ tiling5 = 8
121
+ tiling6 = 9
122
+ unroll = 10
123
+ matchLibraryCalls = 11
124
+ nMappedToBlocksDims = 12
125
+ mappingToBlocks1 = 13
126
+ mappingToBlocks2 = 14
127
+ mappingToBlocks3 = 15
128
+ nMappedToThreadsDims = 16
129
+ mappingToThreads1 = 17
130
+ mappingToThreads2 = 18
131
+ mappingToThreads3 = 19
132
+ useSharedMemory = 20
133
+ usePrivateMemory = 21
134
+ unrollCopyShared = 22
135
+ maxSharedMemory = 23
136
+ useReadOnlyCache = 24
137
+ privateDepth = 25
138
+
139
+ def get_rand (l ):
140
+ return np .random .choice (l ).item ()
141
+
142
+ def print_opt (options ):
143
+ print (options .tolist ())
144
+
145
+ def evalTime (opt , exptuner_config , iters = 50 , warmup = 10 , estimator = "mean" , prune = - 1 , curr_best = - 1 ):
146
+ tc_code , tc_name , inp = \
147
+ exptuner_config .tc_code , exptuner_config .tc_name , exptuner_config .inp
148
+ infty = 30000
149
+ opt = exptuner_config .catVec_to_optVec (opt )
150
+ opt = optionsFromVector (opt )
151
+ try :
152
+ tc_prog = tc .compile (tc_code , tc_name , opt , * inp )
153
+ first_ft = tc_prog .executor .profile_kernel (inp )
154
+ except (KeyboardInterrupt , SystemExit ):
155
+ raise
156
+ except :
157
+ return infty
158
+ if (prune != - 1 and first_ft > 100 * curr_best ):
159
+ return first_ft
160
+ for _ in range (warmup - 1 ):
161
+ tc_prog .executor .profile_kernel (inp )
162
+
163
+ first_t = tc_prog .executor .profile_kernel (inp )
164
+
165
+ if (prune != - 1 and first_t > prune * curr_best ):
166
+ return first_t
167
+
168
+ tc_time_list = [first_t ]
169
+ for i in range (iters - 1 ):
170
+ iter_time = tc_prog .executor .profile_kernel (inp )
171
+ tc_time_list .append (iter_time )
172
+ if (estimator == "mean" ):
173
+ mean_time = np .mean (tc_time_list )
174
+ return mean_time
175
+ elif (estimator == "median" ):
176
+ median_time = np .median (tc_time_list )
177
+ return median_time
178
+ elif (estimator == "p25" ):
179
+ p25_time = np .percentile (tc_time_list , 25 )
180
+ return p25_time
181
+ print ("Unknown estimator" )
182
+ return infty
183
+
184
+ def getRawVectorFromTcOpt (tc_opt ):
185
+ tr_dic = {"Max" :0 , "Preserve3Coincident" :1 , "Min" :2 }
186
+ opt_vect = np .zeros (NB_HYPERPARAMS ).astype (int )
187
+ opt_vect [MappingOptionsIdx .outerScheduleFusionStrategy ] = \
188
+ tr_dic [tc_opt ["outerScheduleFusionStrategy" ]]
189
+ opt_vect [MappingOptionsIdx .intraTileScheduleFusionStrategy ] = \
190
+ tr_dic [tc_opt ["intraTileScheduleFusionStrategy" ]]
191
+ opt_vect [MappingOptionsIdx .fixParametersBeforeScheduling ] = \
192
+ tc_opt ["fixParametersBeforeScheduling" ]
193
+ opt_vect [MappingOptionsIdx .nTiledDims ] = \
194
+ len (tc_opt ["tile" ])
195
+ assert opt_vect [MappingOptionsIdx .nTiledDims ] < 7 , "Too many tilings"
196
+ opt_vect [
197
+ MappingOptionsIdx .tiling1 : MappingOptionsIdx .tiling1 + opt_vect [MappingOptionsIdx .nTiledDims ]] = \
198
+ tc_opt ["tile" ]
199
+ opt_vect [MappingOptionsIdx .unroll ] = \
200
+ tc_opt ["unroll" ]
201
+ #opt_vect[MappingOptionsIdx.tileImperfectlyNested] = \
202
+ # tc_opt["tileImperfectlyNested"] #todo: pybind
203
+ opt_vect [MappingOptionsIdx .matchLibraryCalls ] = \
204
+ tc_opt ["matchLibraryCalls" ]
205
+ opt_vect [MappingOptionsIdx .nMappedToBlocksDims ] = \
206
+ len (tc_opt ["mapToBlocks" ])
207
+ opt_vect [
208
+ MappingOptionsIdx .mappingToBlocks1 : MappingOptionsIdx .mappingToBlocks1 + opt_vect [MappingOptionsIdx .nMappedToBlocksDims ]] = \
209
+ tc_opt ["mapToBlocks" ]
210
+ opt_vect [MappingOptionsIdx .nMappedToThreadsDims ] = \
211
+ len (tc_opt ["mapToThreads" ])
212
+ opt_vect [
213
+ MappingOptionsIdx .mappingToThreads1 : MappingOptionsIdx .mappingToThreads1 + opt_vect [MappingOptionsIdx .nMappedToThreadsDims ]] = \
214
+ tc_opt ["mapToThreads" ]
215
+ opt_vect [MappingOptionsIdx .useSharedMemory ] = \
216
+ tc_opt ["useSharedMemory" ]
217
+ opt_vect [MappingOptionsIdx .usePrivateMemory ] = \
218
+ tc_opt ["usePrivateMemory" ]
219
+ opt_vect [MappingOptionsIdx .unrollCopyShared ] = \
220
+ tc_opt ["unrollCopyShared" ]
221
+ if (USE_MAX_SHARED_MEMORY and "maxSharedMemory" in tc_opt ):
222
+ opt_vect [MappingOptionsIdx .maxSharedMemory ] = \
223
+ tc_opt ["maxSharedMemory" ]
224
+ opt_vect [MappingOptionsIdx .useReadOnlyCache ] = \
225
+ tc_opt ["useReadOnlyCache" ]
226
+ opt_vect [MappingOptionsIdx .privateDepth ] = \
227
+ tc_opt ["privateDepth" ]
228
+ return opt_vect
229
+
230
+ def optionsFromVector (vect ):
231
+ strat_str = ["Max" , "Preserve3Coincident" , "Min" ]
232
+ options = tc .MappingOptions ("naive" )
233
+ options .outerScheduleFusionStrategy (
234
+ strat_str [vect [
235
+ MappingOptionsIdx .outerScheduleFusionStrategy ]])
236
+ options .intraTileScheduleFusionStrategy (
237
+ strat_str [vect [
238
+ MappingOptionsIdx .intraTileScheduleFusionStrategy ]])
239
+ options .fixParametersBeforeScheduling (
240
+ vect [MappingOptionsIdx .fixParametersBeforeScheduling ])
241
+ options .tile (
242
+ list (vect [
243
+ MappingOptionsIdx .tiling1 : MappingOptionsIdx .tiling1 + vect [MappingOptionsIdx .nTiledDims ]]))
244
+ options .unroll (
245
+ vect [MappingOptionsIdx .unroll ])
246
+ options .matchLibraryCalls (
247
+ vect [MappingOptionsIdx .matchLibraryCalls ])
248
+ options .mapToBlocks (
249
+ list (vect [
250
+ MappingOptionsIdx .mappingToBlocks1 : MappingOptionsIdx .mappingToBlocks1 + vect [MappingOptionsIdx .nMappedToBlocksDims ]]))
251
+ options .mapToThreads (
252
+ list (vect [
253
+ MappingOptionsIdx .mappingToThreads1 : MappingOptionsIdx .mappingToThreads1 + vect [MappingOptionsIdx .nMappedToThreadsDims ]]))
254
+ options .useSharedMemory (
255
+ vect [MappingOptionsIdx .useSharedMemory ])
256
+ options .usePrivateMemory (
257
+ vect [MappingOptionsIdx .usePrivateMemory ])
258
+ options .unrollCopyShared (
259
+ vect [MappingOptionsIdx .unrollCopyShared ])
260
+ if (USE_MAX_SHARED_MEMORY ):
261
+ options .maxSharedMemory (
262
+ vect [MappingOptionsIdx .maxSharedMemory ])
263
+ options .useReadOnlyCache (
264
+ vect [MappingOptionsIdx .useReadOnlyCache ])
265
+ options .privateDepth (
266
+ vect [MappingOptionsIdx .privateDepth ])
267
+ return options
268
+
269
+ def computeDivs (sz ):
270
+ l = []
271
+ for i in range (sz ):
272
+ if (2 ** i > sz ):
273
+ break
274
+ l .append ((sz + 2 ** i - 1 ) // (2 ** i ))
275
+ return l
276
+
277
+ def getAllDivs (inp , maxp2 = 8 ):
278
+ p2 = [2 ** i for i in range (maxp2 + 1 )]
279
+ l = []
280
+ for elem in inp :
281
+ for sz in elem .shape :
282
+ l += computeDivs (sz )
283
+ divs_list = list (set (l + p2 ))
284
+ return sorted (divs_list )
0 commit comments