Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 58caf8a

Browse files
author
Jules Pondard
committed
Add utilities needed for all the different tuning experiments
Also introduce a new folder where will be all the files related to experimentingnew ways of tuning options.
1 parent c607c68 commit 58caf8a

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import time
2+
import torch
3+
import torch.optim as optim
4+
import torch.nn as nn
5+
import torch.utils.data
6+
import torch.nn.functional as F
7+
import tensor_comprehensions as tc
8+
import numpy as np
9+
10+
NB_HYPERPARAMS, INIT_INPUT_SZ = 26, 7
11+
12+
def getrand(l):
13+
return np.random.choice(l).item()
14+
15+
def get_convolution_example(size_type="default", inp_sz_list=[]):
16+
global INIT_INPUT_SZ
17+
INIT_INPUT_SZ = 7
18+
tc_name = "convolution"
19+
tc_code = """
20+
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
21+
O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
22+
}
23+
"""
24+
25+
if(size_type=="input"):
26+
N, C, H, W, O, kH, kW = tuple(inp_sz_list)
27+
elif(size_type=="default"):
28+
N, C, H, W, O, kH, kW = 16, 4, 56, 56, 16, 1, 1 #8, 2, 28, 28, 8, 1, 1
29+
elif(size_type=="random"):
30+
N, C, H, W, O, kH, kW = \
31+
getrand([8, 16, 32, 64]), \
32+
getrand([2, 4, 8, 16]), \
33+
getrand([28, 56, 112]), \
34+
getrand([28, 56, 112]), \
35+
getrand([8, 16, 32]), \
36+
getrand([1, 2, 4]), \
37+
getrand([1, 2, 4])
38+
else:
39+
print("Unknown size type")
40+
exit()
41+
I, W1 = torch.randn(N, C, H, W, device='cuda'), torch.randn(O, C, kH, kW, device='cuda')
42+
init_input = (I, W1)
43+
init_input_sz = np.array([N,C,H,W,O, kH, kW])
44+
print(init_input_sz)
45+
init_input_sz = torch.from_numpy(init_input_sz).float()
46+
47+
return (tc_code, tc_name, init_input, init_input_sz)
48+
49+
def print_opt(options):
50+
print(options.tolist())
51+
52+
def set_tc(tc_code_arg, tc_name_arg):
53+
global tc_code, tc_name
54+
tc_code = tc_code_arg
55+
tc_name = tc_name_arg
56+
57+
def set_inp(inp_arg):
58+
global inp
59+
inp = inp_arg
60+
61+
def set_vars(tc_prog_arg, inp_arg, cat_val_arg, cat_sz_arg):
62+
global tc_prog, inp, cat_val, cat_sz
63+
tc_prog = tc_prog_arg
64+
inp = inp_arg
65+
cat_val = cat_val_arg
66+
cat_sz = cat_sz_arg
67+
68+
def catVec_to_optVec(catVec):
69+
global cat_val
70+
opt = [cat_val[i][catVec[i]] for i in range(NB_HYPERPARAMS)]
71+
#if(opt[18] * opt[17] > 1024)
72+
#opt[18] = min(opt[18], 1024//opt[17])
73+
#opt[19] = min(opt[19], 1024//(opt[17] * opt[18]))
74+
return opt
75+
76+
def evalTime(opt, iters=50, warmup=10, estimator="mean", naive=False, prune=-1, curr_best=-1):
77+
global tc_code, tc_name, inp, cat_val
78+
79+
infty = 30000
80+
opt = catVec_to_optVec(opt)
81+
if naive:
82+
opt = tc.MappingOptions("naive")
83+
else:
84+
opt = optionsFromVector(opt)
85+
try:
86+
tc_prog = tc.compile(tc_code, tc_name, opt, *inp)
87+
first_ft = tc_prog.executor.profile_kernel(inp)
88+
except (KeyboardInterrupt, SystemExit):
89+
raise
90+
except:
91+
return infty
92+
if(prune != -1 and first_ft > 100*curr_best):
93+
return first_ft
94+
for _ in range(warmup):
95+
tc_prog.executor.profile_kernel(inp)
96+
97+
first_t = tc_prog.executor.profile_kernel(inp)
98+
99+
if(prune != -1 and first_t > prune*curr_best):
100+
return first_t
101+
102+
tc_time_list = []
103+
for i in range(iters):
104+
iter_time = tc_prog.executor.profile_kernel(inp)
105+
tc_time_list.append(iter_time)
106+
if(estimator == "mean"):
107+
mean_time = np.mean(tc_time_list)
108+
return mean_time
109+
elif(estimator == "median"):
110+
median_time = np.median(tc_time_list)
111+
return median_time
112+
elif(estimator == "p25"):
113+
p25_time = np.percentile(tc_time_list, 25)
114+
return p25_time
115+
print("Unknown estimator")
116+
return infty
117+
118+
def getRawVectorFromTcOpt(tc_opt):
119+
tr_dic = {"Max":0, "Preserve3Coincident":1, "Min":2}
120+
opt_vect = np.zeros(NB_HYPERPARAMS).astype(int)
121+
opt_vect[0] = tr_dic[tc_opt["outerScheduleFusionStrategy"]]
122+
opt_vect[1] = tr_dic[tc_opt["intraTileScheduleFusionStrategy"]]
123+
opt_vect[2] = tc_opt["fixParametersBeforeScheduling"]
124+
opt_vect[3] = len(tc_opt["tile"])
125+
opt_vect[4:4+opt_vect[3]] = tc_opt["tile"]
126+
opt_vect[10] = tc_opt["unroll"]
127+
#opt_vect[11] = tc_opt["tileImperfectlyNested"] #todo: pybind
128+
opt_vect[11] = tc_opt["matchLibraryCalls"]
129+
opt_vect[12] = len(tc_opt["mapToBlocks"])
130+
opt_vect[13:13+opt_vect[12]] = tc_opt["mapToBlocks"]
131+
opt_vect[16] = len(tc_opt["mapToThreads"])
132+
opt_vect[17:17+opt_vect[16]] = tc_opt["mapToThreads"]
133+
opt_vect[20] = tc_opt["useSharedMemory"]
134+
opt_vect[21] = tc_opt["usePrivateMemory"]
135+
opt_vect[22] = tc_opt["unrollCopyShared"]
136+
#opt_vect[23] = tc_opt["maxSharedMemory"]
137+
opt_vect[24] = tc_opt["useReadOnlyCache"]
138+
opt_vect[25] = tc_opt["privateDepth"]
139+
return opt_vect
140+
141+
def optionsFromVector(vect):
142+
strat_str = ["Max", "Preserve3Coincident", "Min"]
143+
options = tc.MappingOptions("naive")
144+
options.outerScheduleFusionStrategy(strat_str[vect[0]])
145+
options.intraTileScheduleFusionStrategy(strat_str[vect[1]])
146+
options.fixParametersBeforeScheduling(vect[2])
147+
options.tile(list(vect[4:(4+vect[3])]))
148+
options.unroll(vect[10])
149+
options.matchLibraryCalls(vect[11])
150+
options.mapToBlocks(list(vect[13:13+vect[12]]))
151+
options.mapToThreads(list(vect[17:17+vect[16]])) #grid?
152+
options.useSharedMemory(vect[20])
153+
options.usePrivateMemory(vect[21])
154+
options.unrollCopyShared(vect[22])
155+
#options.maxSharedMemory(vect[23])
156+
options.useReadOnlyCache(vect[24])
157+
options.privateDepth(vect[25])
158+
return options
159+
160+
def computeDivs(sz):
161+
l = []
162+
for i in range(sz): #or 10?
163+
if(2**i > sz):
164+
break
165+
l.append((sz+2**i-1)//(2**i))
166+
return l
167+
168+
def getAllDivs(inp, maxp2=8):
169+
p2 = []
170+
pp=1
171+
for i in range(maxp2+1):
172+
p2.append(pp)
173+
pp*=2
174+
l = []
175+
for elem in inp:
176+
for sz in elem.shape:
177+
l += computeDivs(sz)
178+
my_list = list(set(l+p2))
179+
return np.sort(my_list).tolist()
180+
181+
def computeCat(inp_arg):
182+
global cat_sz, cat_val, inp
183+
inp = inp_arg
184+
cat_sz = np.zeros(NB_HYPERPARAMS).astype(int)
185+
cat_val = []
186+
187+
divs = getAllDivs(inp)
188+
#divs2 = getAllDivs([np.array([tc.tclib.shared_memory_size()])])
189+
190+
cat_val.append([0,1,2]) #0
191+
cat_val.append([0,1,2]) #1
192+
cat_val.append([0,1]) #2
193+
cat_val.append([i+1 for i in range(6)]) #3
194+
for i in range(6): #tiling #4-9
195+
cat_val.append(divs + [0]) #4-9
196+
cat_val.append([2**i for i in range(8)]) #10
197+
cat_val.append([0,1]) #11
198+
cat_val.append([i+1 for i in range(3)]) #12
199+
for i in range(3): #13-15
200+
cat_val.append(divs) #blocks #maximum 2^31-1 for the first value and 65535 for the second and third
201+
cat_val.append([i+1 for i in range(3)]) #16
202+
for i in range(3): #17-19
203+
cat_val.append(divs) #threads #maximum 1024 for the first and second value, 32 for the third, product below 1024
204+
cat_val.append([0,1]) #20
205+
cat_val.append([0,1]) #21
206+
cat_val.append([0,1]) #22
207+
cat_val.append([0]) #cat_val.append(divs2) #23
208+
cat_val.append([0,1]) #24
209+
cat_val.append([i for i in range(6)]) #25
210+
211+
for i in range(NB_HYPERPARAMS):
212+
cat_sz[i] = len(cat_val[i])

0 commit comments

Comments
 (0)