diff --git a/python/experimental/options_search/eval_genetic_options.py b/python/experimental/options_search/eval_genetic_options.py new file mode 100644 index 000000000..3edb13358 --- /dev/null +++ b/python/experimental/options_search/eval_genetic_options.py @@ -0,0 +1,23 @@ +import tensor_comprehensions as tc +import tensor_comprehensions.tclib as tclib +import utils + +cache = tc.MappingOptionsCache("genetic_savedopt_conv_default.txt") + +exptuner_config = utils.ExpTunerConfig() +exptuner_config.set_convolution_tc() +tc_code, tc_name, inp = exptuner_config.tc_code, exptuner_config.tc_name, exptuner_config.inp + +print("divs : " + str(utils.getAllDivs(inp))) +tup = cache.load(tc_code, tc_name, inp, 1) +if(tup == []): + exit() +best_options, = tup +best_options = best_options.getDict() +optsVect = utils.getRawVectorFromTcOpt(best_options) +opts = utils.optionsFromVector(optsVect) +print(opts) + +time = utils.evalTime(opts, exptuner_config, estimator="median") +print(time) + diff --git a/python/experimental/options_search/eval_options.py b/python/experimental/options_search/eval_options.py new file mode 100644 index 000000000..242442f2c --- /dev/null +++ b/python/experimental/options_search/eval_options.py @@ -0,0 +1,21 @@ +import numpy as np +import time +import utils +import torch +import torch.optim as optim +import torch.nn as nn +import torch.utils.data +import torch.nn.functional as F +import tensor_comprehensions as tc + +set_options = [ +[1, 1, 0, 1, 2, 1, 3, 6, 8, 0, 3, 0, 2, 11, 9, 8, 2, 0, 0, 3, 0, 0, 1, 0, 1, 2] +] + +exptuner_config = utils.ExpTunerConfig() +exptuner_config.set_convolution_tc() + +for i in range(len(set_options)): + opts = np.array(set_options[i]) + time = utils.evalTime(opts, exptuner_config) + print(time) diff --git a/python/experimental/options_search/generate_genetic_options.py b/python/experimental/options_search/generate_genetic_options.py new file mode 100644 index 000000000..45c55deda --- /dev/null +++ b/python/experimental/options_search/generate_genetic_options.py @@ -0,0 +1,16 @@ +import numpy as np +import torch +import tensor_comprehensions as tc + +import utils + +exptuner_config = utils.ExpTunerConfig() +exptuner_config.set_convolution_tc() +tc_code, tc_name, inp = exptuner_config.tc_code, exptuner_config.tc_name, exptuner_config.inp +#config = tc.autotuner_settings +#config["pop_size"]=50 +#config["generations"]=1 +opts = tc.MappingOptions("naive") +print(opts) + +tc.autotune(tc_code, tc_name, *inp, starting_options=opts, cache_filename="genetic_savedopt_conv_default.txt", store_to_cache=True) diff --git a/python/experimental/options_search/mcts.py b/python/experimental/options_search/mcts.py new file mode 100644 index 000000000..9a2bc8c8b --- /dev/null +++ b/python/experimental/options_search/mcts.py @@ -0,0 +1,179 @@ +import tensor_comprehensions as tc +import torch +import utils +import numpy as np +#from tqdm import tqdm +from visdom import Visdom + +viz = Visdom() + +class Node: + def __init__(self, father=None, new_act=0): + self.value = 0 + self.values = [] + self.nbVisits=0 + self.nbChildrenSeen = 0 + self.pos=0 + #self.hasSeen = {} #todo + self.children=[] + self.parent = father + self.stateVector = [0] * utils.NB_HYPERPARAMS + if(father != None): + self.pos = father.pos+1 + #self.hasSeen = {} #todo + self.stateVector = father.stateVector[:] + self.stateVector[self.pos-1] = new_act + + def getRoot(self): + return self + + def getParent(self): + return self.parent + + def notRoot(self): + return (self.parent != None) + +class MCTS: + def __init__(self): + self.C = 1 #to tune + + self.exptuner_config = utils.ExpTunerConfig() + self.exptuner_config.set_convolution_tc() + + self.nbActions = self.exptuner_config.cat_sz + self.tree = Node() + + self.best_rewards = [] + self.rws = [] + + self.curIter=0 + self.curr_best=0 + self.running_reward=0 + self.win0 = viz.line(X=np.arange(5), Y=np.random.rand(5)) + + def main_search(self, starting_pos): #, init_inp): + node = starting_pos + #node.nbVisits+=1 + ttNbIters = 10 #2*self.nbActions[node.pos] + for _ in range(max(ttNbIters, self.nbActions[node.pos])): + leaf = self.getLeaf(node) + val = self.evaluate(leaf) + self.backup(leaf, val) + #print(node.value / node.nbVisits) + _, action = self.getBestChild2(node) + return action + + def take_action(self, node, act): + if(node.nbChildrenSeen > act): + return node.children[act] + new_child = Node(father=node, new_act=act) + node.children.append(new_child) + #node.hasSeen[act]=1 + node.nbChildrenSeen += 1 + return node.children[-1] + + def getLeaf(self, node): + first=True + while(node.pos < utils.NB_HYPERPARAMS and (first or node.nbVisits != 0)): + first=False + pos = node.pos + if(node.nbChildrenSeen == self.nbActions[pos]): + node, _ = self.getBestChild(node) + else: + act=node.nbChildrenSeen + self.take_action(node, act) + return node.children[-1] + return node + + def getBestChild2(self, node): + bestIndic = 0. + bestAction = 0 + first=True + pos = node.pos + for act in range(self.nbActions[pos]): + child = node.children[act] + #indic = np.percentile(child.values, 20) + indic = child.value / child.nbVisits + if(first or indic > bestIndic): + bestIndic = indic + bestAction = act + first=False + return node.children[bestAction], bestAction + + def getBestChild(self, node): + bestIndic = 0. + bestAction = 0 + first=True + pos = node.pos + for act in range(self.nbActions[pos]): + child = node.children[act] + #indic = np.percentile(child.values, 20) + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits) + indic = child.value / child.nbVisits + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits) + if(first or indic > bestIndic): + bestIndic = indic + bestAction = act + first=False + return node.children[bestAction], bestAction + + def saveReward(self, reward, opts): + INTER_DISP = 20 + #print(-reward) + if(self.curIter == 0): + self.running_reward = reward + self.curr_best = reward + if(self.curIter == 0 or reward > self.curr_best): + print(-reward) + print(opts) + self.curIter += 1 + self.running_reward = self.running_reward * 0.99 + reward * 0.01 + self.curr_best = max(self.curr_best, reward) + #self.rewards.append(-reward) + self.best_rewards.append(-self.curr_best) + self.rws.append(-self.running_reward) + if self.curIter % INTER_DISP == 0: + viz.line(X=np.column_stack((np.arange(self.curIter), np.arange(self.curIter))), \ + Y=np.column_stack((np.array(self.rws), np.array(self.best_rewards))), \ + win=self.win0, opts=dict(legend=["Geometric run", "Best time"])) + + def randomSampleScoreFrom(self, node): + pos = node.pos + optsVector = node.stateVector + for i in range(utils.NB_HYPERPARAMS - (pos)): + a = np.random.randint(self.nbActions[i+pos]) + optsVector[i+(pos)] = a + #print(optsVector) + reward = -np.log(utils.evalTime(optsVector, self.exptuner_config)) + self.saveReward(reward, optsVector) + return reward + + def evaluate(self, leaf): + score = 0 + nb_iters=5 + for _ in range(nb_iters): + score += self.randomSampleScoreFrom(leaf) + return score / nb_iters + + def backup(self, leaf, val): + #if(val > 10.): #infty + # return + node = leaf + while(node.notRoot()): + node.nbVisits += 1 + #node.values.append(val) + node.value += val + node = node.getParent() + node.nbVisits += 1 + node.value += val + node.values.append(val) + +mcts = MCTS() + +opts = [] +curr_node = mcts.tree +for i in range(utils.NB_HYPERPARAMS): + opts.append(mcts.main_search(curr_node)) + curr_node = mcts.take_action(curr_node, opts[-1]) + print(opts) +opts = np.array(opts).astype(int) +print(utils.evalTime(opts.tolist(), mcts.exptuner_config)) +utils.print_opt(opts) diff --git a/python/experimental/options_search/predict_time.py b/python/experimental/options_search/predict_time.py new file mode 100644 index 000000000..6b58df251 --- /dev/null +++ b/python/experimental/options_search/predict_time.py @@ -0,0 +1,57 @@ +import time +import torch +import tensor_comprehensions as tc +#import sklearn +#from sklearn.linear_model import LinearRegression +#from sklearn.ensemble import GradientBoostingRegressor +import numpy as np +#from sklearn.model_selection import train_test_split +#from tensor_comprehensions.mapping_options import Options +from multiprocessing import Pool +from itertools import repeat +import utils +#from tqdm import tqdm + +exptuner_config = utils.ExpTunerConfig() +exptuner_config.set_convolution_tc() + +NB_HYPERPARAMS = utils.NB_HYPERPARAMS + +def createY(x): + y = utils.evalTime(x, exptuner_config) + return y + +def getRandom(): + opt_v = np.zeros(NB_HYPERPARAMS).astype(int) + for i in range(opt_v.shape[0]): + opt_v[i] = np.random.randint(exptuner_config.cat_sz[i]) + return opt_v + +def makeDataset(): + from tqdm import tqdm + sz = 500 + datasetX, datasetY = [], [] + for _ in tqdm(range(sz)): + opt = getRandom() + yi = createY(opt) + datasetX.append(opt) + datasetY.append(yi) + #with Pool(sz) as p: + # datasetY = p.starmap(createY, datasetX) + return np.array(datasetX), np.array(datasetY) + +def learn(): + #from sklearn.linear_model import LinearRegression + from sklearn.ensemble import GradientBoostingRegressor + from sklearn.model_selection import train_test_split + datasetX, datasetY = makeDataset() + print(min(datasetY)) + Xtrain, Xtest, Ytrain, Ytest = train_test_split(datasetX, datasetY, test_size=0.2, random_state = 42) + model1 = GradientBoostingRegressor(n_estimators=1000) + model1.fit(Xtrain, Ytrain) + pred0 = model1.predict(Xtrain) + pred1 = model1.predict(Xtest) + print(np.corrcoef(pred0, Ytrain)[0, 1]**2) + print(np.corrcoef(pred1, Ytest)[0,1]**2) + +#learn() diff --git a/python/experimental/options_search/random_search.py b/python/experimental/options_search/random_search.py new file mode 100644 index 000000000..e01c7cab3 --- /dev/null +++ b/python/experimental/options_search/random_search.py @@ -0,0 +1,63 @@ +import numpy as np +#import ipdb +import torch +import torch.optim as optim +import torch.nn as nn +import torch.utils.data +import torch.nn.functional as F +import tensor_comprehensions as tc +from visdom import Visdom + +import utils + +NB_EPOCHS = 1000 +BATCH_SZ = 1 + +viz = Visdom() +win0 = viz.line(X=np.arange(NB_EPOCHS), Y=np.random.rand(NB_EPOCHS)) + +exptuner_config = utils.ExpTunerConfig() +exptuner_config.set_convolution_tc() + +NB_HYPERPARAMS = utils.NB_HYPERPARAMS + +def getRandom(): + opt_v = np.zeros(NB_HYPERPARAMS).astype(int) + for i in range(opt_v.shape[0]): + opt_v[i] = np.random.randint(exptuner_config.cat_sz[i]) + return opt_v + +INTER_DISP = 20 + +running_reward = -0.5 +tab_rewards=[] +tab_best=[] +best=-12 +best_options = -1 +for i in range(NB_EPOCHS): + rewards = [] + opts=[] + for j in range(BATCH_SZ): + out = getRandom() + reward = utils.evalTime(out.astype(int), exptuner_config, prune=2, curr_best=np.exp(-best)) + reward = -np.log(reward) + rewards.append(reward) + opts.append(out.astype(int)) + if(best < np.max(rewards) or i==0): + best = np.max(rewards) + ind=np.argmax(rewards) + best_options = opts[ind] + utils.print_opt(best_options) + if(i==0): + running_reward = reward + running_reward = running_reward * 0.99 + np.mean(rewards) * 0.01 + tab_rewards.append(-running_reward) + tab_best.append(-best) + if i % INTER_DISP == 0: + viz.line(X=np.column_stack((np.arange(i+1), np.arange(i+1))), Y=np.column_stack((np.array(tab_rewards), np.array(tab_best))), win=win0, opts=dict(legend=["Geometric run", "Best time"])) + print(-running_reward) + print(-best) +tab_best = np.array(tab_best) +np.save("randomsearch.npy", tab_best) +print("Finally, best options are:") +utils.print_opt(best_options) diff --git a/python/experimental/options_search/rl_memory_replay.py b/python/experimental/options_search/rl_memory_replay.py new file mode 100644 index 000000000..9c4d99584 --- /dev/null +++ b/python/experimental/options_search/rl_memory_replay.py @@ -0,0 +1,172 @@ +import numpy as np +import torch +import torch.optim as optim +import torch.nn as nn +import torch.utils.data +import torch.nn.functional as F +#import ipdb +from itertools import count +from collections import namedtuple +from torch.distributions import Categorical +import tensor_comprehensions as tc +from visdom import Visdom +from collections import deque +from heapq import heappush, heappop + +import utils + +NB_EPOCHS = 1000 +BATCH_SZ = 16 +buff = deque() +MAXI_BUFF_SZ = 50 + +exptuner_config = utils.ExpTunerConfig() +exptuner_config.set_convolution_tc() + +NB_HYPERPARAMS = utils.NB_HYPERPARAMS +INIT_INPUT_SZ = exptuner_config.INIT_INPUT_SZ +init_input_sz = exptuner_config.init_input_sz + +viz = Visdom() +win0 = viz.line(X=np.arange(NB_EPOCHS), Y=np.random.rand(NB_EPOCHS)) +win1 = viz.line(X=np.arange(NB_EPOCHS), Y=np.random.rand(NB_EPOCHS)) + +SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) + +layer_sz = 32 + +class Predictor(nn.Module): + def __init__(self, nb_inputs, nb_actions): + super(Predictor, self).__init__() + self.affine1 = nn.Linear(nb_inputs, layer_sz) + self.affine15 = nn.Linear(layer_sz, layer_sz) + self.affine2 = nn.Linear(layer_sz, nb_actions) + self.affine3 = nn.Linear(layer_sz, 1) + + self.W = nn.Linear(nb_inputs, nb_inputs) + + def forward(self, x): + #ipdb.set_trace() + #x = F.softmax(self.W(x), dim=-1) * x #attention mecanism + tmp1 = F.relu(self.affine1(x)) + #tmp1 = F.relu(self.affine15(tmp1)) + out_action = F.softmax(self.affine2(tmp1), dim=-1) + out_value = self.affine3(tmp1) + return out_action, out_value + +class FullNetwork(nn.Module): + def __init__(self, nb_hyperparams, init_input_sz): + super(FullNetwork, self).__init__() + self.nb_hyperparams = nb_hyperparams + self.init_input_sz = init_input_sz + self.nets = [Predictor(init_input_sz + i, int(exptuner_config.cat_sz[i])) for i in range(nb_hyperparams)] + self.nets = nn.ModuleList(self.nets) + + def select_action(self, x, i, out_sz): + geps = 0.1 + proba = np.random.rand() + probs, state_value = self.nets[i](x) + if(proba <= geps): + probs = torch.FloatTensor([1./out_sz]*out_sz) + m = Categorical(probs) + action = m.sample() + return action.item(), m.log_prob(action), state_value + + def forward(self, x): + actions_prob = [] + values = [] + for i in range(self.nb_hyperparams): + sym, action_prob, value = self.select_action(x, i, int(exptuner_config.cat_sz[i])) + actions_prob.append(action_prob) + values.append(value) + x = torch.cat([x, torch.FloatTensor([sym])]) + return x[INIT_INPUT_SZ:], actions_prob, values + +net = FullNetwork(NB_HYPERPARAMS, INIT_INPUT_SZ) +optimizer = optim.Adam(net.parameters(), lr=0.0001) +eps = np.finfo(np.float32).eps.item() + +def finish_episode(actions_probs, values, final_rewards): + policy_losses = [[] for i in range(BATCH_SZ)] + value_losses = [[] for i in range(BATCH_SZ)] + final_rewards = torch.tensor(list(final_rewards)) + #final_rewards = (final_rewards - final_rewards.mean()) / (final_rewards.std() + eps) + for batch_id in range(BATCH_SZ): + for (log_prob, value) in zip(actions_probs[batch_id], values[batch_id]): + reward = final_rewards[batch_id] - value.item() + policy_losses[batch_id].append(-log_prob * reward) + value_losses[batch_id].append(F.smooth_l1_loss(value, torch.tensor([final_rewards[batch_id]]))) + optimizer.zero_grad() + vloss = torch.stack([torch.stack(value_losses[i]).sum() for i in range(BATCH_SZ)]).mean() + ploss = torch.stack([torch.stack(policy_losses[i]).sum() for i in range(BATCH_SZ)]).mean() + loss = ploss + vloss + loss.backward(retain_graph=True) + optimizer.step() + return vloss.item(), ploss.item() + +def add_to_buffer(actions_probs, values, reward): + global buff + #if(len(buff) > 0): + # min_reward = np.min(np.array(buff)[:,2]) + # if(reward < 10*min_reward): + # return + if len(buff) == MAXI_BUFF_SZ: + #heappop(buff) + buff.popleft() + #heappush(buff, (reward, actions_probs, values)) + buff.append((reward, actions_probs, values)) + +def select_batch(): + #random.sample() + batch = [buff[np.random.randint(len(buff))] for i in range(BATCH_SZ)] + #batch.append(buff[-1]) + batch=np.array(batch) + return batch[:,1], batch[:,2], batch[:,0] + +def get_best_buff(): + return np.max(np.array(buff)[:,0]) + +INTER_DISP = 20 + +running_reward = -0.5 +tab_rewards=[] +tab_best=[] +best=-12 +v_losses=[] +p_losses=[] +best_options = np.zeros(NB_HYPERPARAMS).astype(int) +for i in range(NB_EPOCHS): + rewards = [] + out_actions, out_probs, out_values = net(init_input_sz) + #utils.print_opt(out_actions.numpy().astype(int)) + reward = utils.evalTime(out_actions.numpy().astype(int), exptuner_config, prune=-1, curr_best=np.exp(-best)) + #reward=100*reward + #reward = -((reward)/1000) + reward = -np.log(reward) + add_to_buffer(out_probs, out_values, reward) + best_in_buffer = get_best_buff() + if(i >= 20): + actions_probs, values, rewards = select_batch() + for j in range(1): + vloss, ploss = finish_episode(actions_probs, values, rewards) + v_losses.append(vloss) + p_losses.append(ploss) + if(best < reward or i==0): + best=reward + best_options = out_actions.numpy().astype(int) + utils.print_opt(best_options) + if(i==0): + running_reward = reward + running_reward = running_reward * 0.99 + reward * 0.01 + tab_rewards.append(-(running_reward)) + tab_best.append(-best) + if i % INTER_DISP == 0: + viz.line(X=np.column_stack((np.arange(i+1), np.arange(i+1))), Y=np.column_stack((np.array(tab_rewards), np.array(tab_best))), win=win0, opts=dict(legend=["Geometric run", "Best time"])) + if(len(v_losses) > 0): + viz.line(X=np.column_stack((np.arange(len(v_losses)), np.arange(len(v_losses)))), Y=np.column_stack((np.array(v_losses), np.array(p_losses))), win=win1, opts=dict(legend=["Value loss", "Policy loss"])) + print(-running_reward) + print(-best) + print("Best in buffer: " + str(-best_in_buffer)) + +print("Finally, best options are:") +utils.print_opt(best_options) diff --git a/python/experimental/options_search/utils.py b/python/experimental/options_search/utils.py new file mode 100644 index 000000000..2f897867b --- /dev/null +++ b/python/experimental/options_search/utils.py @@ -0,0 +1,284 @@ +import time +import torch +import torch.optim as optim +import torch.nn as nn +import torch.utils.data +import torch.nn.functional as F +import tensor_comprehensions as tc +import numpy as np +from enum import IntEnum + +NB_HYPERPARAMS = 26 + +class ExpTunerConfig: + def __init__(self, use_max_shared_memory=0): + self.INIT_INPUT_SZ = -1 + self.USE_MAX_SHARED_MEMORY = use_max_shared_memory + self.tc_code = "" + self.tc_name = "" + self.inp = -1 + self.cat_val = -1 + self.cat_sz = -1 + + def set_convolution_tc(self, size_type="default", inp_sz_list=[], use_max_shared_memory=False): + self.INIT_INPUT_SZ = 7 + self.tc_name = "convolution" + self.tc_code = """ + def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) { + O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw) + } + """ + + if(size_type=="input"): + N, C, H, W, O, kH, kW = tuple(inp_sz_list) + elif(size_type=="default"): + N, C, H, W, O, kH, kW = 16, 4, 56, 56, 16, 1, 1 #8, 2, 28, 28, 8, 1, 1 + elif(size_type=="random"): + N, C, H, W, O, kH, kW = \ + getrand([8, 16, 32, 64]), \ + getrand([2, 4, 8, 16]), \ + getrand([28, 56, 112]), \ + getrand([28, 56, 112]), \ + getrand([8, 16, 32]), \ + getrand([1, 2, 4]), \ + getrand([1, 2, 4]) + else: + print("Unknown size type") + exit() + I, W1 = torch.randn(N, C, H, W, device='cuda'), torch.randn(O, C, kH, kW, device='cuda') + self.inp = (I, W1) + self.init_input_sz = np.array([N,C,H,W,O, kH, kW]) + print(self.init_input_sz) + self.init_input_sz = torch.from_numpy(self.init_input_sz).float() + + self.computeCat() + + def computeCat(self): + inp = self.inp + self.cat_sz = np.zeros(NB_HYPERPARAMS).astype(int) + self.cat_val = [[] for _ in range(NB_HYPERPARAMS)] + + divs = getAllDivs(inp) + if(self.USE_MAX_SHARED_MEMORY): + divs2 = getAllDivs([np.array([tc.tclib.shared_memory_size()])]) + + self.cat_val[MappingOptionsIdx.outerScheduleFusionStrategy] = \ + [0,1,2] + self.cat_val[MappingOptionsIdx.intraTileScheduleFusionStrategy] = \ + [0,1,2] + self.cat_val[MappingOptionsIdx.fixParametersBeforeScheduling] = \ + [0,1] + self.cat_val[MappingOptionsIdx.nTiledDims] = \ + [i+1 for i in range(6)] + for i in range(6): #tiling + self.cat_val[MappingOptionsIdx.tiling1 + i] = \ + divs + [0] + self.cat_val[MappingOptionsIdx.unroll] = \ + [2**i for i in range(8)] + self.cat_val[MappingOptionsIdx.matchLibraryCalls] = \ + [0,1] + self.cat_val[MappingOptionsIdx.nMappedToBlocksDims] = \ + [i+1 for i in range(3)] + for i in range(3): #mapping to blocks + self.cat_val[MappingOptionsIdx.mappingToBlocks1 + i] = \ + divs + self.cat_val[MappingOptionsIdx.nMappedToThreadsDims] = \ + [i+1 for i in range(3)] + for i in range(3): #mapping to threads + self.cat_val[MappingOptionsIdx.mappingToThreads1 + i] = \ + divs + self.cat_val[MappingOptionsIdx.useSharedMemory] = \ + [0,1] + self.cat_val[MappingOptionsIdx.usePrivateMemory] = \ + [0,1] + self.cat_val[MappingOptionsIdx.unrollCopyShared] = \ + [0,1] + self.cat_val[MappingOptionsIdx.maxSharedMemory] = \ + divs2 if USE_MAX_SHARED_MEMORY else [0] + self.cat_val[MappingOptionsIdx.useReadOnlyCache] = \ + [0,1] + self.cat_val[MappingOptionsIdx.privateDepth] = \ + [i for i in range(6)] + + for i in range(NB_HYPERPARAMS): + self.cat_sz[i] = len(self.cat_val[i]) + + def catVec_to_optVec(self, catVec): + opt = [self.cat_val[i][catVec[i]] for i in range(NB_HYPERPARAMS)] + return opt + + +class MappingOptionsIdx(IntEnum): + outerScheduleFusionStrategy = 0 + intraScheduleFusionStrategy = 1 + fixParametersBeforeScheduling = 2 + nTiledDims = 3 + tiling1 = 4 + tiling2 = 5 + tiling3 = 6 + tiling4 = 7 + tiling5 = 8 + tiling6 = 9 + unroll = 10 + matchLibraryCalls = 11 + nMappedToBlocksDims = 12 + mappingToBlocks1 = 13 + mappingToBlocks2 = 14 + mappingToBlocks3 = 15 + nMappedToThreadsDims = 16 + mappingToThreads1 = 17 + mappingToThreads2 = 18 + mappingToThreads3 = 19 + useSharedMemory = 20 + usePrivateMemory = 21 + unrollCopyShared = 22 + maxSharedMemory = 23 + useReadOnlyCache = 24 + privateDepth = 25 + +def get_rand(l): + return np.random.choice(l).item() + +def print_opt(options): + print(options.tolist()) + +def evalTime(opt, exptuner_config, iters=50, warmup=10, estimator="mean", prune=-1, curr_best=-1): + tc_code, tc_name, inp = \ + exptuner_config.tc_code, exptuner_config.tc_name, exptuner_config.inp + infty = 30000 + opt = exptuner_config.catVec_to_optVec(opt) + opt = optionsFromVector(opt) + try: + tc_prog = tc.compile(tc_code, tc_name, opt, *inp) + first_ft = tc_prog.executor.profile_kernel(inp) + except (KeyboardInterrupt, SystemExit): + raise + except: + return infty + if(prune != -1 and first_ft > 100*curr_best): + return first_ft + for _ in range(warmup-1): + tc_prog.executor.profile_kernel(inp) + + first_t = tc_prog.executor.profile_kernel(inp) + + if(prune != -1 and first_t > prune*curr_best): + return first_t + + tc_time_list = [first_t] + for i in range(iters-1): + iter_time = tc_prog.executor.profile_kernel(inp) + tc_time_list.append(iter_time) + if(estimator == "mean"): + mean_time = np.mean(tc_time_list) + return mean_time + elif(estimator == "median"): + median_time = np.median(tc_time_list) + return median_time + elif(estimator == "p25"): + p25_time = np.percentile(tc_time_list, 25) + return p25_time + print("Unknown estimator") + return infty + +def getRawVectorFromTcOpt(tc_opt): + tr_dic = {"Max":0, "Preserve3Coincident":1, "Min":2} + opt_vect = np.zeros(NB_HYPERPARAMS).astype(int) + opt_vect[MappingOptionsIdx.outerScheduleFusionStrategy] = \ + tr_dic[tc_opt["outerScheduleFusionStrategy"]] + opt_vect[MappingOptionsIdx.intraTileScheduleFusionStrategy] = \ + tr_dic[tc_opt["intraTileScheduleFusionStrategy"]] + opt_vect[MappingOptionsIdx.fixParametersBeforeScheduling] = \ + tc_opt["fixParametersBeforeScheduling"] + opt_vect[MappingOptionsIdx.nTiledDims] = \ + len(tc_opt["tile"]) + assert opt_vect[MappingOptionsIdx.nTiledDims] < 7, "Too many tilings" + opt_vect[ + MappingOptionsIdx.tiling1 : MappingOptionsIdx.tiling1 + opt_vect[MappingOptionsIdx.nTiledDims]] = \ + tc_opt["tile"] + opt_vect[MappingOptionsIdx.unroll] = \ + tc_opt["unroll"] + #opt_vect[MappingOptionsIdx.tileImperfectlyNested] = \ + # tc_opt["tileImperfectlyNested"] #todo: pybind + opt_vect[MappingOptionsIdx.matchLibraryCalls] = \ + tc_opt["matchLibraryCalls"] + opt_vect[MappingOptionsIdx.nMappedToBlocksDims] = \ + len(tc_opt["mapToBlocks"]) + opt_vect[ + MappingOptionsIdx.mappingToBlocks1 : MappingOptionsIdx.mappingToBlocks1 + opt_vect[MappingOptionsIdx.nMappedToBlocksDims]] = \ + tc_opt["mapToBlocks"] + opt_vect[MappingOptionsIdx.nMappedToThreadsDims] = \ + len(tc_opt["mapToThreads"]) + opt_vect[ + MappingOptionsIdx.mappingToThreads1 : MappingOptionsIdx.mappingToThreads1 + opt_vect[MappingOptionsIdx.nMappedToThreadsDims]] = \ + tc_opt["mapToThreads"] + opt_vect[MappingOptionsIdx.useSharedMemory] = \ + tc_opt["useSharedMemory"] + opt_vect[MappingOptionsIdx.usePrivateMemory] = \ + tc_opt["usePrivateMemory"] + opt_vect[MappingOptionsIdx.unrollCopyShared] = \ + tc_opt["unrollCopyShared"] + if(USE_MAX_SHARED_MEMORY and "maxSharedMemory" in tc_opt): + opt_vect[MappingOptionsIdx.maxSharedMemory] = \ + tc_opt["maxSharedMemory"] + opt_vect[MappingOptionsIdx.useReadOnlyCache] = \ + tc_opt["useReadOnlyCache"] + opt_vect[MappingOptionsIdx.privateDepth] = \ + tc_opt["privateDepth"] + return opt_vect + +def optionsFromVector(vect): + strat_str = ["Max", "Preserve3Coincident", "Min"] + options = tc.MappingOptions("naive") + options.outerScheduleFusionStrategy( + strat_str[vect[ + MappingOptionsIdx.outerScheduleFusionStrategy]]) + options.intraTileScheduleFusionStrategy( + strat_str[vect[ + MappingOptionsIdx.intraTileScheduleFusionStrategy]]) + options.fixParametersBeforeScheduling( + vect[MappingOptionsIdx.fixParametersBeforeScheduling]) + options.tile( + list(vect[ + MappingOptionsIdx.tiling1 : MappingOptionsIdx.tiling1 + vect[MappingOptionsIdx.nTiledDims]])) + options.unroll( + vect[MappingOptionsIdx.unroll]) + options.matchLibraryCalls( + vect[MappingOptionsIdx.matchLibraryCalls]) + options.mapToBlocks( + list(vect[ + MappingOptionsIdx.mappingToBlocks1 : MappingOptionsIdx.mappingToBlocks1 + vect[MappingOptionsIdx.nMappedToBlocksDims]])) + options.mapToThreads( + list(vect[ + MappingOptionsIdx.mappingToThreads1 : MappingOptionsIdx.mappingToThreads1 + vect[MappingOptionsIdx.nMappedToThreadsDims]])) + options.useSharedMemory( + vect[MappingOptionsIdx.useSharedMemory]) + options.usePrivateMemory( + vect[MappingOptionsIdx.usePrivateMemory]) + options.unrollCopyShared( + vect[MappingOptionsIdx.unrollCopyShared]) + if(USE_MAX_SHARED_MEMORY): + options.maxSharedMemory( + vect[MappingOptionsIdx.maxSharedMemory]) + options.useReadOnlyCache( + vect[MappingOptionsIdx.useReadOnlyCache]) + options.privateDepth( + vect[MappingOptionsIdx.privateDepth]) + return options + +def computeDivs(sz): + l = [] + for i in range(sz): + if(2 ** i > sz): + break + l.append((sz + 2 ** i - 1) // (2 ** i)) + return l + +def getAllDivs(inp, maxp2=8): + p2 = [2**i for i in range(maxp2 + 1)] + l = [] + for elem in inp: + for sz in elem.shape: + l += computeDivs(sz) + divs_list = list(set(l + p2)) + return sorted(divs_list) diff --git a/tensor_comprehensions/pybinds/tclib.cc b/tensor_comprehensions/pybinds/tclib.cc index a18fb4ca7..12bed3bd5 100644 --- a/tensor_comprehensions/pybinds/tclib.cc +++ b/tensor_comprehensions/pybinds/tclib.cc @@ -30,6 +30,7 @@ #include "tc/aten/aten_autotuner.h" #include "tc/autotuner/genetic_search.h" #include "tc/autotuner/options_cache.h" +#include "tc/core/cuda/cuda.h" #include "tc/core/cuda/cuda_backend.h" #include "tc/core/cuda/cuda_tc_executor.h" #include "tc/core/flags.h" @@ -273,6 +274,17 @@ struct TcExecutor { return tupleOrTensor(convertToPyObjects(atOutputs)); } } + + size_t profile_kernel(const py::tuple& inputs, const py::tuple& outputs) { + auto atInputs = getATenTensors(inputs); + auto atOutputs = (outputs.size() > 0) + ? getATenTensors(outputs) + : tc::aten::prepareOutputs(tc, entryPoint, atInputs); + tc::ProfilingInfo profinfo = + tc::aten::profile(*executor, atInputs, atOutputs); + return profinfo.kernelRuntime.toMicroSeconds(); + } + std::string tc; std::string entryPoint; std::unique_ptr executor; @@ -467,6 +479,11 @@ PYBIND11_MODULE(tclib, m) { return res; }); + // Get GPU shared memory size + m.def("shared_memory_size", []() { + return CudaGPUInfo::GPUInfo().SharedMemorySize(); + }); + // Low-level stateful API compile returns an executor on which run and // unchecked_run can be called. py::class_(m, "TcExecutor") @@ -479,7 +496,13 @@ PYBIND11_MODULE(tclib, m) { "unchecked_run", &TcExecutor::uncheckedRun, py::arg("inputs"), + py::arg("outputs") = py::tuple()) + .def( + "profile_kernel", + &TcExecutor::profile_kernel, + py::arg("inputs"), py::arg("outputs") = py::tuple()); + m.def( "compile", [](const std::string& tc, @@ -651,6 +674,38 @@ PYBIND11_MODULE(tclib, m) { return str; }, "Returns the CudaMappingOptions as a human-readable string") + .def( + "getDict", + [](tc::CudaMappingOptions& instance) { + py::dict rv; + rv["outerScheduleFusionStrategy"] = FusionStrategy_Name( + instance.generic.outerScheduleOptions.proto.fusion_strategy()); + if (instance.generic.proto.has_intra_tile_schedule_options()) + rv["intraTileScheduleFusionStrategy"] = + FusionStrategy_Name(instance.generic.intraTileScheduleOptions + .proto.fusion_strategy()); + rv["fixParametersBeforeScheduling"] = + instance.generic.proto.fix_parameters_before_scheduling(); + if (instance.generic.proto.has_tiling()) + rv["tile"] = instance.generic.tiling.extractVector(); + if (instance.generic.proto.has_unroll()) + rv["unroll"] = instance.generic.proto.unroll(); + rv["tileImperfectlyNested"] = + instance.generic.proto.tile_imperfectly_nested(); + rv["matchLibraryCalls"] = + instance.generic.proto.match_library_calls(); + rv["mapToThreads"] = instance.block.extractVector(); + rv["mapToBlocks"] = instance.grid.extractVector(); + rv["useSharedMemory"] = instance.proto().use_shared_memory(); + rv["usePrivateMemory"] = instance.proto().use_private_memory(); + rv["unrollCopyShared"] = instance.proto().unroll_copy_shared(); + rv["useReadOnlyCache"] = instance.proto().use_readonly_cache(); + if (instance.proto().has_max_shared_memory()) + rv["maxSharedMemory"] = instance.proto().max_shared_memory(); + rv["privateDepth"] = instance.proto().private_depth(); + return rv; + }, + "Returns a dictionary with the CudaMappingOptions") .def( "serialize", [](tc::CudaMappingOptions& instance) { @@ -677,6 +732,10 @@ PYBIND11_MODULE(tclib, m) { &tc::CudaMappingOptions::unrollCopyShared, "Also unroll the copies to and from shared memory. If an unroll " "value is not provided, has no effect") + .def( + "privateDepth", + &tc::CudaMappingOptions::privateDepth, + "Specify the private depth") .def( "useReadOnlyCache", &tc::CudaMappingOptions::useReadOnlyCache,