diff --git a/Cargo.lock b/Cargo.lock index 01788a3..90c9fc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2956,7 +2956,7 @@ dependencies = [ [[package]] name = "pvq-program-metadata-gen" -version = "0.1.0" +version = "0.2.0" dependencies = [ "clap", "parity-scale-codec", diff --git a/Makefile b/Makefile index 6cdcf85..45f80cb 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ guests: $(GUEST_TARGETS) dummy-guests: $(DUMMY_GUEST_TARGETS) guest-%: - cd guest-examples; METADATA_OUTPUT_DIR=$(realpath output) cargo build -q --release --bin guest-$* -p guest-$* + cd guest-examples; METADATA_OUTPUT_DIR=$(realpath output) cargo build --release --bin guest-$* -p guest-$* mkdir -p output polkatool link --run-only-if-newer -s guest-examples/target/riscv32emac-unknown-none-polkavm/release/guest-$* -o output/guest-$*.polkavm diff --git a/guest-examples/Cargo.lock b/guest-examples/Cargo.lock index f082258..00dbc58 100644 --- a/guest-examples/Cargo.lock +++ b/guest-examples/Cargo.lock @@ -93,6 +93,16 @@ dependencies = [ "pvq-program", ] +[[package]] +name = "guest-swap-info" +version = "0.1.0" +dependencies = [ + "cfg-if", + "parity-scale-codec", + "polkavm-derive", + "pvq-program", +] + [[package]] name = "guest-test-swap-extension" version = "0.1.0" diff --git a/guest-examples/Cargo.toml b/guest-examples/Cargo.toml index 786e079..bc31bea 100644 --- a/guest-examples/Cargo.toml +++ b/guest-examples/Cargo.toml @@ -7,6 +7,7 @@ members = [ "total-supply-hand-written", "transparent-call-hand-written", "test-swap-extension", + "swap-info", ] resolver = "2" @@ -17,4 +18,5 @@ parity-scale-codec = { version = "3", default-features = false, features = [ pvq-program = { path = "../pvq-program", default-features = false } pvq-program-metadata-gen = { path = "../pvq-program-metadata-gen" } polkavm-derive = { path = "../vendor/polkavm/crates/polkavm-derive" } +acala-primitives = { git = "https://github.com/AcalaNetwork/Acala", branch = "master", default-features = false } cfg-if = "1.0" diff --git a/guest-examples/rust-toolchain.toml b/guest-examples/rust-toolchain.toml index 650393f..912f4d0 100644 --- a/guest-examples/rust-toolchain.toml +++ b/guest-examples/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "nightly-2024-11-19" +channel = "nightly-2025-06-09" components = ["rust-src", "clippy"] diff --git a/guest-examples/swap-info/Cargo.toml b/guest-examples/swap-info/Cargo.toml new file mode 100644 index 0000000..0e8d03f --- /dev/null +++ b/guest-examples/swap-info/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "guest-swap-info" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +parity-scale-codec = { workspace = true } +polkavm-derive = { workspace = true } +pvq-program = { workspace = true } +cfg-if = { workspace = true } + +[features] +asset-hub = [] +acala = [] diff --git a/guest-examples/swap-info/build.rs b/guest-examples/swap-info/build.rs new file mode 100644 index 0000000..f20692e --- /dev/null +++ b/guest-examples/swap-info/build.rs @@ -0,0 +1,27 @@ +use std::env; +use std::path::PathBuf; +use std::process::Command; + +fn main() { + // Tell Cargo to rerun this build script if the source file changes + // println!("cargo:rerun-if-changed=src/main.rs"); + let current_dir = env::current_dir().expect("Failed to get current directory"); + // Determine the output directory for the metadata + let output_dir = PathBuf::from(env::var("METADATA_OUTPUT_DIR").expect("METADATA_OUTPUT_DIR is not set")) + .canonicalize() + .expect("Failed to canonicalize output directory"); + + // Build and run the command + let status = Command::new("pvq-program-metadata-gen") + .arg("--crate-path") + .arg(¤t_dir) + .arg("--output-dir") + .arg(&output_dir) + .env("RUST_LOG", "info") + .status() + .expect("Failed to execute pvq-program-metadata-gen"); + + if !status.success() { + panic!("Failed to generate program metadata"); + } +} diff --git a/guest-examples/swap-info/src/main.rs b/guest-examples/swap-info/src/main.rs new file mode 100644 index 0000000..5e89504 --- /dev/null +++ b/guest-examples/swap-info/src/main.rs @@ -0,0 +1,72 @@ +#![no_std] +#![no_main] + +#[pvq_program::program] +mod swap_info { + + cfg_if::cfg_if! { + if #[cfg(feature = "asset-hub")] { + // Actually AssetHub uses xcm::Location as AssetId, but we use opaque Vec because some compilation issues. + type AssetId = alloc::vec::Vec; + type Balance = u128; + } else if #[cfg(feature = "acala")] { + type AssetId = alloc::vec::Vec; + type Balance = u128; + } else { + type AssetId = alloc::vec::Vec; + type Balance = u128; + } + } + + #[program::extension_fn(extension_id = 13206387959972970661u64, fn_index = 0)] + fn quote_price_tokens_for_exact_tokens( + asset1: AssetId, + asset2: AssetId, + amount: Balance, + include_fee: bool, + ) -> Option { + } + + #[program::extension_fn(extension_id = 13206387959972970661u64, fn_index = 1)] + fn quote_price_exact_tokens_for_tokens( + asset1: AssetId, + asset2: AssetId, + amount: Balance, + include_fee: bool, + ) -> Option { + } + + #[program::extension_fn(extension_id = 13206387959972970661u64, fn_index = 2)] + fn get_liquidity_pool(asset1: AssetId, asset2: AssetId) -> Option<(Balance, Balance)> {} + + #[program::extension_fn(extension_id = 13206387959972970661u64, fn_index = 3)] + fn list_pools() -> alloc::vec::Vec<(AssetId, AssetId, Balance, Balance)> {} + + #[program::entrypoint] + fn entrypoint_quote_price_exact_tokens_for_tokens( + asset1: AssetId, + asset2: AssetId, + amount: Balance, + ) -> Option { + quote_price_exact_tokens_for_tokens(asset1, asset2, amount, true) + } + + #[program::entrypoint] + fn entrypoint_quote_price_tokens_for_exact_tokens( + asset1: AssetId, + asset2: AssetId, + amount: Balance, + ) -> Option { + quote_price_tokens_for_exact_tokens(asset1, asset2, amount, true) + } + + #[program::entrypoint] + fn entrypoint_get_liquidity_pool(asset1: AssetId, asset2: AssetId) -> Option<(Balance, Balance)> { + get_liquidity_pool(asset1, asset2) + } + + #[program::entrypoint] + fn entrypoint_list_pools() -> alloc::vec::Vec<(AssetId, AssetId, Balance, Balance)> { + list_pools() + } +} diff --git a/pvq-program-metadata-gen/Cargo.toml b/pvq-program-metadata-gen/Cargo.toml index 5f3ead5..88c1201 100644 --- a/pvq-program-metadata-gen/Cargo.toml +++ b/pvq-program-metadata-gen/Cargo.toml @@ -1,11 +1,11 @@ [package] name = "pvq-program-metadata-gen" description = "PVQ program metadata generation" +version = "0.2.0" authors.workspace = true edition.workspace = true license.workspace = true repository.workspace = true -version.workspace = true [dependencies] quote = { workspace = true } diff --git a/pvq-program-metadata-gen/src/metadata_gen.rs b/pvq-program-metadata-gen/src/metadata_gen.rs index 0e00607..07dd7cc 100644 --- a/pvq-program-metadata-gen/src/metadata_gen.rs +++ b/pvq-program-metadata-gen/src/metadata_gen.rs @@ -30,12 +30,12 @@ pub fn metadata_gen_src(source: &str, pkg_name: &str, output_dir: &str) -> syn:: let program_mod_items = &mut program_mod.content.as_mut().expect("This is checked before").1; // Find entrypoint and extension functions - let mut entrypoint_metadata = None; + let mut entrypoints_metadata = Vec::new(); let mut extension_fns_metadata = Vec::new(); + let mut remaining_items = Vec::new(); - for i in (0..program_mod_items.len()).rev() { - let item = &mut program_mod_items[i]; - if let Some(attr) = crate::helper::take_first_program_attr(item)? { + for mut item in program_mod_items.drain(..) { + if let Some(attr) = crate::helper::take_first_program_attr(&mut item)? { if let Some(last_segment) = attr.path().segments.last() { if last_segment.ident == "extension_fn" { let mut extension_id = None; @@ -55,7 +55,6 @@ pub fn metadata_gen_src(source: &str, pkg_name: &str, output_dir: &str) -> syn:: } Ok(()) })?; - let removed_item = program_mod_items.remove(i); if extension_id.is_none() || fn_index.is_none() { return Err(syn::Error::new( attr.span(), @@ -66,14 +65,11 @@ pub fn metadata_gen_src(source: &str, pkg_name: &str, output_dir: &str) -> syn:: extension_id.ok_or_else(|| syn::Error::new(attr.span(), "Extension ID is required"))?; let fn_index = fn_index.ok_or_else(|| syn::Error::new(attr.span(), "Function index is required"))?; - let extension_fn_metadata = generate_extension_fn_metadata(removed_item, extension_id, fn_index)?; + let extension_fn_metadata = generate_extension_fn_metadata(item, extension_id, fn_index)?; extension_fns_metadata.push(extension_fn_metadata); } else if last_segment.ident == "entrypoint" { - if entrypoint_metadata.is_some() { - return Err(syn::Error::new(attr.span(), "Multiple entrypoint functions found")); - } - let removed_item = program_mod_items.remove(i); - entrypoint_metadata = Some(generate_entrypoint_metadata(removed_item)?); + let entrypoint_metadata = generate_entrypoint_metadata(item)?; + entrypoints_metadata.push(entrypoint_metadata); } else { return Err(syn::Error::new( attr.span(), @@ -81,23 +77,29 @@ pub fn metadata_gen_src(source: &str, pkg_name: &str, output_dir: &str) -> syn:: )); } } + } else { + remaining_items.push(item) } } - let entrypoint_metadata = entrypoint_metadata - .ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), "No entrypoint function found"))?; + if entrypoints_metadata.is_empty() { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + "No entrypoint function found", + )); + } let metadata_defs = metadata_defs(); let import_packages = import_packages(); let new_items = quote! { - #(#program_mod_items)* + #(#remaining_items)* #import_packages #metadata_defs fn main() { let extension_fns = vec![ #( #extension_fns_metadata, )* ]; - let entrypoint = #entrypoint_metadata; - let metadata = Metadata::new(extension_fns, entrypoint); + let entrypoints = vec![ #( #entrypoints_metadata, )* ]; + let metadata = Metadata::new(extension_fns, entrypoints); // Serialize to both formats let encoded = parity_scale_codec::Encode::encode(&metadata); let json = serde_json::to_string(&metadata).expect("Failed to serialize metadata to JSON"); @@ -231,21 +233,24 @@ fn metadata_defs() -> proc_macro2::TokenStream { pub struct Metadata { pub types: PortableRegistry, pub extension_fns: Vec<(ExtensionId, FnIndex, FunctionMetadata)>, - pub entrypoint: FunctionMetadata, + pub entrypoints: Vec>, } impl Metadata { - pub fn new(extension_fns: Vec<(ExtensionId, FnIndex, FunctionMetadata)>, entrypoint: FunctionMetadata) -> Self { + pub fn new(extension_fns: Vec<(ExtensionId, FnIndex, FunctionMetadata)>, entrypoints: Vec) -> Self { let mut registry = Registry::new(); let extension_fns = extension_fns .into_iter() .map(|(id, index, metadata)| (id, index, metadata.into_portable(&mut registry))) .collect(); - let entrypoint = entrypoint.into_portable(&mut registry); + let entrypoints = entrypoints + .into_iter() + .map(|metadata| metadata.into_portable(&mut registry)) + .collect(); Self { types: registry.into(), extension_fns, - entrypoint, + entrypoints, } } } diff --git a/pvq-program/procedural/src/program/expand/mod.rs b/pvq-program/procedural/src/program/expand/mod.rs index d2769c5..e685fab 100644 --- a/pvq-program/procedural/src/program/expand/mod.rs +++ b/pvq-program/procedural/src/program/expand/mod.rs @@ -79,42 +79,45 @@ fn expand_extension_fn(extension_fn: &mut ExtensionFn, parity_scale_codec: &syn: fn expand_main(def: &Def) -> TokenStream2 { let parity_scale_codec = &def.parity_scale_codec; - // Get `ident: Type`s - let arg_pats = def.entrypoint.item_fn.sig.inputs.iter().collect::>(); - // Get `ident`s - let arg_identifiers = arg_pats - .iter() - .map(|arg| { - if let syn::FnArg::Typed(pat_type) = arg { - pat_type.pat.to_token_stream() - } else { - unreachable!("Checked in parse stage") - } - }) - .collect::>(); - let arg_identifiers_str = arg_identifiers.iter().map(|arg| arg.to_string()).collect::>(); - - let decode_args = quote! { - #(let #arg_pats = #parity_scale_codec::Decode::decode(&mut arg_bytes).expect(concat!("Failed to decode ", #arg_identifiers_str));)* - }; + // Generate match arms for each entrypoint + let match_arms = def.entrypoints.iter().enumerate().map(|(index, entrypoint)| { + let entrypoint_ident = &entrypoint.item_fn.sig.ident; + let arg_pats = entrypoint.item_fn.sig.inputs.iter().collect::>(); + let arg_identifiers = arg_pats + .iter() + .map(|arg| { + if let syn::FnArg::Typed(pat_type) = arg { + pat_type.pat.to_token_stream() + } else { + unreachable!("Checked in parse stage") + } + }) + .collect::>(); - let entrypoint_ident = &def.entrypoint.item_fn.sig.ident; - let call_entrypoint = quote! { - let res = #entrypoint_ident(#(#arg_identifiers),*); - }; + quote! { + #index => { + #(let #arg_pats = #parity_scale_codec::Decode::decode(&mut arg_bytes) + .expect(concat!("Failed to decode arguments for ", stringify!(#entrypoint_ident)));)* + let res = #entrypoint_ident(#(#arg_identifiers),*); + let encoded_res = #parity_scale_codec::Encode::encode(&res); + (encoded_res.len() as u64) << 32 | (encoded_res.as_ptr() as u64) + } + } + }); quote! { #[polkavm_derive::polkavm_export] extern "C" fn pvq(arg_ptr: u32, size: u32) -> u64 { - let mut arg_bytes = unsafe { core::slice::from_raw_parts(arg_ptr as *const u8, size as usize) }; + // First stage: read fn_index + let fn_index = unsafe { *(arg_ptr as *const u8) } as usize; - #decode_args - - #call_entrypoint - - let encoded_res = #parity_scale_codec::Encode::encode(&res); - (encoded_res.len() as u64) << 32 | (encoded_res.as_ptr() as u64) + // Second stage: read arg_bytes + let mut arg_bytes = unsafe { core::slice::from_raw_parts((arg_ptr + 1) as *const u8, (size - 1) as usize) }; + match fn_index { + #(#match_arms,)* + _ => panic!("Invalid function index"), + } } } } diff --git a/pvq-program/procedural/src/program/parse/mod.rs b/pvq-program/procedural/src/program/parse/mod.rs index 187ef53..ba9fa65 100644 --- a/pvq-program/procedural/src/program/parse/mod.rs +++ b/pvq-program/procedural/src/program/parse/mod.rs @@ -8,7 +8,7 @@ mod helper; pub struct Def { pub item: syn::ItemMod, pub extension_fns: Vec, - pub entrypoint: EntrypointDef, + pub entrypoints: Vec, pub parity_scale_codec: syn::Path, pub polkavm_derive: syn::Path, } @@ -28,11 +28,11 @@ impl Def { .1; let mut extension_fns = Vec::new(); - let mut entrypoint = None; + let mut entrypoints = Vec::new(); + let mut remaining_items = Vec::new(); - for i in (0..items.len()).rev() { - let item = &mut items[i]; - if let Some(attr) = helper::take_first_program_attr(item)? { + for mut item in items.drain(..) { + if let Some(attr) = helper::take_first_program_attr(&mut item)? { if let Some(last_segment) = attr.path().segments.last() { if last_segment.ident == "extension_fn" { let mut extension_id = None; @@ -53,15 +53,13 @@ impl Def { Ok(()) })?; - let removed_item = items.remove(i); - let extension_fn = ExtensionFn::try_from(attr.span(), removed_item, extension_id, fn_index)?; + let extension_fn = ExtensionFn::try_from(attr.span(), item, extension_id, fn_index)?; extension_fns.push(extension_fn); continue; } else if last_segment.ident == "entrypoint" { - if entrypoint.is_some() { - return Err(syn::Error::new(attr.span(), "Only one entrypoint function is allowed")); - } - entrypoint = Some(EntrypointDef::try_from(attr.span(), item)?); + let entrypoint = EntrypointDef::try_from(attr.span(), &mut item)?; + entrypoints.push(entrypoint); + remaining_items.push(item); continue; } else { return Err(syn::Error::new( @@ -71,14 +69,23 @@ impl Def { } } } + remaining_items.push(item); } - let entrypoint = - entrypoint.ok_or_else(|| syn::Error::new(mod_span, "At least one entrypoint function is required"))?; + if entrypoints.is_empty() { + return Err(syn::Error::new( + mod_span, + "At least one entrypoint function is required", + )); + } + + // Put remaining items back + items.extend(remaining_items); + let def = Def { item, extension_fns, - entrypoint, + entrypoints, parity_scale_codec, polkavm_derive, }; diff --git a/pvq-test-runner/Cargo.toml b/pvq-test-runner/Cargo.toml index da12207..3729175 100644 --- a/pvq-test-runner/Cargo.toml +++ b/pvq-test-runner/Cargo.toml @@ -14,8 +14,8 @@ tracing-subscriber = { workspace = true } parity-scale-codec = { workspace = true, features = ["std"] } scale-info = { workspace = true, features = ["std", "serde"] } sp-core = { workspace = true, features = ["std"] } -serde = { workspace = true } -serde_json = { workspace = true } +serde = { workspace = true, features = ["std"] } +serde_json = { workspace = true, features = ["std"] } pvq-executor = { workspace = true, features = ["std"] } pvq-extension = { workspace = true, features = ["std"] } diff --git a/pvq-test-runner/src/lib.rs b/pvq-test-runner/src/lib.rs index 33fb957..e9f1b2a 100644 --- a/pvq-test-runner/src/lib.rs +++ b/pvq-test-runner/src/lib.rs @@ -42,7 +42,7 @@ pub mod extensions { #[extensions_impl::extension] impl pvq_extension_swap::extension::ExtensionSwap for ExtensionsImpl { type AssetId = Vec; - type Balance = u64; + type Balance = u128; fn quote_price_tokens_for_exact_tokens( _asset1: Self::AssetId, _asset2: Self::AssetId, @@ -88,14 +88,15 @@ impl TestRunner { let mut input_data = Vec::new(); if program_path.contains("sum-balance") { + input_data.extend_from_slice(&[0u8]); input_data.extend_from_slice(&21u32.encode()); - let alice_account: [u8; 32] = AccountId32::from_ss58check("5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY") .expect("Failed to decode Alice's address") .into(); input_data.extend_from_slice(&vec![alice_account].encode()); } else if program_path.contains("total-supply") { + input_data.extend_from_slice(&[0u8]); input_data.extend_from_slice(&21u32.encode()); } else if program_path.contains("transparent-call") { input_data.extend_from_slice(&4071833530116166512u64.encode()); @@ -113,6 +114,12 @@ impl TestRunner { let asset2 = u32::encode(&22); input_data.extend_from_slice(&asset1.encode()); input_data.extend_from_slice(&asset2.encode()); + } else if program_path.contains("swap-info") { + input_data.extend_from_slice(&[0u8]); + let asset1 = u32::encode(&21); + let asset2 = u32::encode(&22); + input_data.extend_from_slice(&asset1.encode()); + input_data.extend_from_slice(&asset2.encode()); } tracing::info!("Input data (hex): {}", HexDisplay::from(&input_data)); input_data