Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 4fdadf5

Browse files
author
Jules Pondard
committed
Add utilities needed for all the different tuning experiments
Also introduce a new folder where will be all the files related to experimentingnew ways of tuning options.
1 parent c607c68 commit 4fdadf5

File tree

1 file changed

+291
-0
lines changed

1 file changed

+291
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
import time
2+
import torch
3+
import torch.optim as optim
4+
import torch.nn as nn
5+
import torch.utils.data
6+
import torch.nn.functional as F
7+
import tensor_comprehensions as tc
8+
import numpy as np
9+
from enum import IntEnum
10+
11+
NB_HYPERPARAMS, INIT_INPUT_SZ = 26, 7
12+
USE_MAX_SHARED_MENORY=0
13+
14+
class MappingOptionsIdx(IntEnum):
15+
outerScheduleFusionStrategy = 0
16+
intraScheduleFusionStrategy = 1
17+
fixParametersBeforeScheduling = 2
18+
nTiledDims = 3
19+
tiling1 = 4
20+
tiling2 = 5
21+
tiling3 = 6
22+
tiling4 = 7
23+
tiling5 = 8
24+
tiling6 = 9
25+
unroll = 10
26+
matchLibraryCalls = 11
27+
nMappedToBlocksDims = 12
28+
mappingToBlocks1 = 13
29+
mappingToBlocks2 = 14
30+
mappingToBlocks3 = 15
31+
nMappedToThreadsDims = 16
32+
mappingToThreads1 = 17
33+
mappingToThreads2 = 18
34+
mappingToThreads3 = 19
35+
useSharedMemory = 20
36+
usePrivateMemory = 21
37+
unrollCopyShared = 22
38+
maxSharedMemory = 23
39+
useReadOnlyCache = 24
40+
privateDepth = 25
41+
42+
def getrand(l):
43+
return np.random.choice(l).item()
44+
45+
def get_convolution_example(size_type="default", inp_sz_list=[], use_max_shared_memory=False):
46+
global INIT_INPUT_SZ, USE_MAX_SHARED_MEMORY
47+
48+
USE_MAX_SHARED_MEMORY = use_max_shared_memory
49+
50+
INIT_INPUT_SZ = 7
51+
tc_name = "convolution"
52+
tc_code = """
53+
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
54+
O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
55+
}
56+
"""
57+
58+
if(size_type=="input"):
59+
N, C, H, W, O, kH, kW = tuple(inp_sz_list)
60+
elif(size_type=="default"):
61+
N, C, H, W, O, kH, kW = 16, 4, 56, 56, 16, 1, 1 #8, 2, 28, 28, 8, 1, 1
62+
elif(size_type=="random"):
63+
N, C, H, W, O, kH, kW = \
64+
getrand([8, 16, 32, 64]), \
65+
getrand([2, 4, 8, 16]), \
66+
getrand([28, 56, 112]), \
67+
getrand([28, 56, 112]), \
68+
getrand([8, 16, 32]), \
69+
getrand([1, 2, 4]), \
70+
getrand([1, 2, 4])
71+
else:
72+
print("Unknown size type")
73+
exit()
74+
I, W1 = torch.randn(N, C, H, W, device='cuda'), torch.randn(O, C, kH, kW, device='cuda')
75+
init_input = (I, W1)
76+
init_input_sz = np.array([N,C,H,W,O, kH, kW])
77+
print(init_input_sz)
78+
init_input_sz = torch.from_numpy(init_input_sz).float()
79+
80+
computeCat(init_input)
81+
set_tc(tc_code, tc_name)
82+
83+
return (tc_code, tc_name, init_input, init_input_sz)
84+
85+
def print_opt(options):
86+
print(options.tolist())
87+
88+
def set_tc(tc_code_arg, tc_name_arg):
89+
global tc_code, tc_name
90+
tc_code = tc_code_arg
91+
tc_name = tc_name_arg
92+
93+
def catVec_to_optVec(catVec):
94+
global cat_val
95+
opt = [cat_val[i][catVec[i]] for i in range(NB_HYPERPARAMS)]
96+
return opt
97+
98+
def evalTime(opt, iters=50, warmup=10, estimator="mean", naive=False, prune=-1, curr_best=-1):
99+
global tc_code, tc_name, inp, cat_val
100+
101+
infty = 30000
102+
opt = catVec_to_optVec(opt)
103+
if naive:
104+
opt = tc.MappingOptions("naive")
105+
else:
106+
opt = optionsFromVector(opt)
107+
try:
108+
tc_prog = tc.compile(tc_code, tc_name, opt, *inp)
109+
first_ft = tc_prog.executor.profile_kernel(inp)
110+
except (KeyboardInterrupt, SystemExit):
111+
raise
112+
except:
113+
return infty
114+
if(prune != -1 and first_ft > 100*curr_best):
115+
return first_ft
116+
for _ in range(warmup):
117+
tc_prog.executor.profile_kernel(inp)
118+
119+
first_t = tc_prog.executor.profile_kernel(inp)
120+
121+
if(prune != -1 and first_t > prune*curr_best):
122+
return first_t
123+
124+
tc_time_list = []
125+
for i in range(iters):
126+
iter_time = tc_prog.executor.profile_kernel(inp)
127+
tc_time_list.append(iter_time)
128+
if(estimator == "mean"):
129+
mean_time = np.mean(tc_time_list)
130+
return mean_time
131+
elif(estimator == "median"):
132+
median_time = np.median(tc_time_list)
133+
return median_time
134+
elif(estimator == "p25"):
135+
p25_time = np.percentile(tc_time_list, 25)
136+
return p25_time
137+
print("Unknown estimator")
138+
return infty
139+
140+
def getRawVectorFromTcOpt(tc_opt):
141+
tr_dic = {"Max":0, "Preserve3Coincident":1, "Min":2}
142+
opt_vect = np.zeros(NB_HYPERPARAMS).astype(int)
143+
opt_vect[MappingOptionsIdx.outerScheduleFusionStrategy] = \
144+
tr_dic[tc_opt["outerScheduleFusionStrategy"]]
145+
opt_vect[MappingOptionsIdx.intraTileScheduleFusionStrategy] = \
146+
tr_dic[tc_opt["intraTileScheduleFusionStrategy"]]
147+
opt_vect[MappingOptionsIdx.fixParametersBeforeScheduling] = \
148+
tc_opt["fixParametersBeforeScheduling"]
149+
opt_vect[MappingOptionsIdx.nTiledDims] = \
150+
len(tc_opt["tile"])
151+
assert opt_vect[MappingOptionsIdx.nTiledDims] < 7, "Too many tilings"
152+
opt_vect[
153+
MappingOptionsIdx.tiling1 : MappingOptionsIdx.tiling1 + opt_vect[MappingOptionsIdx.nTiledDims]] = \
154+
tc_opt["tile"]
155+
opt_vect[MappingOptionsIdx.unroll] = \
156+
tc_opt["unroll"]
157+
#opt_vect[MappingOptionsIdx.tileImperfectlyNested] = \
158+
# tc_opt["tileImperfectlyNested"] #todo: pybind
159+
opt_vect[MappingOptionsIdx.matchLibraryCalls] = \
160+
tc_opt["matchLibraryCalls"]
161+
opt_vect[MappingOptionsIdx.nMappingToBlocksDims] = \
162+
len(tc_opt["mapToBlocks"])
163+
opt_vect[
164+
MappingOptionsIdx.mappingToBlocks1 : MappingOptionsIdx.mappingToBlocks1 + opt_vect[MappingOptionsIdx.nMappingToBlocksDims]] = \
165+
tc_opt["mapToBlocks"]
166+
opt_vect[MappingOptionsIdx.nMappingToThreadsDims] = \
167+
len(tc_opt["mapToThreads"])
168+
opt_vect[
169+
MappingOptionsIdx.mappingToThreads1 : MappingOptionsIdx.mappingToThreads1 + opt_vect[MappingOptionsIdx.nMappingToThreadsDims]] = \
170+
tc_opt["mapToThreads"]
171+
opt_vect[MappingOptionsIdx.useSharedMemory] = \
172+
tc_opt["useSharedMemory"]
173+
opt_vect[MappingOptionsIdx.usePrivateMemory] = \
174+
tc_opt["usePrivateMemory"]
175+
opt_vect[MappingOptionsIdx.unrollCopyShared] = \
176+
tc_opt["unrollCopyShared"]
177+
if(USE_MAX_SHARED_MEMORY and "maxSharedMemory" in tc_opt):
178+
opt_vect[MappingOptionsIdx.maxSharedMemory] = \
179+
tc_opt["maxSharedMemory"]
180+
opt_vect[MappingOptionsIdx.useReadOnlyCache] = \
181+
tc_opt["useReadOnlyCache"]
182+
opt_vect[MappingOptionsIdx.privateDepth] = \
183+
tc_opt["privateDepth"]
184+
return opt_vect
185+
186+
def optionsFromVector(vect):
187+
strat_str = ["Max", "Preserve3Coincident", "Min"]
188+
options = tc.MappingOptions("naive")
189+
options.outerScheduleFusionStrategy(
190+
strat_str[vect[
191+
MappingOptionsIdx.outerScheduleFusionStrategy]])
192+
options.intraTileScheduleFusionStrategy(
193+
strat_str[vect[
194+
MappingOptionsIdx.intraTileScheduleFusionStrategy]])
195+
options.fixParametersBeforeScheduling(
196+
vect[MappingOptionsIdx.fixParametersBeforeScheduling])
197+
options.tile(
198+
list(vect[
199+
MappingOptionsIdx.tiling1 : MappingOptionsIdx.tiling1 + vect[MappingOptionsIdx.nTiledDims]]))
200+
options.unroll(
201+
vect[MappingOptionsIdx.unroll])
202+
options.matchLibraryCalls(
203+
vect[MappingOptionsIdx.matchLibraryCalls])
204+
options.mapToBlocks(
205+
list(vect[
206+
MappingOptionsIdx.mappingToBlocks1 : MappingOptionsIdx.mappingToBlocks1 + vect[MappingOptionsIdx.nMappingToBlocksDims]]))
207+
options.mapToThreads(
208+
list(vect[
209+
MappingOptionsIdx.mappingToThreads1 : MappingOptionsIdx.mappingToThreads1 + vect[MappingOptionsIdx.nMappingToThreadsDims]]))
210+
options.useSharedMemory(
211+
vect[MappingOptionsIdx.useSharedMemory])
212+
options.usePrivateMemory(
213+
vect[MappingOptionsIdx.usePrivateMemory])
214+
options.unrollCopyShared(
215+
vect[MappingOptionsIdx.unrollCopyShared])
216+
if(USE_MAX_SHARED_MEMORY):
217+
options.maxSharedMemory(
218+
vect[MappingOptionsIdx.maxSharedMemory])
219+
options.useReadOnlyCache(
220+
vect[MappingOptionsIdx.useReadOnlyCache])
221+
options.privateDepth(
222+
vect[MappingOptionsIdx.privateDepth])
223+
return options
224+
225+
def computeDivs(sz):
226+
l = []
227+
for i in range(sz):
228+
if(2**i > sz):
229+
break
230+
l.append((sz+2**i-1)//(2**i))
231+
return l
232+
233+
def getAllDivs(inp, maxp2=8):
234+
p2 = [2**i for i in range(maxp2 + 1)]
235+
l = []
236+
for elem in inp:
237+
for sz in elem.shape:
238+
l += computeDivs(sz)
239+
divs_list = list(set(l + p2))
240+
return sorted(divs_list)
241+
242+
def computeCat(inp_arg):
243+
global cat_sz, cat_val, inp
244+
inp = inp_arg
245+
cat_sz = np.zeros(NB_HYPERPARAMS).astype(int)
246+
cat_val = [[] for _ in range(NB_HYPERPARAMS)]
247+
248+
divs = getAllDivs(inp)
249+
if(USE_MAX_SHARED_MEMORY):
250+
divs2 = getAllDivs([np.array([tc.tclib.shared_memory_size()])])
251+
252+
cat_val[MappingOptionsIdx.outerScheduleFusionStrategy] = \
253+
[0,1,2]
254+
cat_val[MappingOptionsIdx.intraTileScheduleFusionStrategy] = \
255+
[0,1,2]
256+
cat_val[MappingOptionsIdx.fixParametersBeforeScheduling] = \
257+
[0,1]
258+
cat_val[MappingOptionsIdx.nTiledDims] = \
259+
[i+1 for i in range(6)]
260+
for i in range(6): #tiling
261+
cat_val[MappingOptionsIdx.tiling1 + i] = \
262+
divs + [0]
263+
cat_val[MappingOptionsIdx.unroll] = \
264+
[2**i for i in range(8)]
265+
cat_val[MappingOptionsIdx.matchLibraryCalls] = \
266+
[0,1]
267+
cat_val[MappingOptionsIdx.nMappingToBlocksDims] = \
268+
[i+1 for i in range(3)]
269+
for i in range(3): #mapping to blocks
270+
cat_val[MappingOptionsIdx.mappingToBlocks1 + i] = \
271+
divs
272+
cat_val[MappingOptionsIdx.nMappingToThreadsDims] = \
273+
[i+1 for i in range(3)]
274+
for i in range(3): #mapping to threads
275+
cat_val[MappingOptionsIdx.mappingToThreads1 + i] = \
276+
divs
277+
cat_val[MappingOptionsIdx.useSharedMemory] = \
278+
[0,1]
279+
cat_val[MappingOptionsIdx.usePrivateMemory] = \
280+
[0,1]
281+
cat_val[MappingOptionsIdx.unrollCopyShared] = \
282+
[0,1]
283+
cat_val[MappingOptionsIdx.maxSharedMemory] = \
284+
divs2 if USE_MAX_SHARED_MEMORY else [0]
285+
cat_val[MappingOptionsIdx.useReadOnlyCache] = \
286+
[0,1]
287+
cat_val[MappingOptionsIdx.privateDepth] = \
288+
[i for i in range(6)]
289+
290+
for i in range(NB_HYPERPARAMS):
291+
cat_sz[i] = len(cat_val[i])

0 commit comments

Comments
 (0)