Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 9763f96

Browse files
author
Jules Pondard
committed
Add utilities needed for all the different tuning experiments
Also introduce a new folder where will be all the files related to experimentingnew ways of tuning options.
1 parent c607c68 commit 9763f96

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import time
2+
import torch
3+
import torch.optim as optim
4+
import torch.nn as nn
5+
import torch.utils.data
6+
import torch.nn.functional as F
7+
import tensor_comprehensions as tc
8+
import numpy as np
9+
10+
NB_HYPERPARAMS, INIT_INPUT_SZ = 26, 7
11+
USE_MAX_SHARED_MENORY=0
12+
13+
def getrand(l):
14+
return np.random.choice(l).item()
15+
16+
def get_convolution_example(size_type="default", inp_sz_list=[], use_max_shared_memory=False):
17+
global INIT_INPUT_SZ, USE_MAX_SHARED_MEMORY
18+
19+
USE_MAX_SHARED_MEMORY = use_max_shared_memory
20+
21+
INIT_INPUT_SZ = 7
22+
tc_name = "convolution"
23+
tc_code = """
24+
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
25+
O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
26+
}
27+
"""
28+
29+
if(size_type=="input"):
30+
N, C, H, W, O, kH, kW = tuple(inp_sz_list)
31+
elif(size_type=="default"):
32+
N, C, H, W, O, kH, kW = 16, 4, 56, 56, 16, 1, 1 #8, 2, 28, 28, 8, 1, 1
33+
elif(size_type=="random"):
34+
N, C, H, W, O, kH, kW = \
35+
getrand([8, 16, 32, 64]), \
36+
getrand([2, 4, 8, 16]), \
37+
getrand([28, 56, 112]), \
38+
getrand([28, 56, 112]), \
39+
getrand([8, 16, 32]), \
40+
getrand([1, 2, 4]), \
41+
getrand([1, 2, 4])
42+
else:
43+
print("Unknown size type")
44+
exit()
45+
I, W1 = torch.randn(N, C, H, W, device='cuda'), torch.randn(O, C, kH, kW, device='cuda')
46+
init_input = (I, W1)
47+
init_input_sz = np.array([N,C,H,W,O, kH, kW])
48+
print(init_input_sz)
49+
init_input_sz = torch.from_numpy(init_input_sz).float()
50+
51+
computeCat(init_input)
52+
set_tc(tc_code, tc_name)
53+
54+
return (tc_code, tc_name, init_input, init_input_sz)
55+
56+
def print_opt(options):
57+
print(options.tolist())
58+
59+
def set_tc(tc_code_arg, tc_name_arg):
60+
global tc_code, tc_name
61+
tc_code = tc_code_arg
62+
tc_name = tc_name_arg
63+
64+
def catVec_to_optVec(catVec):
65+
global cat_val
66+
opt = [cat_val[i][catVec[i]] for i in range(NB_HYPERPARAMS)]
67+
return opt
68+
69+
def evalTime(opt, iters=50, warmup=10, estimator="mean", naive=False, prune=-1, curr_best=-1):
70+
global tc_code, tc_name, inp, cat_val
71+
72+
infty = 30000
73+
opt = catVec_to_optVec(opt)
74+
if naive:
75+
opt = tc.MappingOptions("naive")
76+
else:
77+
opt = optionsFromVector(opt)
78+
try:
79+
tc_prog = tc.compile(tc_code, tc_name, opt, *inp)
80+
first_ft = tc_prog.executor.profile_kernel(inp)
81+
except (KeyboardInterrupt, SystemExit):
82+
raise
83+
except:
84+
return infty
85+
if(prune != -1 and first_ft > 100*curr_best):
86+
return first_ft
87+
for _ in range(warmup):
88+
tc_prog.executor.profile_kernel(inp)
89+
90+
first_t = tc_prog.executor.profile_kernel(inp)
91+
92+
if(prune != -1 and first_t > prune*curr_best):
93+
return first_t
94+
95+
tc_time_list = []
96+
for i in range(iters):
97+
iter_time = tc_prog.executor.profile_kernel(inp)
98+
tc_time_list.append(iter_time)
99+
if(estimator == "mean"):
100+
mean_time = np.mean(tc_time_list)
101+
return mean_time
102+
elif(estimator == "median"):
103+
median_time = np.median(tc_time_list)
104+
return median_time
105+
elif(estimator == "p25"):
106+
p25_time = np.percentile(tc_time_list, 25)
107+
return p25_time
108+
print("Unknown estimator")
109+
return infty
110+
111+
def getRawVectorFromTcOpt(tc_opt):
112+
tr_dic = {"Max":0, "Preserve3Coincident":1, "Min":2}
113+
opt_vect = np.zeros(NB_HYPERPARAMS).astype(int)
114+
opt_vect[0] = tr_dic[tc_opt["outerScheduleFusionStrategy"]]
115+
opt_vect[1] = tr_dic[tc_opt["intraTileScheduleFusionStrategy"]]
116+
opt_vect[2] = tc_opt["fixParametersBeforeScheduling"]
117+
opt_vect[3] = len(tc_opt["tile"])
118+
assert opt_vect[3] < 7, "Too many tilings"
119+
opt_vect[4:4+opt_vect[3]] = tc_opt["tile"]
120+
opt_vect[10] = tc_opt["unroll"]
121+
#opt_vect[11] = tc_opt["tileImperfectlyNested"] #todo: pybind
122+
opt_vect[11] = tc_opt["matchLibraryCalls"]
123+
opt_vect[12] = len(tc_opt["mapToBlocks"])
124+
opt_vect[13:13+opt_vect[12]] = tc_opt["mapToBlocks"]
125+
opt_vect[16] = len(tc_opt["mapToThreads"])
126+
opt_vect[17:17+opt_vect[16]] = tc_opt["mapToThreads"]
127+
opt_vect[20] = tc_opt["useSharedMemory"]
128+
opt_vect[21] = tc_opt["usePrivateMemory"]
129+
opt_vect[22] = tc_opt["unrollCopyShared"]
130+
if(USE_MAX_SHARED_MEMORY and "maxSharedMemory" in tc_opt):
131+
opt_vect[23] = tc_opt["maxSharedMemory"]
132+
opt_vect[24] = tc_opt["useReadOnlyCache"]
133+
opt_vect[25] = tc_opt["privateDepth"]
134+
return opt_vect
135+
136+
def optionsFromVector(vect):
137+
strat_str = ["Max", "Preserve3Coincident", "Min"]
138+
options = tc.MappingOptions("naive")
139+
options.outerScheduleFusionStrategy(strat_str[vect[0]])
140+
options.intraTileScheduleFusionStrategy(strat_str[vect[1]])
141+
options.fixParametersBeforeScheduling(vect[2])
142+
options.tile(list(vect[4:(4+vect[3])]))
143+
options.unroll(vect[10])
144+
options.matchLibraryCalls(vect[11])
145+
options.mapToBlocks(list(vect[13:13+vect[12]]))
146+
options.mapToThreads(list(vect[17:17+vect[16]]))
147+
options.useSharedMemory(vect[20])
148+
options.usePrivateMemory(vect[21])
149+
options.unrollCopyShared(vect[22])
150+
if(USE_MAX_SHARED_MEMORY):
151+
options.maxSharedMemory(vect[23])
152+
options.useReadOnlyCache(vect[24])
153+
options.privateDepth(vect[25])
154+
return options
155+
156+
def computeDivs(sz):
157+
l = []
158+
for i in range(sz):
159+
if(2**i > sz):
160+
break
161+
l.append((sz+2**i-1)//(2**i))
162+
return l
163+
164+
def getAllDivs(inp, maxp2=8):
165+
p2 = [2**i for i in range(maxp2 + 1)]
166+
l = []
167+
for elem in inp:
168+
for sz in elem.shape:
169+
l += computeDivs(sz)
170+
divs_list = list(set(l + p2))
171+
return sorted(divs_list)
172+
173+
def computeCat(inp_arg):
174+
global cat_sz, cat_val, inp
175+
inp = inp_arg
176+
cat_sz = np.zeros(NB_HYPERPARAMS).astype(int)
177+
cat_val = []
178+
179+
divs = getAllDivs(inp)
180+
if(USE_MAX_SHARED_MEMORY):
181+
divs2 = getAllDivs([np.array([tc.tclib.shared_memory_size()])])
182+
183+
cat_val.append([0,1,2]) #0
184+
cat_val.append([0,1,2]) #1
185+
cat_val.append([0,1]) #2
186+
cat_val.append([i+1 for i in range(6)]) #3
187+
for i in range(6): #tiling #4-9
188+
cat_val.append(divs + [0]) #4-9
189+
cat_val.append([2**i for i in range(8)]) #10
190+
cat_val.append([0,1]) #11
191+
cat_val.append([i+1 for i in range(3)]) #12
192+
for i in range(3): #13-15
193+
cat_val.append(divs) #blocks #maximum 2^31-1 for the first value and 65535 for the second and third
194+
cat_val.append([i+1 for i in range(3)]) #16
195+
for i in range(3): #17-19
196+
cat_val.append(divs) #threads #maximum 1024 for the first and second value, 32 for the third, product below 1024
197+
cat_val.append([0,1]) #20
198+
cat_val.append([0,1]) #21
199+
cat_val.append([0,1]) #22
200+
if(USE_MAX_SHARED_MEMORY): #23
201+
cat_val.append(divs2)
202+
else:
203+
cat_val.append([0])
204+
cat_val.append([0,1]) #24
205+
cat_val.append([i for i in range(6)]) #25
206+
207+
for i in range(NB_HYPERPARAMS):
208+
cat_sz[i] = len(cat_val[i])

0 commit comments

Comments
 (0)