Skip to content

Commit 7ee2e29

Browse files
authored
Merge pull request #155 from Pennycook/custom-argparse-actions
Add custom argparse actions
2 parents f58df1f + 1430ab6 commit 7ee2e29

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

codebasin/config.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import re
13+
import string
1314

1415
from codebasin import CompilationDatabase, util
1516

@@ -72,6 +73,83 @@ def load_importcfg():
7273
_importcfg[name] = compiler["options"]
7374

7475

76+
class _StoreSplitAction(argparse.Action):
77+
"""
78+
A custom argparse.Action that splits the value based on a user-provided
79+
separator, then stores the resulting list.
80+
"""
81+
82+
def __init__(
83+
self,
84+
option_strings: list[str],
85+
dest: str,
86+
nargs=None,
87+
**kwargs,
88+
):
89+
self.sep = kwargs.pop("sep", None)
90+
self.format = kwargs.pop("format", None)
91+
super().__init__(option_strings, dest, nargs=nargs, **kwargs)
92+
93+
def __call__(
94+
self,
95+
parser: argparse.ArgumentParser,
96+
namespace: argparse.Namespace,
97+
values: str,
98+
option_string: str,
99+
):
100+
if not isinstance(values, str):
101+
raise TypeError("store_split expects string values")
102+
split_values = values.split(self.sep)
103+
if self.format:
104+
template = string.Template(self.format)
105+
split_values = [template.substitute(value=v) for v in split_values]
106+
if self.dest == "passes":
107+
passes = getattr(namespace, self.dest)
108+
passes[option_string] = split_values
109+
else:
110+
setattr(namespace, self.dest, split_values)
111+
112+
113+
class _ExtendMatchAction(argparse.Action):
114+
"""
115+
A custom argparse.Action that matches the value against a user-provided
116+
pattern, then extends the destination list using the result(s).
117+
"""
118+
119+
def __init__(
120+
self,
121+
option_strings: list[str],
122+
dest: str,
123+
nargs=None,
124+
**kwargs,
125+
):
126+
self.pattern = kwargs.pop("pattern", None)
127+
self.format = kwargs.pop("format", None)
128+
super().__init__(option_strings, dest, nargs=nargs, **kwargs)
129+
130+
def __call__(
131+
self,
132+
parser: argparse.ArgumentParser,
133+
namespace: argparse.Namespace,
134+
value: str,
135+
option_string: str,
136+
):
137+
if not isinstance(value, str):
138+
raise TypeError("extend_match expects string value")
139+
matches = re.findall(self.pattern, value)
140+
if self.format:
141+
template = string.Template(self.format)
142+
matches = [template.substitute(value=v) for v in matches]
143+
if self.dest == "passes":
144+
passes = getattr(namespace, self.dest)
145+
if option_string not in passes:
146+
passes[option_string] = []
147+
passes[option_string].extend(matches)
148+
else:
149+
dest = getattr(namespace, self.dest)
150+
dest.extend(matches)
151+
152+
75153
class Compiler:
76154
"""
77155
Represents the behavior of a specific compiler, including:

tests/compilers/test_actions.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (C) 2019-2024 Intel Corporation
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import argparse
5+
import logging
6+
import unittest
7+
8+
from codebasin.config import _ExtendMatchAction, _StoreSplitAction
9+
10+
11+
class TestActions(unittest.TestCase):
12+
"""
13+
Test that custom argparse.Action classes work as expected.
14+
These classes enable handling of complex user-defined compiler options.
15+
"""
16+
17+
def setUp(self):
18+
logging.disable()
19+
20+
def test_store_split_init(self):
21+
"""Check that store_split recognizes custom arguments"""
22+
action = _StoreSplitAction(["--foo"], "foo", sep=",", format="$value")
23+
self.assertEqual(action.sep, ",")
24+
self.assertEqual(action.format, "$value")
25+
26+
action = _StoreSplitAction(["--foo"], "foo")
27+
self.assertEqual(action.sep, None)
28+
self.assertEqual(action.format, None)
29+
30+
def test_store_split(self):
31+
"""Check that argparse calls store_split correctly"""
32+
namespace = argparse.Namespace()
33+
namespace.passes = {}
34+
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument("--foo", action=_StoreSplitAction, sep=",")
37+
parser.add_argument(
38+
"--bar",
39+
action=_StoreSplitAction,
40+
sep=",",
41+
format="prefix-$value-suffix",
42+
)
43+
parser.add_argument("--baz", action=_StoreSplitAction, type=int)
44+
parser.add_argument(
45+
"--qux",
46+
action=_StoreSplitAction,
47+
sep=",",
48+
dest="passes",
49+
)
50+
51+
args, _ = parser.parse_known_args(["--foo=one,two"], namespace)
52+
self.assertEqual(args.foo, ["one", "two"])
53+
54+
args, _ = parser.parse_known_args(["--bar=one,two"], namespace)
55+
self.assertEqual(args.bar, ["prefix-one-suffix", "prefix-two-suffix"])
56+
57+
with self.assertRaises(TypeError):
58+
args, _ = parser.parse_known_args(["--baz=1"], namespace)
59+
60+
args, _ = parser.parse_known_args(["--qux=one,two"], namespace)
61+
self.assertEqual(args.passes, {"--qux": ["one", "two"]})
62+
63+
def test_extend_match_init(self):
64+
"""Check that extend_match recognizes custom arguments"""
65+
action = _ExtendMatchAction(
66+
["--foo"],
67+
"foo",
68+
pattern="*",
69+
format="$value",
70+
)
71+
self.assertEqual(action.pattern, "*")
72+
self.assertEqual(action.format, "$value")
73+
74+
action = _ExtendMatchAction(["--foo"], "foo")
75+
self.assertEqual(action.pattern, None)
76+
self.assertEqual(action.format, None)
77+
78+
def test_extend_match(self):
79+
"""Check that argparse calls store_split correctly"""
80+
namespace = argparse.Namespace()
81+
namespace.passes = {}
82+
83+
parser = argparse.ArgumentParser()
84+
parser.add_argument(
85+
"--foo",
86+
action=_ExtendMatchAction,
87+
pattern=r"option_(\d+)",
88+
default=[],
89+
)
90+
parser.add_argument(
91+
"--bar",
92+
action=_ExtendMatchAction,
93+
pattern=r"option_(\d+)",
94+
format="prefix-$value-suffix",
95+
default=[],
96+
)
97+
parser.add_argument("--baz", action=_ExtendMatchAction, type=int)
98+
parser.add_argument(
99+
"--qux",
100+
action=_ExtendMatchAction,
101+
pattern=r"option_(\d+)",
102+
dest="passes",
103+
)
104+
105+
args, _ = parser.parse_known_args(
106+
["--foo=option_1,option_2"],
107+
namespace,
108+
)
109+
self.assertEqual(args.foo, ["1", "2"])
110+
111+
args, _ = parser.parse_known_args(
112+
["--bar=option_1,option_2"],
113+
namespace,
114+
)
115+
self.assertEqual(args.bar, ["prefix-1-suffix", "prefix-2-suffix"])
116+
117+
with self.assertRaises(TypeError):
118+
args, _ = parser.parse_known_args(["--baz=1"], namespace)
119+
120+
args, _ = parser.parse_known_args(
121+
["--qux=option_1,option_2"],
122+
namespace,
123+
)
124+
self.assertEqual(args.passes, {"--qux": ["1", "2"]})
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

0 commit comments

Comments
 (0)