Skip to content

Commit f9c94b5

Browse files
committed
use safe match pattern
1 parent ff22137 commit f9c94b5

File tree

5 files changed

+82
-115
lines changed

5 files changed

+82
-115
lines changed

riscv-pac/macros/src/lib.rs

Lines changed: 38 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,89 +6,66 @@ extern crate syn;
66
use proc_macro::TokenStream;
77
use proc_macro2::TokenStream as TokenStream2;
88
use quote::quote;
9-
use std::{collections::HashMap, ops::Range, str::FromStr};
9+
use std::str::FromStr;
1010
use syn::{parse_macro_input, Data, DeriveInput, Ident};
1111

1212
struct PacNumberEnum {
1313
name: Ident,
14-
valid_ranges: Vec<Range<usize>>,
14+
numbers: Vec<(Ident, usize)>,
1515
}
1616

1717
impl PacNumberEnum {
1818
fn new(input: &DeriveInput) -> Self {
19+
let name = input.ident.clone();
20+
1921
let variants = match &input.data {
2022
Data::Enum(data) => &data.variants,
2123
_ => panic!("Input is not an enum"),
2224
};
23-
24-
// Collect the variants and their associated number discriminants
25-
let mut var_map = HashMap::new();
26-
let mut numbers = Vec::new();
27-
for variant in variants {
28-
let ident = &variant.ident;
29-
let value = match &variant.discriminant {
30-
Some(d) => match &d.1 {
31-
syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
32-
syn::Lit::Int(lit_int) => match lit_int.base10_parse::<usize>() {
33-
Ok(num) => num,
34-
Err(_) => panic!("All variant discriminants must be unsigned integers"),
25+
let numbers = variants
26+
.iter()
27+
.map(|variant| {
28+
let ident = &variant.ident;
29+
let value = match &variant.discriminant {
30+
Some(d) => match &d.1 {
31+
syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
32+
syn::Lit::Int(lit_int) => match lit_int.base10_parse::<usize>() {
33+
Ok(num) => num,
34+
Err(_) => {
35+
panic!("All variant discriminants must be unsigned integers")
36+
}
37+
},
38+
_ => panic!("All variant discriminants must be unsigned integers"),
3539
},
3640
_ => panic!("All variant discriminants must be unsigned integers"),
3741
},
38-
_ => panic!("All variant discriminants must be unsigned integers"),
39-
},
40-
_ => panic!("Variant must have a discriminant"),
41-
};
42-
var_map.insert(value, ident);
43-
numbers.push(value);
44-
}
42+
_ => panic!("Variant must have a discriminant"),
43+
};
44+
(ident.clone(), value)
45+
})
46+
.collect();
4547

46-
// sort the number discriminants and generate a list of valid ranges
47-
numbers.sort_unstable();
48-
let mut valid_ranges = Vec::new();
49-
let mut start = numbers[0];
50-
let mut end = start;
51-
for &number in &numbers[1..] {
52-
if number == end + 1 {
53-
end = number;
54-
} else {
55-
valid_ranges.push(start..end + 1);
56-
start = number;
57-
end = start;
58-
}
59-
}
60-
valid_ranges.push(start..end + 1);
61-
62-
Self {
63-
name: input.ident.clone(),
64-
valid_ranges,
65-
}
66-
}
67-
68-
fn valid_condition(&self) -> TokenStream2 {
69-
let mut arms = Vec::new();
70-
for range in &self.valid_ranges {
71-
let (start, end) = (range.start, range.end);
72-
if end - start == 1 {
73-
arms.push(TokenStream2::from_str(&format!("number == {start}")).unwrap());
74-
} else {
75-
arms.push(
76-
TokenStream2::from_str(&format!("({start}..{end}).contains(&number)")).unwrap(),
77-
);
78-
}
79-
}
80-
quote! { #(#arms) || * }
48+
Self { name, numbers }
8149
}
8250

8351
fn max_discriminant(&self) -> TokenStream2 {
84-
let max_discriminant = self.valid_ranges.last().expect("invalid range").end - 1;
52+
let max_discriminant = self.numbers.iter().map(|(_, num)| num).max().unwrap();
8553
TokenStream2::from_str(&format!("{max_discriminant}")).unwrap()
8654
}
8755

56+
fn valid_matches(&self) -> Vec<TokenStream2> {
57+
self.numbers
58+
.iter()
59+
.map(|(ident, num)| {
60+
TokenStream2::from_str(&format!("{num} => Ok(Self::{ident})")).unwrap()
61+
})
62+
.collect()
63+
}
64+
8865
fn quote(&self, trait_name: &str, num_type: &str, const_name: &str) -> TokenStream2 {
8966
let name = &self.name;
9067
let max_discriminant = self.max_discriminant();
91-
let valid_condition = self.valid_condition();
68+
let valid_matches = self.valid_matches();
9269

9370
let trait_name = TokenStream2::from_str(trait_name).unwrap();
9471
let num_type = TokenStream2::from_str(num_type).unwrap();
@@ -105,11 +82,9 @@ impl PacNumberEnum {
10582

10683
#[inline]
10784
fn from_number(number: #num_type) -> Result<Self, #num_type> {
108-
if #valid_condition {
109-
// SAFETY: The number is valid for this enum
110-
Ok(unsafe { core::mem::transmute::<#num_type, Self>(number) })
111-
} else {
112-
Err(number)
85+
match number {
86+
#(#valid_matches,)*
87+
_ => Err(number),
11388
}
11489
}
11590
}
@@ -125,20 +100,6 @@ impl PacNumberEnum {
125100
/// The trait name must be one of `ExceptionNumber`, `InterruptNumber`, `PriorityNumber`, or `HartIdNumber`.
126101
/// Marker traits `CoreInterruptNumber` and `ExternalInterruptNumber` cannot be implemented using this macro.
127102
///
128-
/// # Note
129-
///
130-
/// To implement number-to-enum operation, the macro works with ranges of valid discriminant numbers.
131-
/// If the number is within any of the valid ranges, the number is transmuted to the enum variant.
132-
/// In this way, the macro achieves better performance for enums with a large number of consecutive variants.
133-
/// Thus, the enum must comply with the following requirements:
134-
///
135-
/// - All the enum variants must have a valid discriminant number (i.e., a number that is within the valid range of the enum).
136-
/// - For the `ExceptionNumber`, `InterruptNumber`, and `HartIdNumber` traits, the enum must be annotated as `#[repr(u16)]`
137-
/// - For the `PriorityNumber` trait, the enum must be annotated as `#[repr(u8)]`
138-
///
139-
/// If the enum does not meet these requirements, you will have to implement the traits manually (e.g., `riscv::mcause::Interrupt`).
140-
/// For enums with a small number of consecutive variants, it might be better to implement the traits manually.
141-
///
142103
/// # Safety
143104
///
144105
/// The struct to be implemented must comply with the requirements of the specified trait.

riscv-pac/tests/ui/fail_wrong_repr.rs

Lines changed: 0 additions & 9 deletions
This file was deleted.

riscv-pac/tests/ui/fail_wrong_repr.stderr

Lines changed: 0 additions & 9 deletions
This file was deleted.

riscv/src/register/mcause.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@ unsafe impl InterruptNumber for Interrupt {
2626

2727
#[inline]
2828
fn from_number(value: usize) -> Result<Self, usize> {
29-
if value > 11 || value % 2 == 0 {
30-
Err(value)
31-
} else {
32-
// SAFETY: valid interrupt number
33-
unsafe { Ok(core::mem::transmute::<usize, Self>(value)) }
29+
match value {
30+
1 => Ok(Self::SupervisorSoft),
31+
3 => Ok(Self::MachineSoft),
32+
5 => Ok(Self::SupervisorTimer),
33+
7 => Ok(Self::MachineTimer),
34+
9 => Ok(Self::SupervisorExternal),
35+
11 => Ok(Self::MachineExternal),
36+
_ => Err(value),
3437
}
3538
}
3639
}
@@ -69,11 +72,22 @@ unsafe impl ExceptionNumber for Exception {
6972

7073
#[inline]
7174
fn from_number(value: usize) -> Result<Self, usize> {
72-
if value == 10 || value == 14 || value > 15 {
73-
Err(value)
74-
} else {
75-
// SAFETY: valid exception number
76-
unsafe { Ok(core::mem::transmute::<usize, Self>(value)) }
75+
match value {
76+
0 => Ok(Self::InstructionMisaligned),
77+
1 => Ok(Self::InstructionFault),
78+
2 => Ok(Self::IllegalInstruction),
79+
3 => Ok(Self::Breakpoint),
80+
4 => Ok(Self::LoadMisaligned),
81+
5 => Ok(Self::LoadFault),
82+
6 => Ok(Self::StoreMisaligned),
83+
7 => Ok(Self::StoreFault),
84+
8 => Ok(Self::UserEnvCall),
85+
9 => Ok(Self::SupervisorEnvCall),
86+
11 => Ok(Self::MachineEnvCall),
87+
12 => Ok(Self::InstructionPageFault),
88+
13 => Ok(Self::LoadPageFault),
89+
15 => Ok(Self::StorePageFault),
90+
_ => Err(value),
7791
}
7892
}
7993
}

riscv/src/register/scause.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ unsafe impl InterruptNumber for Interrupt {
2323

2424
#[inline]
2525
fn from_number(value: usize) -> Result<Self, usize> {
26-
if value == 1 || value == 5 || value == 9 {
27-
// SAFETY: valid interrupt number
28-
Ok(unsafe { core::mem::transmute::<usize, Self>(value) })
29-
} else {
30-
Err(value)
26+
match value {
27+
1 => Ok(Self::SupervisorSoft),
28+
5 => Ok(Self::SupervisorTimer),
29+
9 => Ok(Self::SupervisorExternal),
30+
_ => Err(value),
3131
}
3232
}
3333
}
@@ -65,11 +65,21 @@ unsafe impl ExceptionNumber for Exception {
6565

6666
#[inline]
6767
fn from_number(value: usize) -> Result<Self, usize> {
68-
if value == 10 || value == 11 || value == 14 || value > 15 {
69-
Err(value)
70-
} else {
71-
// SAFETY: valid exception number
72-
unsafe { Ok(core::mem::transmute::<usize, Self>(value)) }
68+
match value {
69+
0 => Ok(Self::InstructionMisaligned),
70+
1 => Ok(Self::InstructionFault),
71+
2 => Ok(Self::IllegalInstruction),
72+
3 => Ok(Self::Breakpoint),
73+
4 => Ok(Self::LoadMisaligned),
74+
5 => Ok(Self::LoadFault),
75+
6 => Ok(Self::StoreMisaligned),
76+
7 => Ok(Self::StoreFault),
77+
8 => Ok(Self::UserEnvCall),
78+
9 => Ok(Self::SupervisorEnvCall),
79+
12 => Ok(Self::InstructionPageFault),
80+
13 => Ok(Self::LoadPageFault),
81+
15 => Ok(Self::StorePageFault),
82+
_ => Err(value),
7383
}
7484
}
7585
}

0 commit comments

Comments
 (0)