diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index ce700ef6..4c144a8c 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -49,7 +49,7 @@ pub mod dynamic; pub use crate::dynamic::TermType; pub mod schedule; -pub use crate::schedule::SchedulerFlags; +pub use crate::schedule::{Schedule, SchedulerFlags}; pub mod env; pub use crate::env::{Env, OwnedEnv}; pub mod thread; diff --git a/rustler/src/schedule.rs b/rustler/src/schedule.rs index ccc5de1c..8b5540d3 100644 --- a/rustler/src/schedule.rs +++ b/rustler/src/schedule.rs @@ -1,5 +1,9 @@ +use rustler_sys::c_char; + +use crate::codegen_runtime::{NifReturnable, NifReturned}; use crate::wrapper::ErlNifTaskFlags; use crate::Env; +use std::{ffi::CStr, marker::PhantomData}; pub enum SchedulerFlags { Normal = ErlNifTaskFlags::ERL_NIF_NORMAL_JOB as isize, @@ -11,3 +15,112 @@ pub fn consume_timeslice(env: Env, percent: i32) -> bool { let success = unsafe { rustler_sys::enif_consume_timeslice(env.as_c_arg(), percent) }; success == 1 } + +/// Convenience macro for scheduling a future invokation of a NIF. +#[macro_export] +macro_rules! reschedule { + ($flags:expr, $($arg:expr),*) => ( + rustler::Schedule::from(($flags, $($arg,)*)) + ) +} + +/// Convenience type for scheduling a future invokation of a NIF. +/// +/// ## Usage: +/// +/// The first generic type should be the NIF that will be scheduled, with a +/// current limitation being that it must be same throughout the lifetime of the +/// NIF. +/// +/// The second generic type defined should be the type of the return value. +/// +/// Every other generic type is optional, but should reflect the argument types +/// of the scheduled NIF, in the same order. +/// +/// ## Example: +/// ```rust,ignore +/// #[nif] +/// fn factorial(input: u8, result: Option) -> Schedule> { +/// let result = result.unwrap_or(1); +/// if input == 0 { +/// Schedule::Return(result) +/// } else { +/// // alternatively `Schedule::Continue2(std::marker::PhantomData, SchedulerFlags::Normal, input - 1, Some(result * input as u32))` +/// // or `Schedule::continue2(SchedulerFlags::Normal, input - 1, Some(result * input as u32))` +/// // or `Schedule::from((SchedulerFlags::Normal, input - 1, Some(result * input as u32)))` +/// // or `(SchedulerFlags::Normal, input - 1, Some(result * input as u32)).into()` +/// reschedule!(SchedulerFlags::Normal, input - 1, Some(result * input as u32)) +/// } +/// } +/// ``` +pub enum Schedule { + /// The final result type to return back to the BEAM. + Return(T), + /// Single- and multiple-argument variants that should reflect the scheduled + /// NIF's function signature. + Continue1(PhantomData, SchedulerFlags, A), + Continue2(PhantomData, SchedulerFlags, A, B), + Continue3(PhantomData, SchedulerFlags, A, B, C), + Continue4(PhantomData, SchedulerFlags, A, B, C, D), + Continue5(PhantomData, SchedulerFlags, A, B, C, D, E), + Continue6(PhantomData, SchedulerFlags, A, B, C, D, E, F), + Continue7(PhantomData, SchedulerFlags, A, B, C, D, E, F, G), +} + +macro_rules! impls { + ($($variant:ident $func_name:ident($($arg:ident : $ty:ty,)*);)*) => { + impl Schedule { + $(#[allow(clippy::many_single_char_names)] + #[allow(clippy::too_many_arguments)] + #[inline] + pub fn $func_name(flags: SchedulerFlags, $($arg: $ty),*) -> Self { + Self::$variant(PhantomData, flags, $($arg),*) + })* + } + + $(impl From<(SchedulerFlags, $($ty),*)> for Schedule { + #[allow(clippy::many_single_char_names)] + #[inline] + fn from((flags, $($arg),*): (SchedulerFlags, $($ty),*)) -> Self { + Self::$func_name(flags, $($arg),*) + } + })* + + unsafe impl NifReturnable for Schedule + where + N: crate::Nif, + T: crate::Encoder, + A: crate::Encoder, + B: crate::Encoder, + C: crate::Encoder, + D: crate::Encoder, + E: crate::Encoder, + F: crate::Encoder, + G: crate::Encoder, + { + #[inline] + unsafe fn into_returned(self, env: Env) -> NifReturned { + #[allow(clippy::many_single_char_names)] + match self { + Self::Return(res) => NifReturned::Term(res.encode(env).as_c_arg()), + $(Self::$variant(_, flags, $($arg),*) => NifReturned::Reschedule { + fun_name: CStr::from_ptr(N::NAME as *const c_char).into(), + flags, + fun: N::RAW_FUNC, + args: vec![$($arg.encode(env).as_c_arg()),*], + },)* + } + } + } + }; +} + +impls! { + Continue1 continue1(a: A,); + Continue2 continue2(a: A, b: B,); + Continue3 continue3(a: A, b: B, c: C,); + Continue4 continue4(a: A, b: B, c: C, d: D,); + Continue5 continue5(a: A, b: B, c: C, d: D, e: E,); + Continue6 continue6(a: A, b: B, c: C, d: D, e: E, f: F,); + Continue7 continue7(a: A, b: B, c: C, d: D, e: E, f: F, g: G,); +} diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 1adaded7..5c895911 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -114,6 +114,8 @@ defmodule RustlerTest do def sum_range(_), do: err() + def scheduled_fac(_, _ \\ nil), do: err() + def bad_arg_error(), do: err() def atom_str_error(), do: err() def raise_atom_error(), do: err() diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index efe2372e..d07bc6b0 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -10,6 +10,7 @@ mod test_nif_attrs; mod test_primitives; mod test_range; mod test_resource; +mod test_schedule; mod test_term; mod test_thread; mod test_tuple; @@ -82,6 +83,7 @@ rustler::init!( test_codegen::tuplestruct_record_echo, test_dirty::dirty_cpu, test_dirty::dirty_io, + test_schedule::scheduled_fac, test_range::sum_range, test_error::bad_arg_error, test_error::atom_str_error, diff --git a/rustler_tests/native/rustler_test/src/test_schedule.rs b/rustler_tests/native/rustler_test/src/test_schedule.rs new file mode 100644 index 00000000..564fba3d --- /dev/null +++ b/rustler_tests/native/rustler_test/src/test_schedule.rs @@ -0,0 +1,15 @@ +use rustler::{reschedule, Schedule, SchedulerFlags}; + +#[rustler::nif] +fn scheduled_fac(input: u8, result: Option) -> Schedule> { + let result = result.unwrap_or(1); + if input == 0 { + Schedule::Return(result) + } else { + reschedule!( + SchedulerFlags::Normal, + input - 1, + Some(result * input as u32) + ) + } +} diff --git a/rustler_tests/test/schedule_test.exs b/rustler_tests/test/schedule_test.exs new file mode 100644 index 00000000..26884ead --- /dev/null +++ b/rustler_tests/test/schedule_test.exs @@ -0,0 +1,7 @@ +defmodule RustlerTest.ScheduleTest do + use ExUnit.Case, async: true + + test "scheduled factorial" do + assert 24 == RustlerTest.scheduled_fac(4) + end +end