Skip to content

PyO3 Migration from v0.20.* to 0.24.* #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
837 changes: 481 additions & 356 deletions Cargo.lock

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ name = "robyn"
crate-type = ["cdylib", "rlib"]

[dependencies]
pyo3 = { version = "0.20.0", features = ["extension-module"] }
pyo3-asyncio = { version="0.20.0" , features = ["attributes", "tokio-runtime"] }
pyo3-log = "0.8.4"
tokio = { version = "1.26.0", features = ["full"] }
pyo3 = { version = "0.24.2", features = ["extension-module", "py-clone"]}
pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"] }
pyo3-async-runtimes-macros = { version = "0.24" }
pyo3-log = "0.12.3"
tokio = { version = "1.40", features = ["full"] }
dashmap = "5.4.3"
anyhow = "1.0.69"
actix = "0.13.4"
Expand All @@ -31,7 +32,7 @@ matchit = "0.7.3"
socket2 = { version = "0.5.1", features = ["all"] }
uuid = { version = "1.3.0", features = ["serde", "v4"] }
log = "0.4.17"
pythonize = "0.20.0"
pythonize = "0.24"
serde = "1.0.187"
serde_json = "1.0.109"
once_cell = "1.8.0"
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ async def message(ws: WebSocketConnector, msg: str, global_dependencies) -> str:
elif state == 1:
resp = "Whooo??"
elif state == 2:
await ws.async_broadcast(ws.query_params.get("one"))
ws.sync_send_to(websocket_id, ws.query_params.get("two"))
await ws.async_broadcast(ws.query_params.get("one", None))
ws.sync_send_to(websocket_id, ws.query_params.get("two", None))
resp = "*chika* *chika* Slim Shady."
elif state == 3:
ws.close()
Expand Down
22 changes: 12 additions & 10 deletions src/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::Arc;
use anyhow::Result;
use log::debug;
use pyo3::prelude::*;
use pyo3_asyncio::TaskLocals;
use pyo3_async_runtimes::TaskLocals;

use crate::types::{
function_info::FunctionInfo, request::Request, response::Response, MiddlewareReturn,
Expand All @@ -19,20 +19,22 @@ fn get_function_output<'a, T>(
function: &'a FunctionInfo,
py: Python<'a>,
function_args: &T,
) -> Result<&'a PyAny, PyErr>
) -> Result<pyo3::Bound<'a, pyo3::PyAny>, PyErr>
where
T: ToPyObject,
{
let handler = function.handler.as_ref(py);
let kwargs = function.kwargs.as_ref(py);
let handler = function.handler.bind(py).downcast()?;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential unhandled error case with downcast. The original code using as_ref() didn't require type checking, but the new code assumes the type can be downcast. If the type doesn't match expectations, it will fail at runtime. Need to add proper error handling or type verification.


React with 👍 to tell me that this comment was useful, or 👎 if not (and I'll stop posting more comments like this in the future)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VishnuSanal , what do you think here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is still PyAny -- we can do the following if we want. but I feel like it would be a redundant & introduce unnecessary complexity. what do you think @sansyrox?

let handler = function.handler.bind(py).downcast::<pyo3::types::PyAny>()
    .map_err(|_| PyErr::new::<pyo3::exceptions::PyTypeError, _>(
        "Wrong type found for function.handler",
    ))?;

this is still handled gracefully anyways

Err(e) => {
                error!(
                    "Error while executing route function for endpoint `{}`: {}",
                    req.uri().path(),
                    get_traceback(&e)
                );

                Response::internal_server_error(None)
            }

let kwargs = function.kwargs.bind(py);
let function_args = function_args.to_object(py);
debug!("Function args: {:?}", function_args);

match function.number_of_params {
0 => handler.call0(),
1 => {
if kwargs.get_item("global_dependencies")?.is_some()
|| kwargs.get_item("router_dependencies")?.is_some()
if pyo3::types::PyDictMethods::get_item(kwargs, "global_dependencies")
.is_ok_and(|it| !it.is_none())
|| pyo3::types::PyDictMethods::get_item(kwargs, "router_dependencies")
.is_ok_and(|it| !it.is_none())
// these are reserved keywords
{
handler.call((), Some(kwargs))
Expand All @@ -57,7 +59,7 @@ where
{
if function.is_async {
let output: Py<PyAny> = Python::with_gil(|py| {
pyo3_asyncio::tokio::into_future(get_function_output(function, py, input)?)
pyo3_async_runtimes::tokio::into_future(get_function_output(function, py, input)?)
})?
.await?;

Expand Down Expand Up @@ -88,7 +90,7 @@ pub async fn execute_http_function(
if function.is_async {
let output = Python::with_gil(|py| {
let function_output = get_function_output(function, py, request)?;
pyo3_asyncio::tokio::into_future(function_output)
pyo3_async_runtimes::tokio::into_future(function_output)
})?
.await?;

Expand All @@ -108,9 +110,9 @@ pub async fn execute_startup_handler(
if function.is_async {
debug!("Startup event handler async");
Python::with_gil(|py| {
pyo3_asyncio::into_future_with_locals(
pyo3_async_runtimes::into_future_with_locals(
task_locals,
function.handler.as_ref(py).call0()?,
function.handler.bind(py).call0()?,
)
})?
.await?;
Expand Down
49 changes: 30 additions & 19 deletions src/executors/web_socket_executors.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use actix::prelude::*;
use actix::AsyncContext;
use actix_web_actors::ws;
use actix::{ActorFutureExt, AsyncContext, WrapFuture};
use actix_web_actors::ws::WebsocketContext;
use pyo3::prelude::*;
use pyo3_asyncio::TaskLocals;
use pyo3_async_runtimes::TaskLocals;

use crate::types::function_info::FunctionInfo;
use crate::websockets::WebSocketConnector;
Expand All @@ -12,42 +11,54 @@ fn get_function_output<'a>(
fn_msg: Option<String>,
py: Python<'a>,
ws: &WebSocketConnector,
) -> Result<&'a PyAny, PyErr> {
let handler = function.handler.as_ref(py);
) -> Result<pyo3::Bound<'a, pyo3::PyAny>, PyErr> {
let handler = function.handler.bind(py).downcast()?;

// this makes the request object accessible across every route

let args = function.args.as_ref(py);
let kwargs = function.kwargs.as_ref(py);
let args = function.args.bind(py).downcast()?;
let kwargs = function.kwargs.bind(py).downcast()?;

match function.number_of_params {
0 => handler.call0(),
1 => {
if args.get_item("ws")?.is_some() {
if pyo3::types::PyDictMethods::get_item(args, "ws").is_ok_and(|it| !it.is_none()) {
handler.call1((ws.clone(),))
} else if args.get_item("msg")?.is_some() {
} else if pyo3::types::PyDictMethods::get_item(args, "msg")
.is_ok_and(|it| !it.is_none())
{
handler.call1((fn_msg.unwrap_or_default(),))
} else {
handler.call((), Some(kwargs))
}
}
2 => {
if args.get_item("ws")?.is_some() && args.get_item("msg")?.is_some() {
if pyo3::types::PyDictMethods::get_item(args, "ws").is_ok_and(|it| !it.is_none())
&& pyo3::types::PyDictMethods::get_item(args, "msg").is_ok_and(|it| !it.is_none())
{
handler.call1((ws.clone(), fn_msg.unwrap_or_default()))
} else if args.get_item("ws")?.is_some() {
} else if pyo3::types::PyDictMethods::get_item(args, "ws").is_ok_and(|it| !it.is_none())
{
handler.call((ws.clone(),), Some(kwargs))
} else if args.get_item("msg")?.is_some() {
} else if pyo3::types::PyDictMethods::get_item(args, "msg")
.is_ok_and(|it| !it.is_none())
{
handler.call((fn_msg.unwrap_or_default(),), Some(kwargs))
} else {
handler.call((), Some(kwargs))
}
}
3 => {
if args.get_item("ws")?.is_some() && args.get_item("msg")?.is_some() {
if pyo3::types::PyDictMethods::get_item(args, "ws").is_ok_and(|it| !it.is_none())
&& pyo3::types::PyDictMethods::get_item(args, "msg").is_ok_and(|it| !it.is_none())
{
handler.call((ws.clone(), fn_msg.unwrap_or_default()), Some(kwargs))
} else if args.get_item("ws")?.is_some() {
} else if pyo3::types::PyDictMethods::get_item(args, "ws").is_ok_and(|it| !it.is_none())
{
handler.call((ws.clone(),), Some(kwargs))
} else if args.get_item("msg")?.is_some() {
} else if pyo3::types::PyDictMethods::get_item(args, "msg")
.is_ok_and(|it| !it.is_none())
{
handler.call((fn_msg.unwrap_or_default(),), Some(kwargs))
} else {
handler.call((), Some(kwargs))
Expand All @@ -61,13 +72,13 @@ pub fn execute_ws_function(
function: &FunctionInfo,
text: Option<String>,
task_locals: &TaskLocals,
ctx: &mut ws::WebsocketContext<WebSocketConnector>,
ctx: &mut WebsocketContext<WebSocketConnector>,
ws: &WebSocketConnector,
// add number of params here
) {
if function.is_async {
let fut = Python::with_gil(|py| {
pyo3_asyncio::into_future_with_locals(
pyo3_async_runtimes::into_future_with_locals(
task_locals,
get_function_output(function, text, py, ws).unwrap(),
)
Expand All @@ -84,7 +95,7 @@ pub fn execute_ws_function(
Python::with_gil(|py| {
if let Some(op) = get_function_output(function, text, py, ws)
.unwrap()
.extract::<Option<&str>>()
.extract::<Option<String>>()
.unwrap()
{
ctx.text(op);
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn get_version() -> String {
}

#[pymodule]
pub fn robyn(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
pub fn robyn(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
// the pymodule class/function to make the rustPyFunctions available
m.add_function(wrap_pyfunction!(get_version, m)?)?;

Expand Down
9 changes: 5 additions & 4 deletions src/routers/const_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::types::HttpMethod;
use anyhow::Context;
use log::debug;
use matchit::Router as MatchItRouter;
use pyo3::types::PyAny;
use pyo3::{Bound, Python};

use anyhow::{Error, Result};

Expand All @@ -25,12 +25,13 @@ pub struct ConstRouter {

impl Router<Response, HttpMethod> for ConstRouter {
/// Doesn't allow query params/body/etc as variables cannot be "memoized"/"const"ified
fn add_route(
fn add_route<'py>(
&self,
_py: Python,
route_type: &HttpMethod,
route: &str,
function: FunctionInfo,
event_loop: Option<&PyAny>,
event_loop: Option<Bound<'py, pyo3::PyAny>>,
) -> Result<(), Error> {
let table = self
.routes
Expand All @@ -42,7 +43,7 @@ impl Router<Response, HttpMethod> for ConstRouter {
let event_loop =
event_loop.context("Event loop must be provided to add a route to the const router")?;

pyo3_asyncio::tokio::run_until_complete(event_loop, async move {
pyo3_async_runtimes::tokio::run_until_complete(event_loop, async move {
let output = execute_http_function(&Request::default(), &function)
.await
.unwrap();
Expand Down
11 changes: 7 additions & 4 deletions src/routers/http_router.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use parking_lot::RwLock;
use pyo3::{Bound, Python};
use std::collections::HashMap;

use matchit::Router as MatchItRouter;
use pyo3::types::PyAny;

use anyhow::{Context, Result};

Expand All @@ -18,12 +18,13 @@ pub struct HttpRouter {
}

impl Router<(FunctionInfo, HashMap<String, String>), HttpMethod> for HttpRouter {
fn add_route(
fn add_route<'py>(
&self,
_py: Python,
route_type: &HttpMethod,
route: &str,
function: FunctionInfo,
_event_loop: Option<&PyAny>,
_event_loop: Option<Bound<'py, pyo3::PyAny>>,
) -> Result<()> {
let table = self.routes.get(route_type).context("No relevant map")?;

Expand All @@ -47,7 +48,9 @@ impl Router<(FunctionInfo, HashMap<String, String>), HttpMethod> for HttpRouter
route_params.insert(key.to_string(), value.to_string());
}

Some((res.value.to_owned(), route_params))
let function_info = Python::with_gil(|_| res.value.to_owned());

Some((function_info, route_params))
}
}

Expand Down
11 changes: 7 additions & 4 deletions src/routers/middleware_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::RwLock;

use anyhow::{Context, Error, Result};
use matchit::Router as MatchItRouter;
use pyo3::types::PyAny;
use pyo3::{Bound, Python};

use crate::routers::Router;
use crate::types::function_info::{FunctionInfo, MiddlewareType};
Expand All @@ -17,12 +17,13 @@ pub struct MiddlewareRouter {
}

impl Router<(FunctionInfo, HashMap<String, String>), MiddlewareType> for MiddlewareRouter {
fn add_route(
fn add_route<'py>(
&self,
_py: Python,
route_type: &MiddlewareType,
route: &str,
function: FunctionInfo,
_event_loop: Option<&PyAny>,
_event_loop: Option<Bound<'py, pyo3::PyAny>>,
) -> Result<(), Error> {
let table = self.routes.get(route_type).context("No relevant map")?;

Expand All @@ -45,7 +46,9 @@ impl Router<(FunctionInfo, HashMap<String, String>), MiddlewareType> for Middlew
route_params.insert(key.to_string(), value.to_string());
}

Some((res.value.to_owned(), route_params))
let function_info = Python::with_gil(|_| res.value.to_owned());

Some((function_info, route_params))
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/routers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use pyo3::PyAny;
use pyo3::{Bound, Python};

use crate::types::function_info::FunctionInfo;

Expand All @@ -11,12 +11,13 @@ pub mod web_socket_router;
pub trait Router<T, U> {
/// Checks if the functions is an async function
/// Inserts them in the router according to their nature(CoRoutine/SyncFunction)
fn add_route(
fn add_route<'py>(
&self,
py: Python,
route_type: &U,
route: &str,
function: FunctionInfo,
event_loop: Option<&PyAny>,
event_loop: Option<Bound<'py, pyo3::PyAny>>,
) -> Result<()>;

/// Retrieve the correct function from the previously inserted routes
Expand Down
Loading