Skip to content

Commit 7f7415a

Browse files
committed
Parse #[repr(..)] for #[pyclass] enums.
1 parent 8a03778 commit 7f7415a

File tree

5 files changed

+128
-21
lines changed

5 files changed

+128
-21
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,69 @@ struct PyClassEnumVariant<'a> {
381381
/* currently have no more options */
382382
}
383383

384+
struct PyClassEnum<'a> {
385+
ident: &'a syn::Ident,
386+
// The underyling representation of the enum.
387+
// It's used to implement __int__ and __richcmp__.
388+
// This matters when the underyling representation may not fit in `isize`.
389+
#[allow(unused, dead_code)]
390+
repr: syn::Ident,
391+
variants: Vec<PyClassEnumVariant<'a>>,
392+
doc: PythonDoc,
393+
}
394+
395+
impl<'a> PyClassEnum<'a> {
396+
fn new(enum_: &'a syn::ItemEnum) -> syn::Result<Self> {
397+
fn is_numeric_type(t: &syn::Ident) -> bool {
398+
[
399+
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
400+
"isize",
401+
]
402+
.iter()
403+
.any(|&s| t == s)
404+
}
405+
struct Reprs(syn::punctuated::Punctuated<syn::Ident, Token![,]>);
406+
impl Parse for Reprs {
407+
fn parse(input: ParseStream) -> Result<Self> {
408+
let inner = Punctuated::parse_terminated(input)?;
409+
Ok(Self(inner))
410+
}
411+
}
412+
let ident = &enum_.ident;
413+
// According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
414+
// "Under the default representation, the specified discriminant is interpreted as an isize
415+
// value", so `isize` should be enough by default.
416+
// `cargo test` also tests the following facts:
417+
// - Rustc emits a compile error when the user use a discriminant larger than `isize`.
418+
// - Rustc emits a compile error when the default discriminant is larger than `isize`.
419+
let mut repr = syn::Ident::new("isize", proc_macro2::Span::call_site());
420+
for attr in &enum_.attrs {
421+
if attr.path.is_ident("repr") {
422+
let reprs: Reprs = attr.parse_args()?;
423+
for r in reprs.0 {
424+
if is_numeric_type(&r) {
425+
repr = r;
426+
break;
427+
}
428+
}
429+
}
430+
}
431+
let doc = utils::get_doc(&enum_.attrs, None);
432+
433+
let variants = enum_
434+
.variants
435+
.iter()
436+
.map(extract_variant_data)
437+
.collect::<syn::Result<_>>()?;
438+
Ok(Self {
439+
ident,
440+
repr,
441+
variants,
442+
doc,
443+
})
444+
}
445+
}
446+
384447
pub fn build_py_enum(
385448
enum_: &syn::ItemEnum,
386449
args: PyClassArgs,
@@ -389,38 +452,32 @@ pub fn build_py_enum(
389452
if enum_.variants.is_empty() {
390453
bail_spanned!(enum_.brace_token.span => "Empty enums can't be #[pyclass].");
391454
}
392-
let variants: Vec<PyClassEnumVariant> = enum_
393-
.variants
394-
.iter()
395-
.map(extract_variant_data)
396-
.collect::<syn::Result<_>>()?;
397-
impl_enum(enum_, args, variants, method_type)
455+
let enum_ = PyClassEnum::new(enum_)?;
456+
impl_enum(enum_, args, method_type)
398457
}
399458

400459
fn impl_enum(
401-
enum_: &syn::ItemEnum,
402-
attrs: PyClassArgs,
403-
variants: Vec<PyClassEnumVariant>,
460+
enum_: PyClassEnum,
461+
args: PyClassArgs,
404462
methods_type: PyClassMethodsType,
405463
) -> syn::Result<TokenStream> {
406-
let enum_name = &enum_.ident;
407-
let doc = utils::get_doc(&enum_.attrs, None);
408-
let enum_cls = impl_enum_class(enum_name, &attrs, variants, doc, methods_type)?;
464+
let enum_cls = impl_enum_class(enum_, &args, methods_type)?;
409465

410466
Ok(quote! {
411467
#enum_cls
412468
})
413469
}
414470

415471
fn impl_enum_class(
416-
cls: &syn::Ident,
417-
attr: &PyClassArgs,
418-
variants: Vec<PyClassEnumVariant>,
419-
doc: PythonDoc,
472+
enum_: PyClassEnum,
473+
args: &PyClassArgs,
420474
methods_type: PyClassMethodsType,
421475
) -> syn::Result<TokenStream> {
422-
let pytypeinfo = impl_pytypeinfo(cls, attr, None);
423-
let pyclass_impls = PyClassImplsBuilder::new(cls, attr, methods_type)
476+
let cls = enum_.ident;
477+
let doc = enum_.doc;
478+
let variants = enum_.variants;
479+
let pytypeinfo = impl_pytypeinfo(cls, args, None);
480+
let pyclass_impls = PyClassImplsBuilder::new(cls, args, methods_type)
424481
.doc(doc)
425482
.impl_all();
426483
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));
@@ -494,9 +551,6 @@ fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVarian
494551
Fields::Unit => &variant.ident,
495552
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
496553
};
497-
if let Some(discriminant) = variant.discriminant.as_ref() {
498-
bail_spanned!(discriminant.0.span() => "Currently does not support discriminats.")
499-
};
500554
Ok(PyClassEnumVariant { ident })
501555
}
502556

tests/test_compile_error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ fn test_compile_errors() {
1515

1616
fn _test_compile_errors() {
1717
let t = trybuild::TestCases::new();
18+
t.compile_fail("tests/ui/invalid_enum_discriminant.rs");
1819
t.compile_fail("tests/ui/invalid_macro_args.rs");
1920
t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
2021
t.compile_fail("tests/ui/invalid_property_args.rs");

tests/test_enum.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,24 @@ fn test_default_repr_correct() {
6161
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
6262
})
6363
}
64+
65+
#[pyclass]
66+
enum CustomDiscriminant {
67+
One = 1,
68+
Two = 2,
69+
}
70+
71+
#[test]
72+
fn test_custom_discriminant() {
73+
Python::with_gil(|py| {
74+
#[allow(non_snake_case)]
75+
let CustomDiscriminant = py.get_type::<CustomDiscriminant>();
76+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
77+
let two = Py::new(py, CustomDiscriminant::Two).unwrap();
78+
py_run!(py, CustomDiscriminant one two, r#"
79+
assert CustomDiscriminant.One == one
80+
assert CustomDiscriminant.Two == two
81+
assert one != two
82+
"#);
83+
})
84+
}

tests/ui/invalid_enum_discriminant.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//! As of rust 1.56, enums without #[repr] cannot have discriminant larger than `isize`.
2+
//! The implementation of #[pyclass] depends on this implementation detail.
3+
//! This file tests if this implementation detail is still true.
4+
5+
use pyo3::prelude::*;
6+
7+
#[pyclass]
8+
enum DiscriminantTooLarge{
9+
Var = 1 << 64,
10+
}
11+
12+
#[pyclass]
13+
enum DiscriminantOverflow{
14+
Var1 = isize::MAX,
15+
Overflow,
16+
}
17+
18+
fn main() {}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
error[E0080]: evaluation of constant value failed
2+
--> tests/ui/invalid_enum_discriminant.rs:9:11
3+
|
4+
9 | Var = 1 << 64,
5+
| ^^^^^^^ attempt to shift left by `64_i32`, which would overflow
6+
7+
error[E0370]: enum discriminant overflowed
8+
--> tests/ui/invalid_enum_discriminant.rs:15:5
9+
|
10+
15 | Overflow,
11+
| ^^^^^^^^ overflowed on value after 9223372036854775807
12+
|
13+
= note: explicitly set `Overflow = -9223372036854775808` if that is desired outcome

0 commit comments

Comments
 (0)