Skip to content

Commit 6d3129e

Browse files
committed
feat: allow async methods to accept &self/&mut self
1 parent 2ca9f59 commit 6d3129e

File tree

6 files changed

+171
-22
lines changed

6 files changed

+171
-22
lines changed

guide/src/async-await.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + '
3030

3131
As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`.
3232

33-
It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.*
34-
33+
However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self`
3534

3635
## Implicit GIL holding
3736

newsfragments/3609.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow async methods to accept `&self`/`&mut self`

pyo3-macros-backend/src/method.rs

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
use std::fmt::Display;
22

3-
use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue};
4-
use crate::deprecations::{Deprecation, Deprecations};
5-
use crate::params::impl_arg_params;
6-
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
7-
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
8-
use crate::quotes;
9-
use crate::utils::{self, PythonDoc};
103
use proc_macro2::{Span, TokenStream};
11-
use quote::ToTokens;
12-
use quote::{quote, quote_spanned};
13-
use syn::ext::IdentExt;
14-
use syn::spanned::Spanned;
15-
use syn::{Ident, Result};
4+
use quote::{quote, quote_spanned, ToTokens};
5+
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
6+
7+
use crate::{
8+
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
9+
deprecations::{Deprecation, Deprecations},
10+
params::impl_arg_params,
11+
pyfunction::{
12+
FunctionSignature, PyFunctionArgPyO3Attributes, PyFunctionOptions, SignatureAttribute,
13+
},
14+
quotes,
15+
utils::{self, PythonDoc},
16+
};
1617

1718
#[derive(Clone, Debug)]
1819
pub struct FnArg<'a> {
@@ -473,8 +474,7 @@ impl<'a> FnSpec<'a> {
473474
}
474475

475476
let rust_call = |args: Vec<TokenStream>| {
476-
let mut call = quote! { function(#self_arg #(#args),*) };
477-
if self.asyncness.is_some() {
477+
let call = if self.asyncness.is_some() {
478478
let throw_callback = if cancel_handle.is_some() {
479479
quote! { Some(__throw_callback) }
480480
} else {
@@ -485,8 +485,19 @@ impl<'a> FnSpec<'a> {
485485
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
486486
None => quote!(None),
487487
};
488-
call = quote! {{
489-
let future = #call;
488+
let future = match self.tp {
489+
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {{
490+
let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
491+
async move { function(&__guard, #(#args),*).await }
492+
}},
493+
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {{
494+
let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
495+
async move { function(&mut __guard, #(#args),*).await }
496+
}},
497+
_ => quote! { function(#self_arg #(#args),*) },
498+
};
499+
let mut call = quote! {{
500+
let future = #future;
490501
_pyo3::impl_::coroutine::new_coroutine(
491502
_pyo3::intern!(py, stringify!(#python_name)),
492503
#qualname_prefix,
@@ -501,7 +512,10 @@ impl<'a> FnSpec<'a> {
501512
#call
502513
}};
503514
}
504-
}
515+
call
516+
} else {
517+
quote! { function(#self_arg #(#args),*) }
518+
};
505519
quotes::map_result_into_ptr(quotes::ok_wrap(call))
506520
};
507521

src/impl_/coroutine.rs

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1-
use std::future::Future;
1+
use std::{
2+
future::Future,
3+
mem,
4+
ops::{Deref, DerefMut},
5+
};
26

3-
use crate::coroutine::cancel::ThrowCallback;
4-
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
7+
use crate::{
8+
coroutine::{cancel::ThrowCallback, Coroutine},
9+
pyclass::boolean_struct::False,
10+
types::PyString,
11+
IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyRef, PyRefMut, PyResult, Python,
12+
};
513

614
pub fn new_coroutine<F, T, E>(
715
name: &PyString,
@@ -16,3 +24,67 @@ where
1624
{
1725
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
1826
}
27+
28+
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
29+
// SAFETY: Py<T> can be casted as *const PyCell<T>
30+
unsafe { &*(obj.as_ptr() as *const PyCell<T>) }.get_ptr()
31+
}
32+
33+
pub struct RefGuard<T: PyClass>(Py<T>);
34+
35+
impl<T: PyClass> RefGuard<T> {
36+
pub fn new(obj: &PyAny) -> PyResult<Self> {
37+
let ref_: PyRef<'_, T> = obj.extract()?;
38+
// SAFETY: `PyRef::as_ptr` returns a borrowed reference
39+
let guard = RefGuard(unsafe { Py::<T>::from_borrowed_ptr(obj.py(), ref_.as_ptr()) });
40+
mem::forget(ref_);
41+
Ok(guard)
42+
}
43+
}
44+
45+
impl<T: PyClass> Deref for RefGuard<T> {
46+
type Target = T;
47+
fn deref(&self) -> &Self::Target {
48+
// SAFETY: `RefGuard` has been built from `PyRef` and provides the same guarantees
49+
unsafe { &*get_ptr(&self.0) }
50+
}
51+
}
52+
53+
impl<T: PyClass> Drop for RefGuard<T> {
54+
fn drop(&mut self) {
55+
Python::with_gil(|gil| self.0.as_ref(gil).release_ref())
56+
}
57+
}
58+
59+
pub struct RefMutGuard<T: PyClass<Frozen = False>>(Py<T>);
60+
61+
impl<T: PyClass<Frozen = False>> RefMutGuard<T> {
62+
pub fn new(obj: &PyAny) -> PyResult<Self> {
63+
let mut_: PyRefMut<'_, T> = obj.extract()?;
64+
// // SAFETY: `PyRefMut::as_ptr` returns a borrowed reference
65+
let guard = RefMutGuard(unsafe { Py::<T>::from_borrowed_ptr(obj.py(), mut_.as_ptr()) });
66+
mem::forget(mut_);
67+
Ok(guard)
68+
}
69+
}
70+
71+
impl<T: PyClass<Frozen = False>> Deref for RefMutGuard<T> {
72+
type Target = T;
73+
fn deref(&self) -> &Self::Target {
74+
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
75+
unsafe { &*get_ptr(&self.0) }
76+
}
77+
}
78+
79+
impl<T: PyClass<Frozen = False>> DerefMut for RefMutGuard<T> {
80+
fn deref_mut(&mut self) -> &mut Self::Target {
81+
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
82+
unsafe { &mut *get_ptr(&self.0) }
83+
}
84+
}
85+
86+
impl<T: PyClass<Frozen = False>> Drop for RefMutGuard<T> {
87+
fn drop(&mut self) {
88+
Python::with_gil(|gil| self.0.as_ref(gil).release_mut())
89+
}
90+
}

src/pycell.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,16 @@ impl<T: PyClass> PyCell<T> {
516516
#[allow(clippy::useless_conversion)]
517517
offset.try_into().expect("offset should fit in Py_ssize_t")
518518
}
519+
520+
#[cfg(feature = "macros")]
521+
pub(crate) fn release_ref(&self) {
522+
self.borrow_checker().release_borrow();
523+
}
524+
525+
#[cfg(feature = "macros")]
526+
pub(crate) fn release_mut(&self) {
527+
self.borrow_checker().release_borrow_mut();
528+
}
519529
}
520530

521531
impl<T: PyClassImpl> PyCell<T> {

tests/test_coroutine.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,56 @@ fn coroutine_panic() {
234234
py_run!(gil, panic, &handle_windows(test));
235235
})
236236
}
237+
238+
#[test]
239+
fn test_async_method_receiver() {
240+
#[pyclass]
241+
struct Counter(usize);
242+
#[pymethods]
243+
impl Counter {
244+
#[new]
245+
fn new() -> Self {
246+
Self(0)
247+
}
248+
async fn get(&self) -> usize {
249+
self.0
250+
}
251+
async fn incr(&mut self) -> usize {
252+
self.0 += 1;
253+
self.0
254+
}
255+
}
256+
Python::with_gil(|gil| {
257+
let test = r#"
258+
import asyncio
259+
260+
obj = Counter()
261+
coro1 = obj.get()
262+
coro2 = obj.get()
263+
try:
264+
obj.incr() # borrow checking should fail
265+
except RuntimeError as err:
266+
pass
267+
else:
268+
assert False
269+
assert asyncio.run(coro1) == 0
270+
coro2.close()
271+
coro3 = obj.incr()
272+
try:
273+
obj.incr() # borrow checking should fail
274+
except RuntimeError as err:
275+
pass
276+
else:
277+
assert False
278+
try:
279+
obj.get() # borrow checking should fail
280+
except RuntimeError as err:
281+
pass
282+
else:
283+
assert False
284+
assert asyncio.run(coro3) == 1
285+
"#;
286+
let locals = [("Counter", gil.get_type::<Counter>())].into_py_dict(gil);
287+
py_run!(gil, *locals, test);
288+
})
289+
}

0 commit comments

Comments
 (0)