Skip to content

Commit 0822a56

Browse files
committed
rust: support the ? operator in doctests
Now it is possible to have tests that use the `?` operator: /// ``` /// # use kernel::{spawn_work_item, workqueue}; /// spawn_work_item!(workqueue::system(), || pr_info!("x"))?; /// # Ok::<(), Error>(()) /// ``` The logic is also simplified: instead of extracting the source code and the originally-crate-level attributes separately, wrap everything into an extra scope. This allows to reuse the `rustdoc`-generated `unwrap()` and `Result` type. This could break if `rustdoc` happens to emit an attribute that can only be used at the crate-level, like a `feature(...)` one, but the previous approach could also break in other ways. There is a workaround for `Result` not being qualified in the `rustdoc`-generated code -- it picks up our own `Result`, and with the simplified logic here we cannot easily move our `use` of the prelude inside the source code. Suggested-by: Andreas Hindborg <andreas.hindborg@wdc.com> Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
1 parent dbb685d commit 0822a56

File tree

2 files changed

+23
-41
lines changed

2 files changed

+23
-41
lines changed

scripts/rustdoc_test_builder.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,45 +13,28 @@
1313

1414
# `[^\s]*` removes the prefix (e.g. `_doctest_main_`) plus any
1515
# leading path (for `O=` builds).
16-
MAIN_RE = re.compile(
17-
r"^"
18-
r"fn main\(\) { "
19-
r"#\[allow\(non_snake_case\)\] "
20-
r"fn ([^\s]*rust_kernel_([a-zA-Z0-9_]+))\(\) {"
21-
r"$"
22-
)
16+
TEST_NAME_RE = re.compile(r"fn [^\s]*rust_kernel_([a-zA-Z0-9_]+)\(\)")
2317

2418
def main():
25-
found_main = False
26-
test_header = ""
27-
test_body = ""
28-
for line in sys.stdin.readlines():
29-
main_match = MAIN_RE.match(line)
30-
if main_match:
31-
if found_main:
32-
raise Exception("More than one `main` line found.")
33-
found_main = True
34-
function_name = main_match.group(1)
35-
test_name = f"rust_kernel_doctest_{main_match.group(2)}"
36-
continue
37-
38-
if found_main:
39-
test_body += line
40-
else:
41-
test_header += line
42-
43-
if not found_main:
44-
raise Exception("No `main` line found.")
45-
46-
call_line = f"}} {function_name}() }}"
47-
if not test_body.endswith(call_line):
48-
raise Exception("Unexpected end of test body.")
49-
test_body = test_body[:-len(call_line)]
19+
content = sys.stdin.read()
20+
matches = TEST_NAME_RE.findall(content)
21+
if len(matches) == 0:
22+
raise Exception("No test name found.")
23+
if len(matches) > 1:
24+
raise Exception("More than one test name found.")
25+
26+
test_name = f"rust_kernel_doctest_{matches[0]}"
27+
28+
# Qualify `Result` to avoid the collision with our own `Result`
29+
# coming from the prelude.
30+
test_body = content.replace(
31+
f'rust_kernel_{matches[0]}() -> Result<(), impl core::fmt::Debug> {{',
32+
f'rust_kernel_{matches[0]}() -> core::result::Result<(), impl core::fmt::Debug> {{',
33+
)
5034

5135
with open(TESTS_DIR / f"{test_name}.json", "w") as fd:
5236
json.dump({
5337
"name": test_name,
54-
"header": test_header,
5538
"body": test_body,
5639
}, fd, sort_keys=True, indent=4)
5740

scripts/rustdoc_test_gen.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
__KUNIT_TEST.store(__kunit_test, core::sync::atomic::Ordering::SeqCst);
3030
3131
/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
32+
#[allow(unused)]
3233
macro_rules! assert {{
3334
($cond:expr $(,)?) => {{{{
3435
kernel::kunit_assert!(
@@ -39,6 +40,7 @@
3940
}}
4041
4142
/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
43+
#[allow(unused)]
4244
macro_rules! assert_eq {{
4345
($left:expr, $right:expr $(,)?) => {{{{
4446
kernel::kunit_assert_eq!(
@@ -50,9 +52,13 @@
5052
}}
5153
5254
// Many tests need the prelude, so provide it by default.
55+
#[allow(unused)]
5356
use kernel::prelude::*;
5457
55-
{test_body}
58+
{{
59+
{test_body}
60+
main();
61+
}}
5662
}}
5763
"""
5864
RUST_TEMPLATE = """// SPDX-License-Identifier: GPL-2.0
@@ -93,8 +99,6 @@
9399
// an `AtomicPtr` to hold the context (though each test only writes once before
94100
// threads may be created).
95101
96-
{rust_header}
97-
98102
const __LOG_PREFIX: &[u8] = b"rust_kernel_doctests\\0";
99103
100104
{rust_tests}
@@ -127,15 +131,12 @@
127131
"""
128132

129133
def main():
130-
rust_header = set()
131134
rust_tests = ""
132135
c_test_declarations = ""
133136
c_test_cases = ""
134137
for filename in sorted(os.listdir(TESTS_DIR)):
135138
with open(TESTS_DIR / filename, "r") as fd:
136139
test = json.load(fd)
137-
for line in test["header"].strip().split("\n"):
138-
rust_header.add(line)
139140
rust_tests += RUST_TEMPLATE_TEST.format(
140141
test_name = test["name"],
141142
test_body = test["body"]
@@ -146,11 +147,9 @@ def main():
146147
c_test_cases += C_TEMPLATE_TEST_CASE.format(
147148
test_name = test["name"]
148149
)
149-
rust_header = sorted(rust_header)
150150

151151
with open(RUST_FILE, "w") as fd:
152152
fd.write(RUST_TEMPLATE.format(
153-
rust_header = "\n".join(rust_header).strip(),
154153
rust_tests = rust_tests.strip(),
155154
))
156155

0 commit comments

Comments
 (0)