Skip to content

Commit a0a1812

Browse files
committed
Add default __richcmp__ for enums.
1 parent 7f7415a commit a0a1812

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ struct PyClassEnum<'a> {
386386
// The underyling representation of the enum.
387387
// It's used to implement __int__ and __richcmp__.
388388
// This matters when the underyling representation may not fit in `isize`.
389-
#[allow(unused, dead_code)]
390389
repr: syn::Ident,
391390
variants: Vec<PyClassEnumVariant<'a>>,
392391
doc: PythonDoc,
@@ -502,7 +501,29 @@ fn impl_enum_class(
502501
}
503502
};
504503

505-
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
504+
let default_richcmp = {
505+
let variants_eq = variants.iter().map(|variant| {
506+
let variant_name = variant.ident;
507+
quote! {(#cls::#variant_name, #cls::#variant_name) => true.to_object(py),}
508+
});
509+
quote! {
510+
#[allow(non_snake_case)]
511+
#[pyo3(name = "__richcmp__")]
512+
fn __pyo3__richcmp__(&self, py: ::pyo3::Python, other: &Self, op: ::pyo3::basic::CompareOp) -> PyObject {
513+
match op {
514+
::pyo3::basic::CompareOp::Eq => {
515+
match (self, other) {
516+
#(#variants_eq)*
517+
_ => py.NotImplemented(),
518+
}
519+
}
520+
_ => py.NotImplemented(),
521+
}
522+
}
523+
}
524+
};
525+
526+
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp]);
506527
Ok(quote! {
507528

508529
#pytypeinfo

tests/test_enum.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@ pub enum MyEnum {
1212

1313
#[test]
1414
fn test_enum_class_attr() {
15-
let gil = Python::acquire_gil();
16-
let py = gil.python();
17-
let my_enum = py.get_type::<MyEnum>();
18-
py_assert!(py, my_enum, "getattr(my_enum, 'Variant', None) is not None");
19-
py_assert!(py, my_enum, "getattr(my_enum, 'foobar', None) is None");
20-
py_run!(py, my_enum, "my_enum.Variant = None");
15+
Python::with_gil(|py| {
16+
let my_enum = py.get_type::<MyEnum>();
17+
let var = Py::new(py, MyEnum::Variant).unwrap();
18+
py_assert!(py, my_enum var, "my_enum.Variant == var");
19+
})
2120
}
2221

2322
#[pyfunction]
@@ -26,7 +25,6 @@ fn return_enum() -> MyEnum {
2625
}
2726

2827
#[test]
29-
#[ignore] // need to implement __eq__
3028
fn test_return_enum() {
3129
let gil = Python::acquire_gil();
3230
let py = gil.python();
@@ -42,14 +40,24 @@ fn enum_arg(e: MyEnum) {
4240
}
4341

4442
#[test]
45-
#[ignore] // need to implement __eq__
4643
fn test_enum_arg() {
47-
let gil = Python::acquire_gil();
48-
let py = gil.python();
49-
let f = wrap_pyfunction!(enum_arg)(py).unwrap();
50-
let mynum = py.get_type::<MyEnum>();
44+
Python::with_gil(|py| {
45+
let f = wrap_pyfunction!(enum_arg)(py).unwrap();
46+
let mynum = py.get_type::<MyEnum>();
5147

52-
py_run!(py, f mynum, "f(mynum.Variant)")
48+
py_run!(py, f mynum, "f(mynum.OtherVariant)")
49+
})
50+
}
51+
52+
#[test]
53+
fn test_enum_eq() {
54+
Python::with_gil(|py| {
55+
let var1 = Py::new(py, MyEnum::Variant).unwrap();
56+
let var2 = Py::new(py, MyEnum::Variant).unwrap();
57+
let other_var = Py::new(py, MyEnum::OtherVariant).unwrap();
58+
py_assert!(py, var1 var2, "var1 == var2");
59+
py_assert!(py, var1 other_var, "var1 != other_var");
60+
})
5361
}
5462

5563
#[test]

0 commit comments

Comments
 (0)