Skip to content

Commit 0684737

Browse files
polymonsterGBDixonAlex
authored and
GBDixonAlex
committed
- refactor to make raytracing shaders, a stage of lib with an array of shaders and a hit group block to define which entry points to use
1 parent 188270c commit 0684737

File tree

1 file changed

+76
-52
lines changed

1 file changed

+76
-52
lines changed

pmfx_pipeline.py

Lines changed: 76 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,19 @@ def pmfx_hash(src):
4242
return zlib.adler32(bytes(str(src).encode("utf8")))
4343

4444

45+
# combine 2, 32 bit hashes
46+
def pmfx_hash_combine(h1: int, h2: int) -> int:
47+
combined_data = h1.to_bytes(4, 'little') + h2.to_bytes(4, 'little')
48+
return zlib.adler32(combined_data)
49+
50+
4551
# return names of supported shader stages
4652
def get_shader_stages():
4753
return [
4854
"vs",
4955
"ps",
5056
"cs",
51-
"rg",
52-
"ch",
53-
"ah",
54-
"mi",
55-
"is",
56-
"ca"
57+
"lib"
5758
]
5859

5960

@@ -773,13 +774,6 @@ def cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, outpu
773774
return 0, error_list, output_list
774775

775776

776-
# convert satage to correct hlsl profile
777-
def hlsl_stage(stage):
778-
if stage in ["rg", "ch", "ah", "mi"]:
779-
return "lib"
780-
return stage
781-
782-
783777
# compile a hlsl version 2
784778
def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_filepath):
785779
exe = os.path.join(info.tools_dir, "bin", "dxc", "dxc")
@@ -792,7 +786,7 @@ def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_fil
792786
if info.shader_platform == "metal":
793787
error_code, error_list, output_list = cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, output_filepath)
794788
elif info.shader_platform == "hlsl":
795-
cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, hlsl_stage(stage), info.shader_version, entry_point, output_filepath, temp_filepath)
789+
cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, stage, info.shader_version, entry_point, output_filepath, temp_filepath)
796790
cmdline += " " + build_pmfx.get_info().args
797791
error_code, error_list, output_list = build_pmfx.call_wait_subprocess(cmdline)
798792

@@ -1066,42 +1060,63 @@ def generate_pipeline_permutation(pipeline_name, pipeline, output_pmfx, shaders,
10661060
print(" pipeline: {} {}".format(pipeline_name, permutation_name))
10671061
resources = dict()
10681062
output_pipeline = dict(pipeline)
1069-
# lookup info from compiled shaders and combine resources
1063+
1064+
# gather entry points
1065+
entry_points = list()
10701066
for stage in get_shader_stages():
10711067
if stage in pipeline:
1072-
entry_point = pipeline[stage]
1073-
if entry_point not in shaders[stage]:
1074-
output_pipeline["error_code"] = 1
1075-
continue
1076-
# lookup shader info, and redirect to shared shaders
1077-
shader_info = shaders[stage][entry_point][pemutation_id]
1078-
if "lookup" in shader_info:
1079-
lookup = shader_info["lookup"]
1080-
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])
1068+
if type(pipeline[stage]) is list:
1069+
for entry_point in pipeline[stage]:
1070+
entry_points.append((stage, entry_point, True))
1071+
else:
1072+
entry_points.append((stage, pipeline[stage], False))
1073+
1074+
# clear lib
1075+
if "lib" in output_pipeline:
1076+
output_pipeline["lib_hash"] = 0
1077+
output_pipeline["lib"].clear()
1078+
1079+
# lookup info from compiled shaders and combine resources
1080+
for (stage, entry_point, lib) in entry_points:
1081+
# check entry exists
1082+
if entry_point not in shaders[stage]:
1083+
output_pipeline["error_code"] = 1
1084+
continue
1085+
# lookup shader info, and redirect to shared shaders
1086+
shader_info = shaders[stage][entry_point][pemutation_id]
1087+
if "lookup" in shader_info:
1088+
lookup = shader_info["lookup"]
1089+
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])
1090+
1091+
if lib:
1092+
output_pipeline[stage].append(shader_info["filename"])
1093+
output_pipeline["lib_hash"] = pmfx_hash_combine(output_pipeline["lib_hash"], pmfx_hash(shader_info["src_hash"]))
1094+
else:
10811095
output_pipeline[stage] = shader_info["filename"]
1082-
output_pipeline["{}_hash:".format(stage)] = pmfx_hash(shader_info["src_hash"])
1083-
shader = shader_info
1084-
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
1085-
# generate vertex layout
1086-
if stage == "vs":
1087-
pmfx_vertex_layout = dict()
1088-
if "vertex_layout" in pipeline:
1089-
pmfx_vertex_layout = pipeline["vertex_layout"]
1090-
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
1091-
# extract numthreads
1092-
if stage == "cs":
1093-
for attrib in shader["attributes"]:
1094-
if attrib.find("numthreads") != -1:
1095-
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
1096-
xyz = attrib[start:end].split(",")
1097-
numthreads = []
1098-
for a in xyz:
1099-
numthreads.append(int(a.strip()))
1100-
output_pipeline["numthreads"] = numthreads
1101-
1102-
# set non zero error codes to track failures
1103-
if shader_info["error_code"] != 0:
1104-
output_pipeline["error_code"] = shader_info["error_code"]
1096+
output_pipeline["{}_hash".format(stage)] = pmfx_hash(shader_info["src_hash"])
1097+
1098+
shader = shader_info
1099+
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
1100+
# generate vertex layout
1101+
if stage == "vs":
1102+
pmfx_vertex_layout = dict()
1103+
if "vertex_layout" in pipeline:
1104+
pmfx_vertex_layout = pipeline["vertex_layout"]
1105+
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
1106+
# extract numthreads
1107+
if stage == "cs":
1108+
for attrib in shader["attributes"]:
1109+
if attrib.find("numthreads") != -1:
1110+
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
1111+
xyz = attrib[start:end].split(",")
1112+
numthreads = []
1113+
for a in xyz:
1114+
numthreads.append(int(a.strip()))
1115+
output_pipeline["numthreads"] = numthreads
1116+
1117+
# set non zero error codes to track failures
1118+
if shader_info["error_code"] != 0:
1119+
output_pipeline["error_code"] = shader_info["error_code"]
11051120

11061121
# build pipeline layout
11071122
output_pipeline["pipeline_layout"] = generate_pipeline_layout(output_pmfx, pipeline, resources)
@@ -1340,9 +1355,13 @@ def generate_pmfx(file, root):
13401355
pipeline = pipelines[pipeline_key]
13411356
for stage in get_shader_stages():
13421357
if stage in pipeline:
1343-
stage_shader = (stage, pipeline[stage])
1344-
if stage_shader not in shader_list:
1345-
shader_list.append(stage_shader)
1358+
if type(pipeline[stage]) is list:
1359+
for shader in pipeline[stage]:
1360+
stage_shader = (stage, shader)
1361+
else:
1362+
stage_shader = (stage, pipeline[stage])
1363+
if stage_shader not in shader_list:
1364+
shader_list.append(stage_shader)
13461365

13471366
# gather permutations
13481367
permutation_jobs = []
@@ -1357,8 +1376,13 @@ def generate_pmfx(file, root):
13571376
pipeline_jobs.append((pipeline_key, id))
13581377
for stage in get_shader_stages():
13591378
if stage in pipeline:
1360-
permutation_jobs.append(
1361-
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))
1379+
if type(pipeline[stage]) is list:
1380+
for shader in pipeline[stage]:
1381+
permutation_jobs.append(
1382+
pool.apply_async(generate_shader_info_permutation, (pmfx, shader, stage, permute, define_list)))
1383+
else:
1384+
permutation_jobs.append(
1385+
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))
13621386

13631387
# wait on shader permutations
13641388
shaders = dict()

0 commit comments

Comments
 (0)