Skip to content

Commit b322b49

Browse files
committed
Generate in parallel
1 parent 2e8a1c7 commit b322b49

9 files changed

+36525
-1471
lines changed

data-cast-fullfp16.json

Lines changed: 2114 additions & 0 deletions
Large diffs are not rendered by default.

data-cast-sve2.json

Lines changed: 3458 additions & 0 deletions
Large diffs are not rendered by default.

data-fp-fullfp16.json

Lines changed: 8450 additions & 0 deletions
Large diffs are not rendered by default.

data-fp-sve2.json

Lines changed: 13954 additions & 0 deletions
Large diffs are not rendered by default.

data-fp.json

Lines changed: 88 additions & 88 deletions
Large diffs are not rendered by default.

data-int-sve2.json

Lines changed: 4515 additions & 1075 deletions
Large diffs are not rendered by default.

data-int.json

Lines changed: 253 additions & 253 deletions
Large diffs are not rendered by default.

data-vec-sve2.json

Lines changed: 3650 additions & 0 deletions
Large diffs are not rendered by default.

generate.py

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import sys, os, subprocess, argparse, logging, json
1+
import sys, os, subprocess, argparse, logging, json, tempfile, multiprocessing
22

33
# Try to more extensively check the cost model figures coming out of the cost model, for every operation x type combo.
44
# Currently it looks at costsize costs, as those are easier to measure.
@@ -10,7 +10,7 @@
1010
# and this to serve is to pert 8081, inside a venv with pandas
1111
# python llvm/utils/costmodeltest.py --servellvm/utils/costmodeltest.py
1212

13-
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, format='')
13+
logging.basicConfig(stream=sys.stdout, level=logging.WARNING, format='')
1414
logging.getLogger("requests").setLevel(logging.WARNING)
1515
logging.getLogger("urllib3").setLevel(logging.WARNING)
1616

@@ -19,21 +19,21 @@ def run(cmd):
1919
cmd = cmd.split()
2020
return subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode('utf-8')
2121

22-
def getcost(costkind, print):
23-
text = run(f"opt {'-mtriple='+args.mtriple if args.mtriple else ''} {'-mattr='+args.mattr if args.mattr else ''} costtest.ll -passes=print<cost-model> -cost-kind={costkind} -disable-output")
22+
def getcost(path, costkind, print):
23+
text = run(f"opt {'-mtriple='+args.mtriple if args.mtriple else ''} {'-mattr='+args.mattr if args.mattr else ''} {os.path.join(path, 'costtest.ll')} -passes=print<cost-model> -cost-kind={costkind} -disable-output")
2424
costpre = 'Cost Model: Found an estimated cost of '
2525
if print:
2626
logging.debug(text.strip())
2727
costs = [x for x in text.split('\n') if 'instruction: ret ' not in x]
2828
cost = sum([int(x[len(costpre):len(costpre)+x[len(costpre):].find(' ')]) for x in costs if x.startswith(costpre)])
2929
return (cost, text.strip())
3030

31-
def getasm(extraflags):
31+
def getasm(path, extraflags):
3232
try:
33-
run(f"llc {'-mtriple='+args.mtriple if args.mtriple else ''} {'-mattr='+args.mattr if args.mattr else ''} {extraflags} costtest.ll -o costtest.s")
33+
run(f"llc {'-mtriple='+args.mtriple if args.mtriple else ''} {'-mattr='+args.mattr if args.mattr else ''} {extraflags} {os.path.join(path, 'costtest.ll')} -o {os.path.join(path, 'costtest.s')}")
3434
except subprocess.CalledProcessError as e:
3535
return ([e.output.decode('utf-8').split('\n')[0]], -1)
36-
with open("costtest.s") as f:
36+
with open(os.path.join(path, "costtest.s")) as f:
3737
lines = [l.strip() for l in f]
3838
# This tries to remove .declarations, comments etc
3939
lines = [l for l in lines if l[0] != '.' and l[0] != '/' and not l.startswith('test:')]
@@ -50,20 +50,21 @@ def getasm(extraflags):
5050

5151
def checkcosts(llasm):
5252
logging.debug(llasm)
53-
with open("costtest.ll", "w") as f:
54-
f.write(llasm)
53+
with tempfile.TemporaryDirectory() as tmp:
54+
with open(os.path.join(tmp, "costtest.ll"), "w") as f:
55+
f.write(llasm)
5556

56-
lines, size = getasm('')
57+
lines, size = getasm(tmp, '')
5758

58-
gilines, gisize = getasm('-global-isel')
59+
gilines, gisize = getasm(tmp, '-global-isel')
5960

60-
codesize = getcost('code-size', True)
61-
thru = getcost('throughput', False)
62-
lat = getcost('latency', False)
63-
sizelat = getcost('size-latency', False)
61+
codesize = getcost(tmp, 'code-size', True)
62+
thru = getcost(tmp, 'throughput', False)
63+
lat = getcost(tmp, 'latency', False)
64+
sizelat = getcost(tmp, 'size-latency', False)
6465

65-
logging.debug(f"cost = codesize:{codesize[0]} throughput:{thru[0]} lat:{lat[0]} sizelat:{sizelat[0]}")
66-
return (size, gisize, [codesize, thru, lat, sizelat], llasm, ('\n'.join(lines)).replace('\t', ' '), ('\n'.join(gilines)).replace('\t', ' '))
66+
logging.debug(f"cost = codesize:{codesize[0]} throughput:{thru[0]} lat:{lat[0]} sizelat:{sizelat[0]}")
67+
return (size, gisize, [codesize, thru, lat, sizelat], llasm, ('\n'.join(lines)).replace('\t', ' '), ('\n'.join(gilines)).replace('\t', ' '))
6768

6869
# TODOD:
6970
#if args.checkopted:
@@ -191,40 +192,37 @@ def binop_variants(ty):
191192
args = parser.parse_args()
192193

193194

194-
def do(instr, variant, ty, ty2, extrasize, data, tyoverride=None):
195+
def do(instr, variant, ty, ty2, extrasize, tyoverride):
195196
logging.info(f"{variant} {instr} with {ty.str()}")
196197
(size, gisize, costs, ll, asm, giasm) = checkcosts(generate(variant, instr, ty, ty2))
197198
tystr = str(ty) if not tyoverride else tyoverride
198199
if costs[0][0] != size - extrasize:
199200
logging.warning(f">>> {variant} {instr} with {tystr} size = {size} vs cost = {costs[0][0]} (expected extrasize={extrasize})")
200-
data.append({"instr":instr, "ty":tystr, "variant":variant, "codesize":costs[0][0], "thru":costs[1][0], "lat":costs[2][0], "sizelat":costs[3][0], "size":size, "gisize":gisize, "extrasize":extrasize, "asm":asm, "giasm":giasm, "ll":ll, "costoutput":costs[0][1]})
201-
logging.debug('')
201+
return {"instr":instr, "ty":tystr, "variant":variant, "codesize":costs[0][0], "thru":costs[1][0], "lat":costs[2][0], "sizelat":costs[3][0], "size":size, "gisize":gisize, "extrasize":extrasize, "asm":asm, "giasm":giasm, "ll":ll, "costoutput":costs[0][1]}
202202

203203
# Operations are the ones in https://github.com/llvm/llvm-project/issues/115133
204204
# TODO: load/store, bitcast, getelementptr, phi
205205

206-
exit = False
207206
if args.type == 'all' or args.type == 'int':
208-
data = []
209-
try:
207+
def enumint():
210208
# Int Binops
211209
for instr in ['add', 'sub', 'mul', 'sdiv', 'srem', 'udiv', 'urem', 'and', 'or', 'xor', 'shl', 'ashr', 'lshr', 'smin', 'smax', 'umin', 'umax', 'uadd.sat', 'usub.sat', 'sadd.sat', 'ssub.sat', 'rotr', 'rotl']:
212210
for ty in inttypes():
213211
for (variant, extrasize) in binop_variants(ty):
214-
do(instr, variant, ty, ty, extrasize, data)
212+
yield (instr, variant, ty, ty, extrasize, None)
215213

216214
# Int unops
217215
for instr in ['abs', 'bitreverse', 'bswap', 'ctlz', 'cttz', 'ctpop']:
218216
for ty in inttypes():
219217
if instr == 'bswap' and ty.scalar == 'i8':
220218
continue
221-
do(instr, 'unop', ty, ty, 0, data)
219+
yield (instr, 'unop', ty, ty, 0, None)
222220
# TODO: not?
223221

224222
# Int triops
225223
for instr in ['fshl', 'fshr']:
226224
for ty in inttypes():
227-
do(instr, 'triop', ty, ty, 0, data)
225+
yield (instr, 'triop', ty, ty, 0, None)
228226
# TODO: select, icmp, fcmp, mla?
229227
# TODO: fshl+const
230228

@@ -239,29 +237,27 @@ def do(instr, variant, ty, ty2, extrasize, data, tyoverride=None):
239237

240238
# TODO: vecreduce.add, vecreduce.mul, vecreduce.and, vecreduce.or, vecreduce.xor, vecreduce.min/max's
241239

242-
except KeyboardInterrupt:
243-
exit=True
240+
pool = multiprocessing.Pool(16)
241+
data = pool.starmap(do, enumint())
244242
with open(f"data-int{'-'+args.mattr if args.mattr else ''}.json", "w") as f:
245243
json.dump(data, f, indent=1)
246-
if exit:
247-
sys.exit(1)
244+
248245

249246
if args.type == 'all' or args.type == 'fp':
250-
data = []
251-
try:
247+
def enumfp():
252248
# Floating point Binops
253249
for instr in ['fadd', 'fsub', 'fmul', 'fdiv', 'frem', 'minnum', 'maxnum', 'minimum', 'maximum', 'copysign', 'pow']:
254250
for ty in fptypes():
255251
for (variant, extrasize) in binop_variants(ty):
256-
do(instr, variant, ty, ty, extrasize, data)
252+
yield (instr, variant, ty, ty, extrasize, None)
257253

258254
# FP unops
259255
for instr in ['fneg', 'fabs', 'sqrt', 'ceil', 'floor', 'trunc', 'rint', 'nearbyint']:
260256
for ty in fptypes():
261-
do(instr, 'unop', ty, ty, 0, data)
257+
yield (instr, 'unop', ty, ty, 0, None)
262258
for instr in ['fma', 'fmuladd']:
263259
for ty in fptypes():
264-
do(instr, 'triop', ty, ty, 0, data)
260+
yield (instr, 'triop', ty, ty, 0, None)
265261

266262
# TODO: fmul+fadd? select+fcmp
267263
# TODO: fminimumnum, fmaximumnum
@@ -272,17 +268,14 @@ def do(instr, variant, ty, ty2, extrasize, data, tyoverride=None):
272268

273269
# TODO: vecreduce.fadd, vecreduce.fmul, vecreduce.fmin/max's
274270

275-
except KeyboardInterrupt:
276-
exit=True
271+
pool = multiprocessing.Pool(16)
272+
data = pool.starmap(do, enumfp())
277273
with open(f"data-fp{'-'+args.mattr if args.mattr else ''}.json", "w") as f:
278274
json.dump(data, f, indent=1)
279-
if exit:
280-
sys.exit(1)
281275

282276

283277
if args.type == 'all' or args.type == 'cast':
284-
data = []
285-
try:
278+
def enumcast():
286279
# TODO: zext, sext, trunc
287280
# TODO: fpext, fptrunc, fptosisat, fptouisat
288281
# TODO: lrint, llrint, lround, llround
@@ -293,42 +286,37 @@ def do(instr, variant, ty, ty2, extrasize, data, tyoverride=None):
293286
for ty2 in inttypes():
294287
if ty1.elts != ty2.elts or ty1.scalable != ty2.scalable:
295288
continue
296-
do(instr, 'cast '+ty2.scalar, ty1, ty2, 0, data)
289+
yield (instr, 'cast '+ty2.scalar, ty1, ty2, 0, None)
297290
for instr in ['sitofp', 'uitofp']:
298291
for ty1 in fptypes():
299292
for ty2 in inttypes():
300293
if ty1.elts != ty2.elts or ty1.scalable != ty2.scalable:
301294
continue
302-
do(instr, 'cast '+ty2.scalar, ty2, ty1, 0, data, str(ty1))
295+
yield (instr, 'cast '+ty2.scalar, ty2, ty1, 0, str(ty1))
303296

304-
except KeyboardInterrupt:
305-
exit=True
297+
pool = multiprocessing.Pool(16)
298+
data = pool.starmap(do, enumcast())
306299
with open(f"data-cast{'-'+args.mattr if args.mattr else ''}.json", "w") as f:
307300
json.dump(data, f, indent=1)
308-
if exit:
309-
sys.exit(1)
310301

311302

312303
if args.type == 'all' or args.type == 'vec':
313-
data = []
314-
try:
304+
def enumvec():
315305
for instr in ['insertelement', 'extractelement']:
316306
for ty in inttypes():
317307
if ty.elts == 1:
318308
continue
319309
for variant in ['vecop0', 'vecop1', 'vecopvar']:
320-
do(instr, variant, ty, ty, 0, data)
310+
yield (instr, variant, ty, ty, 0, None)
321311
for ty in fptypes():
322312
if ty.elts == 1:
323313
continue
324314
for variant in ['vecop0', 'vecop1', 'vecopvar']:
325-
do(instr, variant, ty, ty, 0, data)
315+
yield (instr, variant, ty, ty, 0, None)
326316

327317
# TODO: shuffles
328318

329-
except KeyboardInterrupt:
330-
exit=True
319+
pool = multiprocessing.Pool(16)
320+
data = pool.starmap(do, enumvec())
331321
with open(f"data-vec{'-'+args.mattr if args.mattr else ''}.json", "w") as f:
332322
json.dump(data, f, indent=1)
333-
if exit:
334-
sys.exit(1)

0 commit comments

Comments
 (0)