Skip to content

Commit 77c2d60

Browse files
committed
2 parents 7d7aa97 + 0684737 commit 77c2d60

File tree

2 files changed

+137
-44
lines changed

2 files changed

+137
-44
lines changed

cgu.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,22 @@ def find_typedefs(fully_qualified_name, source):
464464
return typedefs, typedef_names
465465

466466

467+
# return list of any typedefs for a particular type
468+
def find_typedef_decls(source):
469+
pos = 0
470+
typedef_decls = []
471+
while True:
472+
start_pos = find_token("typedef", source[pos:])
473+
if start_pos != -1:
474+
start_pos += pos
475+
end_pos = start_pos + source[start_pos:].find(";")
476+
typedef_decls.append(source[start_pos:end_pos])
477+
pos = end_pos
478+
else:
479+
break
480+
return typedef_decls
481+
482+
467483
def find_type_attributes(source, type_pos):
468484
delimiters = [";", "}"]
469485
attr = source[:type_pos].rfind("[[")
@@ -723,7 +739,7 @@ def find_functions(source):
723739
pos = 0
724740
attributes = []
725741
while True:
726-
statement_end, statement_token = find_first(source, [";", "{"], pos)
742+
statement_end, statement_token = find_first(source, [";", "{", "}"], pos)
727743
if statement_end == -1:
728744
break
729745
statement = source[pos:statement_end].strip()
@@ -795,6 +811,28 @@ def get_funtion_prototype(func):
795811
return "(" + args + ")"
796812

797813

814+
# find the line, column position within source
815+
def position_to_line_column(source, position):
816+
if position < 0 or position > len(source):
817+
raise ValueError("position out of bounds")
818+
819+
# split the string into lines
820+
lines = source.splitlines(keepends=True)
821+
822+
# find the line and column
823+
current_pos = 0
824+
for line_number, line in enumerate(lines, start=1):
825+
line_length = len(line)
826+
if current_pos + line_length > position:
827+
# Found the line
828+
column = position - current_pos + 1 # Convert to 1-based
829+
return line_number, column
830+
current_pos += line_length
831+
832+
# If we exit the loop, something went wrong
833+
raise ValueError("position not found in string")
834+
835+
798836
# main function for scope
799837
def test():
800838
# read source from file

pmfx_pipeline.py

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,19 @@ def pmfx_hash(src):
4242
return zlib.adler32(bytes(str(src).encode("utf8")))
4343

4444

45+
# combine 2, 32 bit hashes
46+
def pmfx_hash_combine(h1: int, h2: int) -> int:
47+
combined_data = h1.to_bytes(4, 'little') + h2.to_bytes(4, 'little')
48+
return zlib.adler32(combined_data)
49+
50+
4551
# return names of supported shader stages
4652
def get_shader_stages():
4753
return [
4854
"vs",
4955
"ps",
50-
"cs"
56+
"cs",
57+
"lib"
5158
]
5259

5360

@@ -92,7 +99,8 @@ def get_bindable_resource_keys():
9299
"RWTexture2DArray",
93100
"RWTexture3D",
94101
"SamplerState",
95-
"SamplerComparisonState"
102+
"SamplerComparisonState",
103+
"RaytracingAccelerationStructure"
96104
]
97105

98106

@@ -124,6 +132,7 @@ def get_resource_mappings():
124132
{"category": "textures", "identifier": "RWTexture2D"},
125133
{"category": "textures", "identifier": "RWTexture2DArray"},
126134
{"category": "textures", "identifier": "RWTexture3D"},
135+
{"category": "acceleration_structures", "identifier": "RaytracingAccelerationStructure"},
127136
]
128137

129138

@@ -134,7 +143,8 @@ def get_resource_categories():
134143
"cbuffers",
135144
"structured_buffers",
136145
"textures",
137-
"samplers"
146+
"samplers",
147+
"acceleration_structures"
138148
]
139149

140150

@@ -225,9 +235,10 @@ def get_shader_visibility(vis):
225235
stages = {
226236
"vs": "Vertex",
227237
"ps": "Fragment",
228-
"cs": "Compute"
238+
"cs": "Compute",
229239
}
230-
return stages[vis[0]]
240+
if vis[0] in stages:
241+
return stages[vis[0]]
231242
return "All"
232243

233244

@@ -940,19 +951,33 @@ def generate_shader_info(pmfx, entry_point, stage, permute=None):
940951
res += "{}\n".format(pragma)
941952

942953
# resources input structs, textures, buffers etc
954+
added_resources = []
943955
if len(resources) > 0:
944956
res += "// resource declarations\n"
945957
for resource in recursive_resources:
958+
if resource in added_resources:
959+
continue
946960
if recursive_resources[resource]["depth"] > 0:
947961
res += recursive_resources[resource]["declaration"] + ";\n"
962+
added_resources.append(resource)
948963

949964
for resource in resources:
965+
if resource in added_resources:
966+
continue
950967
res += resources[resource]["declaration"] + ";\n"
968+
added_resources.append(resource)
951969

952970
# extract vs_input (input layout)
953971
if stage == "vs":
954972
vertex_elements = get_vertex_elements(pmfx, entry_point)
955973

974+
# typedefs
975+
typedef_decls = cgu.find_typedef_decls(pmfx["source"])
976+
if len(typedef_decls) > 0:
977+
res += "// typedefs\n"
978+
for typedef_decl in typedef_decls:
979+
res += typedef_decl + ";\n"
980+
956981
# add fwd function decls
957982
if len(forward_decls) > 0:
958983
res += "// function foward declarations\n"
@@ -1035,42 +1060,63 @@ def generate_pipeline_permutation(pipeline_name, pipeline, output_pmfx, shaders,
10351060
print(" pipeline: {} {}".format(pipeline_name, permutation_name))
10361061
resources = dict()
10371062
output_pipeline = dict(pipeline)
1038-
# lookup info from compiled shaders and combine resources
1063+
1064+
# gather entry points
1065+
entry_points = list()
10391066
for stage in get_shader_stages():
10401067
if stage in pipeline:
1041-
entry_point = pipeline[stage]
1042-
if entry_point not in shaders[stage]:
1043-
output_pipeline["error_code"] = 1
1044-
continue
1045-
# lookup shader info, and redirect to shared shaders
1046-
shader_info = shaders[stage][entry_point][pemutation_id]
1047-
if "lookup" in shader_info:
1048-
lookup = shader_info["lookup"]
1049-
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])
1068+
if type(pipeline[stage]) is list:
1069+
for entry_point in pipeline[stage]:
1070+
entry_points.append((stage, entry_point, True))
1071+
else:
1072+
entry_points.append((stage, pipeline[stage], False))
1073+
1074+
# clear lib
1075+
if "lib" in output_pipeline:
1076+
output_pipeline["lib_hash"] = 0
1077+
output_pipeline["lib"].clear()
1078+
1079+
# lookup info from compiled shaders and combine resources
1080+
for (stage, entry_point, lib) in entry_points:
1081+
# check entry exists
1082+
if entry_point not in shaders[stage]:
1083+
output_pipeline["error_code"] = 1
1084+
continue
1085+
# lookup shader info, and redirect to shared shaders
1086+
shader_info = shaders[stage][entry_point][pemutation_id]
1087+
if "lookup" in shader_info:
1088+
lookup = shader_info["lookup"]
1089+
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])
1090+
1091+
if lib:
1092+
output_pipeline[stage].append(shader_info["filename"])
1093+
output_pipeline["lib_hash"] = pmfx_hash_combine(output_pipeline["lib_hash"], pmfx_hash(shader_info["src_hash"]))
1094+
else:
10501095
output_pipeline[stage] = shader_info["filename"]
1051-
output_pipeline["{}_hash:".format(stage)] = pmfx_hash(shader_info["src_hash"])
1052-
shader = shader_info
1053-
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
1054-
# generate vertex layout
1055-
if stage == "vs":
1056-
pmfx_vertex_layout = dict()
1057-
if "vertex_layout" in pipeline:
1058-
pmfx_vertex_layout = pipeline["vertex_layout"]
1059-
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
1060-
# extract numthreads
1061-
if stage == "cs":
1062-
for attrib in shader["attributes"]:
1063-
if attrib.find("numthreads") != -1:
1064-
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
1065-
xyz = attrib[start:end].split(",")
1066-
numthreads = []
1067-
for a in xyz:
1068-
numthreads.append(int(a.strip()))
1069-
output_pipeline["numthreads"] = numthreads
1070-
1071-
# set non zero error codes to track failures
1072-
if shader_info["error_code"] != 0:
1073-
output_pipeline["error_code"] = shader_info["error_code"]
1096+
output_pipeline["{}_hash".format(stage)] = pmfx_hash(shader_info["src_hash"])
1097+
1098+
shader = shader_info
1099+
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
1100+
# generate vertex layout
1101+
if stage == "vs":
1102+
pmfx_vertex_layout = dict()
1103+
if "vertex_layout" in pipeline:
1104+
pmfx_vertex_layout = pipeline["vertex_layout"]
1105+
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
1106+
# extract numthreads
1107+
if stage == "cs":
1108+
for attrib in shader["attributes"]:
1109+
if attrib.find("numthreads") != -1:
1110+
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
1111+
xyz = attrib[start:end].split(",")
1112+
numthreads = []
1113+
for a in xyz:
1114+
numthreads.append(int(a.strip()))
1115+
output_pipeline["numthreads"] = numthreads
1116+
1117+
# set non zero error codes to track failures
1118+
if shader_info["error_code"] != 0:
1119+
output_pipeline["error_code"] = shader_info["error_code"]
10741120

10751121
# build pipeline layout
10761122
output_pipeline["pipeline_layout"] = generate_pipeline_layout(output_pmfx, pipeline, resources)
@@ -1309,9 +1355,13 @@ def generate_pmfx(file, root):
13091355
pipeline = pipelines[pipeline_key]
13101356
for stage in get_shader_stages():
13111357
if stage in pipeline:
1312-
stage_shader = (stage, pipeline[stage])
1313-
if stage_shader not in shader_list:
1314-
shader_list.append(stage_shader)
1358+
if type(pipeline[stage]) is list:
1359+
for shader in pipeline[stage]:
1360+
stage_shader = (stage, shader)
1361+
else:
1362+
stage_shader = (stage, pipeline[stage])
1363+
if stage_shader not in shader_list:
1364+
shader_list.append(stage_shader)
13151365

13161366
# gather permutations
13171367
permutation_jobs = []
@@ -1326,8 +1376,13 @@ def generate_pmfx(file, root):
13261376
pipeline_jobs.append((pipeline_key, id))
13271377
for stage in get_shader_stages():
13281378
if stage in pipeline:
1329-
permutation_jobs.append(
1330-
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))
1379+
if type(pipeline[stage]) is list:
1380+
for shader in pipeline[stage]:
1381+
permutation_jobs.append(
1382+
pool.apply_async(generate_shader_info_permutation, (pmfx, shader, stage, permute, define_list)))
1383+
else:
1384+
permutation_jobs.append(
1385+
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))
13311386

13321387
# wait on shader permutations
13331388
shaders = dict()

0 commit comments

Comments
 (0)