Skip to content

Commit 08e141b

Browse files
committed
Adjust tests to new ArgumentParser design
The new ArgumentParser interface is slightly different to the old one, and doesn't require a call to a recognize_compiler() function to construct. Signed-off-by: John Pennycook <john.pennycook@intel.com>
1 parent bfcc0f7 commit 08e141b

File tree

1 file changed

+57
-39
lines changed

1 file changed

+57
-39
lines changed

tests/compilers/test_compilers.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99

1010
from codebasin import config
11+
from codebasin.config import ArgumentParser
1112

1213

1314
class TestCompilers(unittest.TestCase):
@@ -42,120 +43,136 @@ def test_common(self):
4243
"MACRO_AFTER_SPACE",
4344
"test.cpp",
4445
]
45-
args = config._parse_compiler_args(argv)
46-
self.assertEqual(
47-
args.defines,
46+
parser = ArgumentParser(argv[0])
47+
48+
passes = parser.parse_args(argv[1:])
49+
self.assertEqual(len(passes), 1)
50+
51+
self.assertEqual(passes[0].pass_name, "default")
52+
53+
self.assertCountEqual(
54+
passes[0].defines,
4855
["MACRO", "FUNCTION_MACRO=1", "MACRO_AFTER_SPACE"],
4956
)
50-
self.assertEqual(
51-
args.include_paths,
57+
self.assertCountEqual(
58+
passes[0].include_paths,
5259
["/path", "/path/after/space", "/system/path"],
5360
)
54-
self.assertEqual(args.include_files, ["foo.inc", "bar.inc"])
61+
self.assertCountEqual(passes[0].include_files, ["foo.inc", "bar.inc"])
5562

5663
def test_gnu(self):
5764
"""compilers/gnu"""
5865
argv = ["g++", "-fopenmp", "test.cpp"]
5966

60-
parser = config.recognize_compiler(argv[0])
61-
self.assertTrue(type(parser) is config.GnuArgumentParser)
67+
parser = ArgumentParser(argv[0])
6268

6369
passes = parser.parse_args(argv[1:])
6470
self.assertEqual(len(passes), 1)
6571

6672
self.assertEqual(passes[0].pass_name, "default")
6773

6874
defines = passes[0].defines
69-
self.assertEqual(defines, ["_OPENMP"])
75+
self.assertCountEqual(defines, ["_OPENMP"])
7076

7177
def test_clang(self):
7278
"""compilers/clang"""
7379
argv = ["clang", "-fsycl-is-device", "test.cpp"]
7480

75-
parser = config.recognize_compiler(argv[0])
76-
self.assertTrue(type(parser) is config.ClangArgumentParser)
81+
parser = ArgumentParser(argv[0])
7782

7883
passes = parser.parse_args(argv[1:])
7984
self.assertEqual(len(passes), 1)
8085

8186
self.assertEqual(passes[0].pass_name, "default")
8287

8388
defines = passes[0].defines
84-
self.assertEqual(defines, ["__SYCL_DEVICE_ONLY__"])
89+
self.assertCountEqual(defines, ["__SYCL_DEVICE_ONLY__"])
8590

8691
def test_intel_sycl(self):
8792
"""compilers/intel_sycl"""
8893
argv = ["icpx", "-fsycl", "test.cpp"]
8994

90-
parser = config.recognize_compiler(argv[0])
91-
self.assertTrue(type(parser) is config.ClangArgumentParser)
95+
parser = ArgumentParser(argv[0])
9296

9397
passes = parser.parse_args(argv[1:])
9498
self.assertEqual(len(passes), 2)
9599

96100
pass_names = {p.pass_name for p in passes}
97-
self.assertEqual(pass_names, {"default", "spir64"})
101+
self.assertCountEqual(pass_names, {"default", "sycl-spir64"})
98102

99103
for p in passes:
100104
if p.pass_name == "default":
101-
expected = []
105+
expected = ["SYCL_LANGUAGE_VERSION"]
102106
else:
103-
expected = ["__SYCL_DEVICE_ONLY__", "__SPIR__", "__SPIRV__"]
104-
self.assertEqual(p.defines, expected)
107+
expected = [
108+
"SYCL_LANGUAGE_VERSION",
109+
"__SYCL_DEVICE_ONLY__",
110+
"__SPIR__",
111+
"__SPIRV__",
112+
]
113+
self.assertCountEqual(p.defines, expected)
105114

106115
def test_intel_targets(self):
107116
"""compilers/intel_targets"""
108117
argv = [
109118
"icpx",
110119
"-fsycl",
111-
"-fsycl-targets=spir64,x86_64",
120+
"-fsycl-targets=spir64,spir64_x86_64",
112121
"-fopenmp",
113122
"test.cpp",
114123
]
115124

116-
parser = config.recognize_compiler(argv[0])
117-
self.assertTrue(type(parser) is config.ClangArgumentParser)
125+
parser = ArgumentParser(argv[0])
118126

119127
passes = parser.parse_args(argv[1:])
120128

121129
pass_names = {p.pass_name for p in passes}
122-
self.assertEqual(pass_names, {"default", "spir64", "x86_64"})
130+
self.assertCountEqual(
131+
pass_names,
132+
{"default", "sycl-spir64", "sycl-spir64_x86_64"},
133+
)
123134

124135
for p in passes:
125136
if p.pass_name == "default":
126-
expected = ["_OPENMP"]
127-
self.assertEqual(p.defines, ["_OPENMP"])
128-
elif p.pass_name == "spir64" or p.pass_name == "x86_64":
129-
expected = ["__SYCL_DEVICE_ONLY__", "__SPIR__", "__SPIRV__"]
130-
self.assertEqual(p.defines, expected)
137+
expected = ["SYCL_LANGUAGE_VERSION", "_OPENMP"]
138+
elif (
139+
p.pass_name == "sycl-spir64"
140+
or p.pass_name == "sycl-spir64_x86_64"
141+
):
142+
expected = [
143+
"SYCL_LANGUAGE_VERSION",
144+
"__SYCL_DEVICE_ONLY__",
145+
"__SPIR__",
146+
"__SPIRV__",
147+
]
148+
self.assertCountEqual(p.defines, expected)
131149

132150
def test_nvcc(self):
133151
"""compilers/nvcc"""
134152
argv = [
135153
"nvcc",
136154
"-fopenmp",
137-
"--gpu-architecture=compute_50",
138-
"--gpu-code=compute_50,sm_50,sm_52",
155+
"--gpu-architecture=compute_70",
156+
"--gpu-code=compute_70,sm_70,sm_75",
139157
"test.cpp",
140158
]
141159

142-
parser = config.recognize_compiler(argv[0])
143-
self.assertTrue(type(parser) is config.NvccArgumentParser)
160+
parser = ArgumentParser(argv[0])
144161

145162
passes = parser.parse_args(argv[1:])
146163

147164
pass_names = {p.pass_name for p in passes}
148-
self.assertEqual(pass_names, {"default", "50", "52"})
165+
self.assertCountEqual(pass_names, {"default", "sm_70", "sm_75"})
149166

150167
defaults = ["__NVCC__", "__CUDACC__"]
151168
for p in passes:
152169
if p.pass_name == "default":
153170
expected = defaults + ["_OPENMP"]
154-
elif p.pass_name == "50":
155-
expected = defaults + ["__CUDA_ARCH__=500"]
156-
elif p.pass_name == "52":
157-
expected = defaults + ["__CUDA_ARCH__=520"]
158-
self.assertEqual(p.defines, expected)
171+
elif p.pass_name == "sm_70":
172+
expected = defaults + ["__CUDA_ARCH__=700"]
173+
elif p.pass_name == "sm_75":
174+
expected = defaults + ["__CUDA_ARCH__=750"]
175+
self.assertCountEqual(p.defines, expected)
159176

160177
def test_user_options(self):
161178
"""Check that we import user-defined options"""
@@ -166,14 +183,15 @@ def test_user_options(self):
166183
with open(path / ".cbi" / "config", mode="w") as f:
167184
f.write('[compiler."c++"]\n')
168185
f.write('options = ["-D", "ASDF"]\n')
169-
config.load_importcfg()
186+
config._importcfg = None
187+
config._load_compilers()
170188

171189
argv = [
172190
"c++",
173191
"test.cpp",
174192
]
175193

176-
parser = config.recognize_compiler(argv[0])
194+
parser = ArgumentParser(argv[0])
177195
passes = parser.parse_args(argv[1:])
178196
self.assertEqual(len(passes), 1)
179197
self.assertCountEqual(passes[0].defines, ["ASDF"])

0 commit comments

Comments
 (0)