|
12 | 12 | import subprocess
|
13 | 13 | import sys
|
14 | 14 |
|
15 |
| -from mako.template import Template |
16 |
| - |
17 |
| -HEADER_TEMPLATE = Template("""/* |
| 15 | +HEADER_TEMPLATE = """/* |
18 | 16 | *
|
19 | 17 | * Copyright (C) 2023 Intel Corporation
|
20 | 18 | *
|
21 | 19 | * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
|
22 | 20 | * See LICENSE.TXT
|
23 | 21 | * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
24 | 22 | *
|
25 |
| - * @file ${file_name}.h |
| 23 | + * @file %s.h |
26 | 24 | *
|
27 | 25 | */
|
28 | 26 |
|
|
33 | 31 | namespace uur {
|
34 | 32 | namespace device_binaries {
|
35 | 33 | std::map<std::string, std::vector<std::string>> program_kernel_map = {
|
36 |
| -% for program, entry_points in kernel_name_dict.items(): |
37 |
| - {"${program}", { |
38 |
| - % for entry_point in entry_points: |
39 |
| - "${entry_point}", |
40 |
| - % endfor |
41 |
| - }}, |
42 |
| -% endfor |
| 34 | +%s |
43 | 35 | };
|
44 | 36 | }
|
45 | 37 | }
|
46 |
| -""") |
| 38 | +""" |
| 39 | + |
| 40 | +PROGRAM_TEMPLATE = """\ |
| 41 | + {"%s", { |
| 42 | +%s |
| 43 | + }}, |
| 44 | +""" |
47 | 45 |
|
| 46 | +ENTRY_POINT_TEMPLATE = """\ |
| 47 | + "%s", |
| 48 | +""" |
48 | 49 |
|
49 | 50 | def generate_header(output_file, kernel_name_dict):
|
50 | 51 | """Render the template and write it to the output file."""
|
51 | 52 | file_name = os.path.basename(output_file)
|
52 |
| - rendered = HEADER_TEMPLATE.render(file_name=file_name, |
53 |
| - kernel_name_dict=kernel_name_dict) |
| 53 | + device_binaries = "" |
| 54 | + for program, entry_points in kernel_name_dict.items(): |
| 55 | + content = "" |
| 56 | + for entry_point in entry_points: |
| 57 | + content += ENTRY_POINT_TEMPLATE % entry_point |
| 58 | + device_binaries += PROGRAM_TEMPLATE % (program, content) |
| 59 | + rendered = HEADER_TEMPLATE % (file_name, device_binaries) |
54 | 60 | rendered = re.sub(r"\r\n", r"\n", rendered)
|
55 |
| - |
56 | 61 | with open(output_file, "w") as fout:
|
57 | 62 | fout.write(rendered)
|
58 | 63 |
|
@@ -81,7 +86,9 @@ def get_mangled_names(dpcxx_path, source_file, output_header):
|
81 | 86 | for line in definition_lines:
|
82 | 87 | if kernel_name_regex.search(line) is None:
|
83 | 88 | continue
|
84 |
| - kernel_name = kernel_name_regex.search(line).group(1) |
| 89 | + match = kernel_name_regex.search(line) |
| 90 | + assert isinstance(match, re.Match) |
| 91 | + kernel_name = match.group(1) |
85 | 92 | if "kernel_wrapper" not in kernel_name and "with_offset" not in kernel_name:
|
86 | 93 | entry_point_names.append(kernel_name)
|
87 | 94 |
|
|
0 commit comments