diff --git a/.github/workflows/starknet-blocks.yml b/.github/workflows/starknet-blocks.yml index 5513f8b29..b89f5bd53 100644 --- a/.github/workflows/starknet-blocks.yml +++ b/.github/workflows/starknet-blocks.yml @@ -32,7 +32,7 @@ jobs: with: repository: lambdaclass/starknet-replay path: starknet-replay - ref: 1b8e2e0be21a8df9f5f6b8f8514d1a40b456ef58 + ref: d36491aa5fca3f48b4d7fb25eba599603ff48225 # We need native to use the linux deps ci action - name: Checkout Native uses: actions/checkout@v4 @@ -43,8 +43,7 @@ jobs: with: repository: lambdaclass/sequencer path: sequencer - ref: 40331042c1149f5cb84b27f9dd8d47994a010bbe - + ref: 14be65ca995ac702bad26ac20f2a522d9515f70a - name: Cache RPC Calls uses: actions/cache@v4.2.0 with: diff --git a/Cargo.toml b/Cargo.toml index 0f4789419..08b46bc6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ scarb = ["build-cli", "dep:scarb-ui", "dep:scarb-metadata"] with-cheatcode = [] with-debug-utils = [] with-mem-tracing = [] +with-libfunc-profiling = [] with-segfault-catcher = [] with-trace-dump = ["dep:sierra-emu"] diff --git a/src/bin/cairo-native-run.rs b/src/bin/cairo-native-run.rs index 4440026ff..ce40390b2 100644 --- a/src/bin/cairo-native-run.rs +++ b/src/bin/cairo-native-run.rs @@ -3,7 +3,11 @@ use cairo_lang_compiler::{ compile_prepared_db, db::RootDatabase, project::setup_project, CompilerConfig, }; use cairo_lang_runner::short_string::as_cairo_short_string; +#[cfg(feature = "with-libfunc-profiling")] +use cairo_lang_sierra::ids::ConcreteLibfuncId; use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig; +#[cfg(feature = "with-libfunc-profiling")] +use cairo_native::metadata::profiler::LibfuncProfileData; use cairo_native::{ context::NativeContext, executor::{AotNativeExecutor, JitNativeExecutor}, @@ -11,6 +15,8 @@ use cairo_native::{ starknet_stub::StubSyscallHandler, }; use clap::{Parser, ValueEnum}; +#[cfg(feature = "with-libfunc-profiling")] +use std::collections::HashMap; use std::path::PathBuf; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use utils::{find_function, result_to_runresult}; @@ -46,6 +52,11 @@ struct Args { #[arg(short = 'O', long, default_value_t = 0)] opt_level: u8, + #[cfg(feature = "with-libfunc-profiling")] + #[arg(long)] + /// The output path for the libfunc profilling results + profiler_output: Option, + #[cfg(feature = "with-trace-dump")] #[arg(long)] /// The output path for the execution trace @@ -131,6 +142,18 @@ fn main() -> anyhow::Result<()> { } } + #[cfg(feature = "with-libfunc-profiling")] + { + use cairo_native::metadata::profiler::ProfilerBinding; + + if let Some(trace_id) = + executor.find_symbol_ptr(ProfilerBinding::ProfileId.symbol()) + { + let trace_id = trace_id.cast::(); + unsafe { *trace_id = 0 }; + } + } + Box::new(move |function_id, args, gas, syscall_handler| { executor.invoke_dynamic_with_syscall_handler( function_id, @@ -153,6 +176,16 @@ fn main() -> anyhow::Result<()> { ); } + #[cfg(feature = "with-libfunc-profiling")] + { + use cairo_native::metadata::profiler::{ProfilerImpl, LIBFUNC_PROFILE}; + + LIBFUNC_PROFILE + .lock() + .unwrap() + .insert(0, ProfilerImpl::new()); + } + let gas_metadata = GasMetadata::new(&sierra_program, Some(MetadataComputationConfig::default())).unwrap(); @@ -188,6 +221,66 @@ fn main() -> anyhow::Result<()> { println!("Remaining gas: {gas}"); } + #[cfg(feature = "with-libfunc-profiling")] + { + use std::{fs::File, io::Write}; + + let profile = cairo_native::metadata::profiler::LIBFUNC_PROFILE + .lock() + .unwrap(); + + assert_eq!(profile.values().len(), 1); + + let profile = profile.values().next().unwrap(); + + if let Some(profiler_output_path) = args.profiler_output { + let mut output = File::create(profiler_output_path)?; + + let raw_profile = profile.get_profile(&sierra_program); + let mut processed_profile = process_profile(raw_profile); + + processed_profile.sort_by_key(|LibfuncProfileSummary { libfunc_idx, .. }| { + sierra_program + .libfunc_declarations + .iter() + .enumerate() + .find_map(|(i, x)| (x.id == *libfunc_idx).then_some(i)) + .unwrap() + }); + + for LibfuncProfileSummary { + libfunc_idx, + samples, + total_time, + average_time, + std_deviation, + quartiles, + } in processed_profile + { + writeln!(output, "{libfunc_idx}")?; + writeln!(output, " Total Samples: {samples}")?; + + let (Some(total_time), Some(average_time), Some(std_deviation), Some(quartiles)) = + (total_time, average_time, std_deviation, quartiles) + else { + writeln!(output, " Total Execution Time: none")?; + writeln!(output, " Average Execution Time: none")?; + writeln!(output, " Standard Deviation: none")?; + writeln!(output, " Quartiles: none")?; + writeln!(output)?; + + continue; + }; + + writeln!(output, " Total Execution Time: {total_time}")?; + writeln!(output, " Average Execution Time: {average_time}")?; + writeln!(output, " Standard Deviation: {std_deviation}")?; + writeln!(output, " Quartiles: {quartiles:?}")?; + writeln!(output)?; + } + } + } + #[cfg(feature = "with-trace-dump")] if let Some(trace_output) = args.trace_output { let traces = cairo_native::metadata::trace_dump::trace_dump_runtime::TRACE_DUMP @@ -205,3 +298,92 @@ fn main() -> anyhow::Result<()> { Ok(()) } + +#[cfg(feature = "with-libfunc-profiling")] +struct LibfuncProfileSummary { + pub libfunc_idx: ConcreteLibfuncId, + pub samples: u64, + pub total_time: Option, + pub average_time: Option, + pub std_deviation: Option, + pub quartiles: Option<[u64; 5]>, +} + +#[cfg(feature = "with-libfunc-profiling")] +fn process_profile( + profiles: HashMap, +) -> Vec { + profiles + .into_iter() + .map( + |( + libfunc_idx, + LibfuncProfileData { + mut deltas, + extra_counts, + }, + )| { + // if no deltas were registered, we only return the libfunc's calls amount + if deltas.is_empty() { + return LibfuncProfileSummary { + libfunc_idx, + samples: extra_counts, + total_time: None, + average_time: None, + std_deviation: None, + quartiles: None, + }; + } + + deltas.sort(); + + // Drop outliers. + { + let q1 = deltas[deltas.len() / 4]; + let q3 = deltas[3 * deltas.len() / 4]; + let iqr = q3 - q1; + + let q1_thr = q1.saturating_sub(iqr + iqr / 2); + let q3_thr = q3 + (iqr + iqr / 2); + + deltas.retain(|x| *x >= q1_thr && *x <= q3_thr); + } + + // Compute the quartiles. + let quartiles = [ + *deltas.first().unwrap(), + deltas[deltas.len() / 4], + deltas[deltas.len() / 2], + deltas[3 * deltas.len() / 4], + *deltas.last().unwrap(), + ]; + + // Compute the average. + let average = deltas.iter().copied().sum::() as f64 / deltas.len() as f64; + + // Compute the standard deviation. + let std_dev = { + let sum = deltas + .iter() + .copied() + .map(|x| x as f64) + .map(|x| (x - average)) + .map(|x| x * x) + .sum::(); + sum / (deltas.len() as u64 + extra_counts) as f64 + }; + + LibfuncProfileSummary { + libfunc_idx, + samples: deltas.len() as u64 + extra_counts, + total_time: Some( + deltas.iter().sum::() + (extra_counts as f64 * average).round() as u64, + ), + average_time: Some(average), + std_deviation: Some(std_dev), + quartiles: Some(quartiles), + } + }, + ) + .collect::>() +} diff --git a/src/compiler.rs b/src/compiler.rs index 2e998d241..cecf22f5f 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -550,26 +550,6 @@ fn compile_func( let (state, _) = edit_state::take_args(state, invocation.args.iter())?; - let helper = LibfuncHelper { - module, - init_block: &pre_entry_block, - region: ®ion, - blocks_arena: &blocks_arena, - last_block: Cell::new(block), - branches: generate_branching_targets( - &blocks, - statements, - statement_idx, - invocation, - &state, - ), - results: invocation - .branches - .iter() - .map(|x| vec![Cell::new(None); x.results.len()]) - .collect::>(), - }; - let libfunc = registry.get_libfunc(&invocation.libfunc_id)?; if is_recursive { if let Some(target) = libfunc.is_function_call() { @@ -620,6 +600,46 @@ fn compile_func( } } + #[allow(unused_mut)] + let mut helper = LibfuncHelper { + module, + init_block: &pre_entry_block, + region: ®ion, + blocks_arena: &blocks_arena, + last_block: Cell::new(block), + branches: generate_branching_targets( + &blocks, + statements, + statement_idx, + invocation, + &state, + ), + results: invocation + .branches + .iter() + .map(|x| vec![Cell::new(None); x.results.len()]) + .collect::>(), + + #[cfg(feature = "with-libfunc-profiling")] + profiler: match libfunc { + CoreConcreteLibfunc::FunctionCall(_) => { + // Tail-recursive function calls are broken beacuse a stack of timestamps is required, + // which would invalidate tail recursion. Also, since each libfunc is measured individually, + // it doesn't make sense to take function calls into account, therefore it's ignored on purpose. + None + } + _ => match metadata.remove::() + { + Some(profiler_meta) => { + let t0 = profiler_meta + .measure_timestamp(context, block, location)?; + Some((profiler_meta, statement_idx, t0)) + } + None => None, + }, + }, + }; + libfunc.build( context, registry, @@ -651,6 +671,11 @@ fn compile_func( libfunc_name ); + #[cfg(feature = "with-libfunc-profiling")] + if let Some((profiler_meta, _, _)) = helper.profiler.take() { + metadata.insert(profiler_meta); + } + if let Some(tailrec_meta) = metadata.remove::() { if let Some(return_block) = tailrec_meta.return_target() { tailrec_state = Some((tailrec_meta.depth_counter(), return_block)); diff --git a/src/context.rs b/src/context.rs index de8802c43..a6167780f 100644 --- a/src/context.rs +++ b/src/context.rs @@ -163,6 +163,9 @@ impl NativeContext { // already some metadata of the same type. metadata.insert(gas_metadata); + #[cfg(feature = "with-libfunc-profiling")] + metadata.insert(crate::metadata::profiler::ProfilerMeta::new()); + // Create the Sierra program registry let registry = ProgramRegistry::::new(program)?; diff --git a/src/executor/aot.rs b/src/executor/aot.rs index a792b8c2b..56ee1b2b1 100644 --- a/src/executor/aot.rs +++ b/src/executor/aot.rs @@ -60,6 +60,9 @@ impl AotNativeExecutor { #[cfg(feature = "with-trace-dump")] crate::metadata::trace_dump::setup_runtime(|name| executor.find_symbol_ptr(name)); + #[cfg(feature = "with-libfunc-profiling")] + crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name)); + executor } diff --git a/src/executor/contract.rs b/src/executor/contract.rs index b53e06a2e..833a4d24a 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -330,6 +330,9 @@ impl AotContractExecutor { #[cfg(feature = "with-trace-dump")] crate::metadata::trace_dump::setup_runtime(|name| executor.find_symbol_ptr(name)); + #[cfg(feature = "with-libfunc-profiling")] + crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name)); + Ok(Some(executor)) } diff --git a/src/executor/jit.rs b/src/executor/jit.rs index 899c955db..f7cbe58c7 100644 --- a/src/executor/jit.rs +++ b/src/executor/jit.rs @@ -71,6 +71,9 @@ impl<'m> JitNativeExecutor<'m> { #[cfg(feature = "with-trace-dump")] crate::metadata::trace_dump::setup_runtime(|name| executor.find_symbol_ptr(name)); + #[cfg(feature = "with-libfunc-profiling")] + crate::metadata::profiler::setup_runtime(|name| executor.find_symbol_ptr(name)); + Ok(executor) } diff --git a/src/libfuncs.rs b/src/libfuncs.rs index b29c33b27..b0718cf49 100644 --- a/src/libfuncs.rs +++ b/src/libfuncs.rs @@ -26,7 +26,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::{arith, cf}, - ir::{Block, BlockLike, BlockRef, Location, Module, Operation, Region, Value}, + ir::{Block, BlockLike, BlockRef, Location, Module, Region, Value}, Context, }; use num_bigint::BigInt; @@ -279,6 +279,14 @@ where pub branches: Vec<(&'this Block<'ctx>, Vec>)>, pub results: Vec>>>>, + + #[cfg(feature = "with-libfunc-profiling")] + // Since function calls don't get profiled, this field is optional + pub profiler: Option<( + crate::metadata::profiler::ProfilerMeta, + cairo_lang_sierra::program::StatementIdx, + (Value<'ctx, 'this>, Value<'ctx, 'this>), + )>, } impl<'ctx, 'this> LibfuncHelper<'ctx, 'this> @@ -330,10 +338,11 @@ where /// used later on when required. fn br( &self, + block: &'this Block<'ctx>, branch: usize, results: &[Value<'ctx, 'this>], location: Location<'ctx>, - ) -> Operation<'ctx> { + ) -> Result<()> { let (successor, operands) = &self.branches[branch]; for (dst, src) in self.results[branch].iter().zip(results) { @@ -349,7 +358,16 @@ where }) .collect::>(); - cf::br(successor, &destination_operands, location) + #[cfg(feature = "with-libfunc-profiling")] + self.push_profiler_frame( + unsafe { self.context().to_ref() }, + self.module, + block, + location, + )?; + + block.append_operation(cf::br(successor, &destination_operands, location)); + Ok(()) } /// Creates a conditional binary branching operation, potentially jumping out of the libfunc and @@ -364,11 +382,12 @@ where fn cond_br( &self, context: &'ctx Context, + block: &'this Block<'ctx>, condition: Value<'ctx, 'this>, branches: [usize; 2], results: [&[Value<'ctx, 'this>]; 2], location: Location<'ctx>, - ) -> Operation<'ctx> { + ) -> Result<()> { let (block_true, args_true) = { let (successor, operands) = &self.branches[branches[0]]; @@ -407,7 +426,10 @@ where (*successor, destination_operands) }; - cf::cond_br( + #[cfg(feature = "with-libfunc-profiling")] + self.push_profiler_frame(context, self.module, block, location)?; + + block.append_operation(cf::cond_br( context, condition, block_true, @@ -415,7 +437,26 @@ where &args_true, &args_false, location, - ) + )); + Ok(()) + } + + #[cfg(feature = "with-libfunc-profiling")] + fn push_profiler_frame( + &self, + context: &'ctx Context, + module: &'this Module, + block: &'this Block<'ctx>, + location: Location<'ctx>, + ) -> Result<()> { + if let Some((profiler_meta, statement_idx, t0)) = self.profiler.as_ref() { + let t0 = *t0; + let t1 = profiler_meta.measure_timestamp(context, block, location)?; + + profiler_meta.push_frame(context, module, block, statement_idx.0, t0, t1, location)?; + } + + Ok(()) } } @@ -488,6 +529,5 @@ fn build_noop<'ctx, 'this, const N: usize, const PROCESS_BUILTINS: bool>( params.push(param_val); } - entry.append_operation(helper.br(0, ¶ms, location)); - Ok(()) + helper.br(entry, 0, ¶ms, location) } diff --git a/src/libfuncs/array.rs b/src/libfuncs/array.rs index 4971aab65..f43130867 100644 --- a/src/libfuncs/array.rs +++ b/src/libfuncs/array.rs @@ -149,8 +149,7 @@ pub fn build_new<'ctx, 'this>( ))?; let value = entry.insert_values(context, location, value, &[nullptr, k0, k0, k0])?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } /// Buils a span (a cairo native array) from a boxed tuple of same-type elements. @@ -262,8 +261,7 @@ pub fn build_span_from_tuple<'ctx, 'this>( &[array_ptr_ptr, k0, array_len, array_len], )?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } /// Buils a tuple (struct) from an span (a cairo native array) @@ -457,7 +455,7 @@ pub fn build_tuple_from_span<'ctx, 'this>( location, )); - valid_block.append_operation(helper.br(0, &[value], location)); + helper.br(valid_block, 0, &[value], location)?; } { @@ -473,10 +471,8 @@ pub fn build_tuple_from_span<'ctx, 'this>( entry.argument(0)?.into(), )?; - error_block.append_operation(helper.br(1, &[], location)); + helper.br(error_block, 1, &[], location) } - - Ok(()) } /// Generate MLIR operations for the `array_append` libfunc. @@ -723,8 +719,7 @@ pub fn build_append<'ctx, 'this>( )?; entry.store(context, location, max_len_ptr, array_end)?; - entry.append_operation(helper.br(0, &[array_obj], location)); - Ok(()) + helper.br(entry, 0, &[array_obj], location) } #[derive(Clone, Copy)] @@ -941,7 +936,7 @@ fn build_pop<'ctx, 'this, const CONSUME: bool, const REVERSE: bool>( branch_values.push(array_obj); branch_values.push(target_ptr); - valid_block.append_operation(helper.br(0, &branch_values, location)); + helper.br(valid_block, 0, &branch_values, location)?; } { @@ -956,7 +951,7 @@ fn build_pop<'ctx, 'this, const CONSUME: bool, const REVERSE: bool>( branch_values.push(array_obj); } - error_block.append_operation(helper.br(1, &branch_values, location)); + helper.br(error_block, 1, &branch_values, location)?; } Ok(()) @@ -1074,7 +1069,7 @@ pub fn build_get<'ctx, 'this>( entry.argument(1)?.into(), )?; - valid_block.append_operation(helper.br(0, &[range_check, target_ptr], location)); + helper.br(valid_block, 0, &[range_check, target_ptr], location)?; } { @@ -1089,7 +1084,7 @@ pub fn build_get<'ctx, 'this>( entry.argument(1)?.into(), )?; - error_block.append_operation(helper.br(1, &[range_check], location)); + helper.br(error_block, 1, &[range_check], location)?; } Ok(()) @@ -1155,7 +1150,7 @@ pub fn build_slice<'ctx, 'this>( let array_obj = valid_block.insert_value(context, location, array_obj, array_start, 1)?; let array_obj = valid_block.insert_value(context, location, array_obj, array_end, 2)?; - valid_block.append_operation(helper.br(0, &[range_check, array_obj], location)); + helper.br(valid_block, 0, &[range_check, array_obj], location)?; } { @@ -1170,7 +1165,7 @@ pub fn build_slice<'ctx, 'this>( array_obj, )?; - error_block.append_operation(helper.br(1, &[range_check], location)); + helper.br(error_block, 1, &[range_check], location)?; } Ok(()) @@ -1205,8 +1200,7 @@ pub fn build_len<'ctx, 'this>( entry.argument(0)?.into(), )?; - entry.append_operation(helper.br(0, &[array_len], location)); - Ok(()) + helper.br(entry, 0, &[array_len], location) } fn is_shared<'ctx, 'this>( diff --git a/src/libfuncs/bool.rs b/src/libfuncs/bool.rs index c9ca75bf3..7f2fd5d07 100644 --- a/src/libfuncs/bool.rs +++ b/src/libfuncs/bool.rs @@ -18,7 +18,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::{arith, llvm}, - ir::{r#type::IntegerType, Block, BlockLike, Location}, + ir::{r#type::IntegerType, Block, Location}, Context, }; @@ -126,8 +126,7 @@ fn build_bool_binary<'ctx, 'this>( let res = entry.insert_value(context, location, res, new_tag_value, 0)?; - entry.append_operation(helper.br(0, &[res], location)); - Ok(()) + helper.br(entry, 0, &[res], location) } /// Generate MLIR operations for the `bool_not_impl` libfunc. @@ -168,8 +167,7 @@ pub fn build_bool_not<'ctx, 'this>( ))?; let res = entry.insert_value(context, location, res, new_tag_value, 0)?; - entry.append_operation(helper.br(0, &[res], location)); - Ok(()) + helper.br(entry, 0, &[res], location) } /// Generate MLIR operations for the `unbox` libfunc. @@ -203,8 +201,7 @@ pub fn build_bool_to_felt252<'ctx, 'this>( let result = entry.extui(tag_value, felt252_ty, location)?; - entry.append_operation(helper.br(0, &[result], location)); - Ok(()) + helper.br(entry, 0, &[result], location) } #[cfg(test)] diff --git a/src/libfuncs/bounded_int.rs b/src/libfuncs/bounded_int.rs index 3e048a31f..4140e5b7e 100644 --- a/src/libfuncs/bounded_int.rs +++ b/src/libfuncs/bounded_int.rs @@ -193,8 +193,7 @@ fn build_add<'ctx, 'this>( res_value }; - entry.append_operation(helper.br(0, &[res_value], location)); - Ok(()) + helper.br(entry, 0, &[res_value], location) } /// Generate MLIR operations for the `bounded_int_sub` libfunc. @@ -321,8 +320,7 @@ fn build_sub<'ctx, 'this>( res_value }; - entry.append_operation(helper.br(0, &[res_value], location)); - Ok(()) + helper.br(entry, 0, &[res_value], location) } /// Generate MLIR operations for the `bounded_int_mul` libfunc. @@ -441,8 +439,7 @@ fn build_mul<'ctx, 'this>( res_value }; - entry.append_operation(helper.br(0, &[res_value], location)); - Ok(()) + helper.br(entry, 0, &[res_value], location) } /// Generate MLIR operations for the `bounded_int_divrem` libfunc. @@ -586,8 +583,7 @@ fn build_divrem<'ctx, 'this>( rem_value }; - entry.append_operation(helper.br(0, &[range_check, div_value, rem_value], location)); - Ok(()) + helper.br(entry, 0, &[range_check, div_value, rem_value], location) } /// Generate MLIR operations for the `bounded_int_constrain` libfunc. @@ -669,7 +665,7 @@ fn build_constrain<'ctx, 'this>( res_value }; - lower_block.append_operation(helper.br(0, &[range_check, res_value], location)); + helper.br(lower_block, 0, &[range_check, res_value], location)?; } { @@ -696,7 +692,7 @@ fn build_constrain<'ctx, 'this>( res_value }; - upper_block.append_operation(helper.br(1, &[range_check, res_value], location)); + helper.br(upper_block, 1, &[range_check, res_value], location)?; } Ok(()) @@ -747,9 +743,14 @@ fn build_trim<'ctx, 'this>( value }; - entry.append_operation(helper.cond_br(context, is_invalid, [0, 1], [&[], &[value]], location)); - - Ok(()) + helper.cond_br( + context, + entry, + is_invalid, + [0, 1], + [&[], &[value]], + location, + ) } /// Generate MLIR operations for the `bounded_int_is_zero` libfunc. @@ -775,14 +776,14 @@ fn build_is_zero<'ctx, 'this>( let k0 = entry.const_int_from_type(context, location, 0, src_value.r#type())?; let src_is_zero = entry.cmpi(context, CmpiPredicate::Eq, src_value, k0, location)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, src_is_zero, [0, 1], [&[], &[src_value]], location, - )); - Ok(()) + ) } /// Generate MLIR operations for the `bounded_int_wrap_non_zero` libfunc. diff --git a/src/libfuncs/box.rs b/src/libfuncs/box.rs index 96a50ef52..b23abbade 100644 --- a/src/libfuncs/box.rs +++ b/src/libfuncs/box.rs @@ -107,8 +107,7 @@ pub fn build_into_box<'ctx, 'this>( ))), )); - entry.append_operation(helper.br(0, &[ptr], location)); - Ok(()) + helper.br(entry, 0, &[ptr], location) } /// Generate MLIR operations for the `unbox` libfunc. @@ -146,8 +145,7 @@ pub fn build_unbox<'ctx, 'this>( entry.append_operation(ReallocBindingsMeta::free(context, entry.arg(0)?, location)?); - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } #[cfg(test)] diff --git a/src/libfuncs/bytes31.rs b/src/libfuncs/bytes31.rs index d64cbeb3f..08d000411 100644 --- a/src/libfuncs/bytes31.rs +++ b/src/libfuncs/bytes31.rs @@ -74,9 +74,7 @@ pub fn build_const<'ctx, 'this>( location, )); - entry.append_operation(helper.br(0, &[op0.result(0)?.into()], location)); - - Ok(()) + helper.br(entry, 0, &[op0.result(0)?.into()], location) } /// Generate MLIR operations for the `bytes31_to_felt252` libfunc. @@ -99,9 +97,7 @@ pub fn build_to_felt252<'ctx, 'this>( let result = entry.extui(value, felt252_ty, location)?; - entry.append_operation(helper.br(0, &[result], location)); - - Ok(()) + helper.br(entry, 0, &[result], location) } /// Generate MLIR operations for the `u8_from_felt252` libfunc. @@ -153,9 +149,9 @@ pub fn build_from_felt252<'ctx, 'this>( )); let value = block_success.trunci(value, result_ty, location)?; - block_success.append_operation(helper.br(0, &[range_check, value], location)); + helper.br(block_success, 0, &[range_check, value], location)?; - block_failure.append_operation(helper.br(1, &[range_check], location)); + helper.br(block_failure, 1, &[range_check], location)?; Ok(()) } diff --git a/src/libfuncs/cast.rs b/src/libfuncs/cast.rs index 04536ef4e..3cfc79ddb 100644 --- a/src/libfuncs/cast.rs +++ b/src/libfuncs/cast.rs @@ -19,7 +19,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::arith::{self, CmpiPredicate}, - ir::{r#type::IntegerType, Block, BlockLike, Location, Value, ValueLike}, + ir::{r#type::IntegerType, Block, Location, Value, ValueLike}, Context, }; use num_bigint::{BigInt, Sign}; @@ -60,14 +60,14 @@ pub fn build_downcast<'ctx, 'this>( if info.signature.param_signatures[1].ty == info.signature.branch_signatures[0].vars[1].ty { let k0 = entry.const_int(context, location, 0, 1)?; - entry.append_operation(helper.cond_br( + return helper.cond_br( context, + entry, k0, [0, 1], [&[range_check, src_value], &[range_check]], location, - )); - return Ok(()); + ); } let src_ty = registry.get_type(&info.signature.param_signatures[1].ty)?; @@ -183,13 +183,14 @@ pub fn build_downcast<'ctx, 'this>( let is_in_bounds = entry.const_int(context, location, 1, 1)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_in_bounds, [0, 1], [&[range_check, dst_value], &[range_check]], location, - )); + )?; } else { let lower_check = if dst_range.lower > src_range.lower { let dst_lower = entry.const_int_from_type( @@ -267,13 +268,14 @@ pub fn build_downcast<'ctx, 'this>( } else { dst_value }; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_in_bounds, [0, 1], [&[range_check, dst_value], &[range_check]], location, - )); + )?; } Ok(()) @@ -292,8 +294,7 @@ pub fn build_upcast<'ctx, 'this>( let src_value = entry.arg(0)?; if info.signature.param_signatures[0].ty == info.signature.branch_signatures[0].vars[0].ty { - entry.append_operation(helper.br(0, &[src_value], location)); - return Ok(()); + return helper.br(entry, 0, &[src_value], location); } let src_ty = registry.get_type(&info.signature.param_signatures[0].ty)?; @@ -384,8 +385,7 @@ pub fn build_upcast<'ctx, 'this>( dst_value }; - entry.append_operation(helper.br(0, &[dst_value], location)); - Ok(()) + helper.br(entry, 0, &[dst_value], location) } #[cfg(test)] diff --git a/src/libfuncs/circuit.rs b/src/libfuncs/circuit.rs index 54ff0cb53..48a0d39e4 100644 --- a/src/libfuncs/circuit.rs +++ b/src/libfuncs/circuit.rs @@ -151,9 +151,7 @@ fn build_init_circuit_data<'ctx, 'this>( &[k0, ptr], )?; - entry.append_operation(helper.br(0, &[rc, accumulator], location)); - - Ok(()) + helper.br(entry, 0, &[rc, accumulator], location) } /// Generate MLIR operations for the `add_circuit_input` libfunc. @@ -235,7 +233,7 @@ fn build_add_input<'ctx, 'this>( // If not last insert, then return accumulator { - middle_insert_block.append_operation(helper.br(1, &[accumulator], location)); + helper.br(middle_insert_block, 1, &[accumulator], location)?; } // If is last insert, then return accumulator.pointer @@ -249,7 +247,7 @@ fn build_add_input<'ctx, 'this>( 1, )?; - last_insert_block.append_operation(helper.br(0, &[inputs_ptr], location)); + helper.br(last_insert_block, 0, &[inputs_ptr], location)?; } Ok(()) @@ -271,9 +269,14 @@ fn build_try_into_circuit_modulus<'ctx, 'this>( let is_valid = entry.cmpi(context, arith::CmpiPredicate::Ugt, modulus, k1, location)?; - entry.append_operation(helper.cond_br(context, is_valid, [0, 1], [&[modulus], &[]], location)); - - Ok(()) + helper.cond_br( + context, + entry, + is_valid, + [0, 1], + [&[modulus], &[]], + location, + ) } /// Generate MLIR operations for the `get_circuit_descriptor` libfunc. @@ -293,9 +296,7 @@ fn build_get_descriptor<'ctx, 'this>( let unit = entry.append_op_result(llvm::undef(descriptor_type, location))?; - entry.append_operation(helper.br(0, &[unit], location)); - - Ok(()) + helper.br(entry, 0, &[unit], location) } /// Generate MLIR operations for the `eval_circuit` libfunc. @@ -413,7 +414,7 @@ fn build_eval<'ctx, 'this>( &[outputs_ptr, modulus_struct], )?; - ok_block.append_operation(helper.br(0, &[add_mod, mul_mod, outputs], location)); + helper.br(ok_block, 0, &[add_mod, mul_mod, outputs], location)?; } // Error case @@ -449,7 +450,12 @@ fn build_eval<'ctx, 'this>( registry.build_type(context, helper, metadata, failure_type_id)?, location, ))?; - err_block.append_operation(helper.br(1, &[add_mod, mul_mod, partial, failure], location)); + helper.br( + err_block, + 1, + &[add_mod, mul_mod, partial, failure], + location, + )?; } Ok(()) @@ -728,9 +734,7 @@ fn build_failure_guarantee_verify<'ctx, 'this>( let guarantee = entry.append_op_result(llvm::undef(guarantee_type, location))?; - entry.append_operation(helper.br(0, &[rc, mul_mod, guarantee], location)); - - Ok(()) + helper.br(entry, 0, &[rc, mul_mod, guarantee], location) } /// Generate MLIR operations for the `u96_limbs_less_than_guarantee_verify` libfunc. @@ -778,7 +782,7 @@ fn build_u96_limbs_less_than_guarantee_verify<'ctx, 'this>( { // if there is diff, return it - diff_block.append_operation(helper.br(1, &[diff], location)); + helper.br(diff_block, 1, &[diff], location)?; } { // if there is no diff, build a new guarantee, skipping last limb @@ -817,7 +821,7 @@ fn build_u96_limbs_less_than_guarantee_verify<'ctx, 'this>( &[new_gate, new_modulus], )?; - next_block.append_operation(helper.br(0, &[new_guarantee], location)); + helper.br(next_block, 0, &[new_guarantee], location)?; } Ok(()) @@ -849,9 +853,7 @@ fn build_u96_single_limb_less_than_guarantee_verify<'ctx, 'this>( // calcualte diff between limbs let diff = entry.append_op_result(arith::subi(modulus_limb, gate_limb, location))?; - entry.append_operation(helper.br(0, &[diff], location)); - - Ok(()) + helper.br(entry, 0, &[diff], location) } /// Generate MLIR operations for the `get_circuit_output` libfunc. @@ -937,7 +939,7 @@ fn build_get_output<'ctx, 'this>( )?; } - entry.append_operation(helper.br(0, &[output_struct, guarantee], location)); + helper.br(entry, 0, &[output_struct, guarantee], location)?; Ok(()) } @@ -1181,8 +1183,7 @@ fn build_into_u96_guarantee<'ctx, 'this>( dst = entry.addi(dst, klower, location)? } - entry.append_operation(helper.br(0, &[dst], location)); - Ok(()) + helper.br(entry, 0, &[dst], location) } #[cfg(test)] diff --git a/src/libfuncs/const.rs b/src/libfuncs/const.rs index 92ec36d31..2986e99ee 100644 --- a/src/libfuncs/const.rs +++ b/src/libfuncs/const.rs @@ -24,7 +24,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::llvm::{self, r#type::pointer}, - ir::{Block, BlockLike, Location, Value}, + ir::{Block, Location, Value}, Context, }; use num_bigint::Sign; @@ -89,8 +89,7 @@ pub fn build_const_as_box<'ctx, 'this>( // Store constant in box entry.store(context, location, ptr, value)?; - entry.append_operation(helper.br(0, &[ptr], location)); - Ok(()) + helper.br(entry, 0, &[ptr], location) } /// Generate MLIR operations for the `const_as_immediate` libfunc. @@ -114,8 +113,7 @@ pub fn build_const_as_immediate<'ctx, 'this>( context, registry, entry, location, helper, metadata, const_type, )?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } pub fn build_const_type_value<'ctx, 'this>( diff --git a/src/libfuncs/coupon.rs b/src/libfuncs/coupon.rs index 4a698d233..14b420238 100644 --- a/src/libfuncs/coupon.rs +++ b/src/libfuncs/coupon.rs @@ -20,7 +20,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::llvm, - ir::{Block, BlockLike, Location}, + ir::{Block, Location}, Context, }; @@ -72,9 +72,7 @@ pub fn build_buy<'ctx, 'this>( )?; let coupon = entry.append_op_result(llvm::undef(ty, location))?; - entry.append_operation(helper.br(0, &[coupon], location)); - - Ok(()) + helper.br(entry, 0, &[coupon], location) } /// Generate MLIR operations for the `coupon` libfunc. @@ -91,7 +89,5 @@ pub fn build_refund<'ctx, 'this>( // let gas = metadata.get::().ok_or(Error::MissingMetadata)?; // let gas_cost = gas.initial_required_gas(&info.function.id); - entry.append_operation(helper.br(0, &[], location)); - - Ok(()) + helper.br(entry, 0, &[], location) } diff --git a/src/libfuncs/debug.rs b/src/libfuncs/debug.rs index 16b81a9fa..a6449f6b6 100644 --- a/src/libfuncs/debug.rs +++ b/src/libfuncs/debug.rs @@ -27,16 +27,16 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::{arith, cf, llvm}, - ir::{r#type::IntegerType, Block, BlockLike, Location}, + ir::{r#type::IntegerType, Block, Location}, Context, }; -pub fn build<'ctx>( +pub fn build<'ctx, 'this>( context: &'ctx Context, registry: &ProgramRegistry, - entry: &Block<'ctx>, + entry: &'this Block<'ctx>, location: Location<'ctx>, - helper: &LibfuncHelper<'ctx, '_>, + helper: &LibfuncHelper<'ctx, 'this>, metadata: &mut MetadataStorage, selector: &DebugConcreteLibfunc, ) -> Result<()> { @@ -47,12 +47,12 @@ pub fn build<'ctx>( } } -pub fn build_print<'ctx>( +pub fn build_print<'ctx, 'this>( context: &'ctx Context, registry: &ProgramRegistry, - entry: &Block<'ctx>, + entry: &'this Block<'ctx>, location: Location<'ctx>, - helper: &LibfuncHelper<'ctx, '_>, + helper: &LibfuncHelper<'ctx, 'this>, metadata: &mut MetadataStorage, info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { @@ -121,7 +121,5 @@ pub fn build_print<'ctx>( location, ); - entry.append_operation(helper.br(0, &[], location)); - - Ok(()) + helper.br(entry, 0, &[], location) } diff --git a/src/libfuncs/drop.rs b/src/libfuncs/drop.rs index 8f32dcca9..2cc73ef3d 100644 --- a/src/libfuncs/drop.rs +++ b/src/libfuncs/drop.rs @@ -19,7 +19,7 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - ir::{Block, BlockLike, Location}, + ir::{Block, Location}, Context, }; @@ -50,6 +50,5 @@ pub fn build<'ctx, 'this>( )?; } - entry.append_operation(helper.br(0, &[], location)); - Ok(()) + helper.br(entry, 0, &[], location) } diff --git a/src/libfuncs/dup.rs b/src/libfuncs/dup.rs index eb9477c50..620fedfde 100644 --- a/src/libfuncs/dup.rs +++ b/src/libfuncs/dup.rs @@ -18,7 +18,7 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - ir::{Block, BlockLike, Location}, + ir::{Block, Location}, Context, }; @@ -45,7 +45,5 @@ pub fn build<'ctx, 'this>( &info.signature.param_signatures[0].ty, entry.arg(0)?, )?; - entry.append_operation(helper.br(0, &[values.0, values.1], location)); - - Ok(()) + helper.br(entry, 0, &[values.0, values.1], location) } diff --git a/src/libfuncs/ec.rs b/src/libfuncs/ec.rs index 5563d1f6b..15a5f1a61 100644 --- a/src/libfuncs/ec.rs +++ b/src/libfuncs/ec.rs @@ -20,7 +20,7 @@ use melior::{ arith::{self, CmpiPredicate}, llvm, }, - ir::{operation::OperationBuilder, r#type::IntegerType, Block, BlockLike, Location}, + ir::{operation::OperationBuilder, r#type::IntegerType, Block, Location}, Context, }; @@ -91,14 +91,14 @@ pub fn build_is_zero<'ctx, 'this>( let k0 = entry.const_int(context, location, 0, 252)?; let y_is_zero = entry.cmpi(context, CmpiPredicate::Eq, y, k0, location)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, y_is_zero, [0, 1], [&[], &[entry.arg(0)?]], location, - )); - Ok(()) + ) } /// Generate MLIR operations for the `ec_neg` libfunc. @@ -134,8 +134,7 @@ pub fn build_neg<'ctx, 'this>( let result = entry.insert_value(context, location, entry.arg(0)?, y_neg, 1)?; - entry.append_operation(helper.br(0, &[result], location)); - Ok(()) + helper.br(entry, 0, &[result], location) } /// Generate MLIR operations for the `ec_point_from_x_nz` libfunc. @@ -179,14 +178,14 @@ pub fn build_point_from_x<'ctx, 'this>( let point = entry.load(context, location, point_ptr, ec_point_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result, [0, 1], [&[range_check, point], &[range_check]], location, - )); - Ok(()) + ) } /// Generate MLIR operations for the `ec_state_add` libfunc. @@ -233,8 +232,7 @@ pub fn build_state_add<'ctx, 'this>( let state = entry.load(context, location, state_ptr, ec_state_ty)?; - entry.append_operation(helper.br(0, &[state], location)); - Ok(()) + helper.br(entry, 0, &[state], location) } /// Generate MLIR operations for the `ec_state_add_mul` libfunc. @@ -289,8 +287,7 @@ pub fn build_state_add_mul<'ctx, 'this>( let state = entry.load(context, location, state_ptr, ec_state_ty)?; - entry.append_operation(helper.br(0, &[ec_op, state], location)); - Ok(()) + helper.br(entry, 0, &[ec_op, state], location) } /// Generate MLIR operations for the `ec_state_try_finalize_nz` libfunc. @@ -335,8 +332,7 @@ pub fn build_state_finalize<'ctx, 'this>( let point = entry.load(context, location, point_ptr, ec_point_ty)?; - entry.append_operation(helper.cond_br(context, is_zero, [0, 1], [&[point], &[]], location)); - Ok(()) + helper.cond_br(context, entry, is_zero, [0, 1], [&[point], &[]], location) } /// Generate MLIR operations for the `ec_state_init` libfunc. @@ -374,8 +370,7 @@ pub fn build_state_init<'ctx, 'this>( let state = entry.load(context, location, state_ptr, ec_state_ty)?; - entry.append_operation(helper.br(0, &[state], location)); - Ok(()) + helper.br(entry, 0, &[state], location) } /// Generate MLIR operations for the `ec_point_try_new_nz` libfunc. @@ -417,8 +412,7 @@ pub fn build_try_new<'ctx, 'this>( .result(0)? .into(); - entry.append_operation(helper.cond_br(context, result, [0, 1], [&[point], &[]], location)); - Ok(()) + helper.cond_br(context, entry, result, [0, 1], [&[point], &[]], location) } /// Generate MLIR operations for the `ec_point_unwrap` libfunc. @@ -457,8 +451,7 @@ pub fn build_unwrap_point<'ctx, 'this>( 1, )?; - entry.append_operation(helper.br(0, &[x, y], location)); - Ok(()) + helper.br(entry, 0, &[x, y], location) } /// Generate MLIR operations for the `ec_point_zero` libfunc. @@ -486,8 +479,7 @@ pub fn build_zero<'ctx, 'this>( let point = entry.insert_value(context, location, point, k0, 1)?; - entry.append_operation(helper.br(0, &[point], location)); - Ok(()) + helper.br(entry, 0, &[point], location) } #[cfg(test)] diff --git a/src/libfuncs/enum.rs b/src/libfuncs/enum.rs index 694f33054..7be0f2de8 100644 --- a/src/libfuncs/enum.rs +++ b/src/libfuncs/enum.rs @@ -94,9 +94,8 @@ pub fn build_init<'ctx, 'this>( &info.signature.param_signatures[0].ty, info.index, )?; - entry.append_operation(helper.br(0, &[val], location)); - Ok(()) + helper.br(entry, 0, &[val], location) } #[allow(clippy::too_many_arguments)] @@ -240,9 +239,7 @@ pub fn build_from_bounded_int<'ctx, 'this>( let value = entry.append_op_result(llvm::undef(enum_ty, location))?; let value = entry.insert_value(context, location, value, tag_value, 0)?; - entry.append_operation(helper.br(0, &[value], location)); - - Ok(()) + helper.br(entry, 0, &[value], location) } /// Generate MLIR operations for the `enum_match` libfunc. @@ -277,7 +274,7 @@ pub fn build_match<'ctx, 'this>( entry.append_operation(llvm::unreachable(location)); } 1 => { - entry.append_operation(helper.br(0, &[entry.arg(0)?], location)); + helper.br(entry, 0, &[entry.arg(0)?], location)?; } _ => { let (layout, (tag_ty, _), variant_tys) = crate::types::r#enum::get_type_for_variants( @@ -394,7 +391,7 @@ pub fn build_match<'ctx, 'this>( } }; - block.append_operation(helper.br(i, &[payload_val], location)); + helper.br(block, i, &[payload_val], location)?; } } } @@ -438,7 +435,7 @@ pub fn build_snapshot_match<'ctx, 'this>( entry.append_operation(llvm::unreachable(location)); } 1 => { - entry.append_operation(helper.br(0, &[entry.arg(0)?], location)); + helper.br(entry, 0, &[entry.arg(0)?], location)?; } _ => { let (layout, (tag_ty, _), variant_tys) = crate::types::r#enum::get_type_for_variants( @@ -536,7 +533,7 @@ pub fn build_snapshot_match<'ctx, 'this>( } }; - block.append_operation(helper.br(i, &[payload_val], location)); + helper.br(block, i, &[payload_val], location)?; } } } diff --git a/src/libfuncs/felt252.rs b/src/libfuncs/felt252.rs index 7604ce09d..33c3fb060 100644 --- a/src/libfuncs/felt252.rs +++ b/src/libfuncs/felt252.rs @@ -271,14 +271,11 @@ pub fn build_binary_operation<'ctx, 'this>( ))?; let result = inverse_result_block.trunci(result, felt252_ty, location)?; - inverse_result_block.append_operation(helper.br(0, &[result], location)); - return Ok(()); + return helper.br(inverse_result_block, 0, &[result], location); } }; - entry.append_operation(helper.br(0, &[result], location)); - - Ok(()) + helper.br(entry, 0, &[result], location) } /// Generate MLIR operations for the `felt252_const` libfunc. @@ -306,8 +303,8 @@ pub fn build_const<'ctx, 'this>( )?; let value = entry.const_int_from_type(context, location, value, felt252_ty)?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + + helper.br(entry, 0, &[value], location) } /// Generate MLIR operations for the `felt252_is_zero` libfunc. @@ -325,8 +322,7 @@ pub fn build_is_zero<'ctx, 'this>( let k0 = entry.const_int_from_type(context, location, 0, arg0.r#type())?; let condition = entry.cmpi(context, CmpiPredicate::Eq, arg0, k0, location)?; - entry.append_operation(helper.cond_br(context, condition, [0, 1], [&[], &[arg0]], location)); - Ok(()) + helper.cond_br(context, entry, condition, [0, 1], [&[], &[arg0]], location) } #[cfg(test)] diff --git a/src/libfuncs/felt252_dict.rs b/src/libfuncs/felt252_dict.rs index 43d2165e3..b0ef5c72d 100644 --- a/src/libfuncs/felt252_dict.rs +++ b/src/libfuncs/felt252_dict.rs @@ -20,7 +20,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::{llvm, ods}, - ir::{Block, BlockLike, Location}, + ir::{Block, Location}, Context, }; @@ -103,8 +103,7 @@ pub fn build_new<'ctx, 'this>( registry.get_type(value_type_id)?.layout(registry)?, )?; - entry.append_operation(helper.br(0, &[segment_arena, dict_ptr], location)); - Ok(()) + helper.br(entry, 0, &[segment_arena, dict_ptr], location) } pub fn build_squash<'ctx, 'this>( @@ -132,13 +131,12 @@ pub fn build_squash<'ctx, 'this>( let new_gas_builtin = entry.addi(gas_builtin, gas_refund, location)?; - entry.append_operation(helper.br( + helper.br( + entry, 0, &[range_check, new_gas_builtin, segment_arena, entry.arg(3)?], location, - )); - - Ok(()) + ) } #[cfg(test)] diff --git a/src/libfuncs/felt252_dict_entry.rs b/src/libfuncs/felt252_dict_entry.rs index e0a2fb2d3..fe91f5f71 100644 --- a/src/libfuncs/felt252_dict_entry.rs +++ b/src/libfuncs/felt252_dict_entry.rs @@ -119,6 +119,9 @@ pub fn build_get<'ctx, 'this>( last_block: Cell::new(&block), branches: Vec::new(), results: Vec::new(), + + #[cfg(feature = "with-libfunc-profiling")] + profiler: helper.profiler.clone(), }; // When the entry is vacant we need to create the default value. @@ -143,8 +146,7 @@ pub fn build_get<'ctx, 'this>( // `get`), the memory it occupied is not modified because we're expecting it to be overwritten // by the finalizer (in other words, the extracted element will be dropped twice). - entry.append_operation(helper.br(0, &[dict_entry, value], location)); - Ok(()) + helper.br(entry, 0, &[dict_entry, value], location) } /// The felt252_dict_entry_finalize libfunc receives the dict entry and a new value, @@ -185,8 +187,7 @@ pub fn build_finalize<'ctx, 'this>( entry.store(context, location, value_ptr, new_value)?; - entry.append_operation(helper.br(0, &[dict_ptr], location)); - Ok(()) + helper.br(entry, 0, &[dict_ptr], location) } #[cfg(test)] diff --git a/src/libfuncs/function_call.rs b/src/libfuncs/function_call.rs index ba8400909..7122d3ab2 100644 --- a/src/libfuncs/function_call.rs +++ b/src/libfuncs/function_call.rs @@ -122,7 +122,7 @@ pub fn build<'ctx, 'this>( } } - cont_block.append_operation(helper.br(0, &results, location)); + helper.br(cont_block, 0, &results, location)?; } else { let mut result_types = Vec::new(); let return_types = info @@ -294,7 +294,7 @@ pub fn build<'ctx, 'this>( } } - entry.append_operation(helper.br(0, &results, location)); + helper.br(entry, 0, &results, location)?; } if let Some(tailrec_meta) = tailrec_meta { diff --git a/src/libfuncs/gas.rs b/src/libfuncs/gas.rs index 1f724d25b..d164437b9 100644 --- a/src/libfuncs/gas.rs +++ b/src/libfuncs/gas.rs @@ -17,7 +17,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::{arith::CmpiPredicate, ods}, - ir::{r#type::IntegerType, Block, BlockLike, Location, Value}, + ir::{r#type::IntegerType, Block, Location, Value}, Context, }; @@ -63,11 +63,12 @@ pub fn build_get_available_gas<'ctx, 'this>( _metadata: &mut MetadataStorage, _info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { - let gas = entry.arg(0)?; - let gas_u128 = entry.extui(gas, IntegerType::new(context, 128).into(), location)?; + let i128_ty = IntegerType::new(context, 128).into(); + + let gas_u128 = entry.extui(entry.arg(0)?, i128_ty, location)?; + // The gas is returned as u128 on the second arg. - entry.append_operation(helper.br(0, &[entry.arg(0)?, gas_u128], location)); - Ok(()) + helper.br(entry, 0, &[entry.arg(0)?, gas_u128], location) } /// Generate MLIR operations for the `withdraw_gas` libfunc. @@ -112,15 +113,14 @@ pub fn build_withdraw_gas<'ctx, 'this>( ods::llvm::intr_usub_sat(context, current_gas, total_gas_cost_value, location).into(), )?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_enough, [0, 1], [&[range_check, resulting_gas], &[range_check, current_gas]], location, - )); - - Ok(()) + ) } /// Returns the unused gas to the remaining @@ -161,9 +161,7 @@ pub fn build_redeposit_gas<'ctx, 'this>( ods::llvm::intr_uadd_sat(context, current_gas, total_gas_cost_value, location).into(), )?; - entry.append_operation(helper.br(0, &[resulting_gas], location)); - - Ok(()) + helper.br(entry, 0, &[resulting_gas], location) } /// Generate MLIR operations for the `withdraw_gas_all` libfunc. @@ -200,15 +198,14 @@ pub fn build_builtin_withdraw_gas<'ctx, 'this>( ods::llvm::intr_usub_sat(context, current_gas, total_gas_cost_value, location).into(), )?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_enough, [0, 1], [&[range_check, resulting_gas], &[range_check, current_gas]], location, - )); - - Ok(()) + ) } /// Generate MLIR operations for the `get_builtin_costs` libfunc. @@ -232,9 +229,7 @@ pub fn build_get_builtin_costs<'ctx, 'this>( .into() }; - entry.append_operation(helper.br(0, &[builtin_ptr], location)); - - Ok(()) + helper.br(entry, 0, &[builtin_ptr], location) } /// Calculate the current gas cost, given the constant `GasCost` configuration, diff --git a/src/libfuncs/int.rs b/src/libfuncs/int.rs index 63b532cd4..f9114d0eb 100644 --- a/src/libfuncs/int.rs +++ b/src/libfuncs/int.rs @@ -219,12 +219,12 @@ fn build_bitwise<'ctx, 'this>( let logical_xor = entry.append_op_result(arith::xori(lhs, rhs, location))?; let logical_or = entry.append_op_result(arith::ori(lhs, rhs, location))?; - entry.append_operation(helper.br( + helper.br( + entry, 0, &[bitwise, logical_and, logical_xor, logical_or], location, - )); - Ok(()) + ) } fn build_byte_reverse<'ctx, 'this>( @@ -241,8 +241,7 @@ fn build_byte_reverse<'ctx, 'this>( let value = entry.append_op_result(ods::llvm::intr_bswap(context, entry.arg(1)?, location).into())?; - entry.append_operation(helper.br(0, &[bitwise, value], location)); - Ok(()) + helper.br(entry, 0, &[bitwise, value], location) } fn build_const<'ctx, 'this, T>( @@ -266,8 +265,7 @@ where let value = entry.const_int_from_type(context, location, info.c, value_ty)?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } fn build_diff<'ctx, 'this>( @@ -287,14 +285,14 @@ fn build_diff<'ctx, 'this>( let is_greater_equal = entry.cmpi(context, CmpiPredicate::Sge, lhs, rhs, location)?; let value_difference = entry.append_op_result(arith::subi(lhs, rhs, location))?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_greater_equal, [0, 1], [&[range_check, value_difference]; 2], location, - )); - Ok(()) + ) } fn build_divmod<'ctx, 'this>( @@ -314,8 +312,7 @@ fn build_divmod<'ctx, 'this>( let result_div = entry.append_op_result(arith::divui(lhs, rhs, location))?; let result_rem = entry.append_op_result(arith::remui(lhs, rhs, location))?; - entry.append_operation(helper.br(0, &[range_check, result_div, result_rem], location)); - Ok(()) + helper.br(entry, 0, &[range_check, result_div, result_rem], location) } fn build_equal<'ctx, 'this>( @@ -335,8 +332,7 @@ fn build_equal<'ctx, 'this>( location, )?; - entry.append_operation(helper.cond_br(context, are_equal, [1, 0], [&[]; 2], location)); - Ok(()) + helper.cond_br(context, entry, are_equal, [1, 0], [&[]; 2], location) } fn build_from_felt252<'ctx, 'this>( @@ -451,15 +447,14 @@ fn build_from_felt252<'ctx, 'this>( let value = entry.trunci(value, value_ty, location)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_in_range, [0, 1], [&[range_check, value], &[range_check]], location, - )); - - Ok(()) + ) } fn build_guarantee_mul<'ctx, 'this>( @@ -488,8 +483,7 @@ fn build_guarantee_mul<'ctx, 'this>( let hi = mul_op.result(1)?.into(); let guarantee = entry.append_op_result(llvm::undef(guarantee_ty, location))?; - entry.append_operation(helper.br(0, &[hi, lo, guarantee], location)); - Ok(()) + helper.br(entry, 0, &[hi, lo, guarantee], location) } fn build_is_zero<'ctx, 'this>( @@ -506,8 +500,7 @@ fn build_is_zero<'ctx, 'this>( let k0 = entry.const_int_from_type(context, location, 0, input.r#type())?; let is_zero = entry.cmpi(context, CmpiPredicate::Eq, input, k0, location)?; - entry.append_operation(helper.cond_br(context, is_zero, [0, 1], [&[], &[input]], location)); - Ok(()) + helper.cond_br(context, entry, is_zero, [0, 1], [&[], &[input]], location) } fn build_mul_guarantee_verify<'ctx, 'this>( @@ -521,8 +514,7 @@ fn build_mul_guarantee_verify<'ctx, 'this>( ) -> Result<()> { let range_check = super::increment_builtin_counter(context, entry, location, entry.arg(0)?)?; - entry.append_operation(helper.br(0, &[range_check], location)); - Ok(()) + helper.br(entry, 0, &[range_check], location) } fn build_operation<'ctx, 'this>( @@ -586,30 +578,31 @@ fn build_operation<'ctx, 'this>( location, )); - block_in_range.append_operation(helper.br(0, &[range_check, result], location)); + helper.br(block_in_range, 0, &[range_check, result], location)?; { let k0 = block_overflow.const_int_from_type(context, location, 0, result.r#type())?; let is_positive = block_overflow.cmpi(context, CmpiPredicate::Sge, result, k0, location)?; - block_overflow.append_operation(helper.cond_br( + helper.cond_br( context, + block_overflow, is_positive, [1, 2], [&[range_check, result]; 2], location, - )); + ) } } else { - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, overflow, [1, 0], [&[range_check, result]; 2], location, - )); + ) } - Ok(()) } fn build_square_root<'ctx, 'this>( @@ -763,8 +756,7 @@ fn build_square_root<'ctx, 'this>( location, ))?; - entry.append_operation(helper.br(0, &[range_check, value], location)); - Ok(()) + helper.br(entry, 0, &[range_check, value], location) } fn build_to_felt252<'ctx, 'this>( @@ -820,8 +812,7 @@ fn build_to_felt252<'ctx, 'this>( entry.extui(entry.arg(0)?, felt252_ty, location)? }; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } fn build_u128s_from_felt252<'ctx, 'this>( @@ -846,14 +837,14 @@ fn build_u128s_from_felt252<'ctx, 'this>( let k0 = entry.const_int_from_type(context, location, 0, target_ty)?; let is_wide = entry.cmpi(context, CmpiPredicate::Ne, hi, k0, location)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_wide, [1, 0], [&[range_check, hi, lo], &[range_check, lo]], location, - )); - Ok(()) + ) } fn build_wide_mul<'ctx, 'this>( @@ -887,8 +878,7 @@ fn build_wide_mul<'ctx, 'this>( let rhs = ext_fn(entry, entry.arg(1)?, result_ty, location)?; let result = entry.muli(lhs, rhs, location)?; - entry.append_operation(helper.br(0, &[result], location)); - Ok(()) + helper.br(entry, 0, &[result], location) } #[cfg(test)] diff --git a/src/libfuncs/int_range.rs b/src/libfuncs/int_range.rs index de2f00e0c..953a44ef0 100644 --- a/src/libfuncs/int_range.rs +++ b/src/libfuncs/int_range.rs @@ -21,7 +21,7 @@ use melior::{ arith::{self, CmpiPredicate}, ods, }, - ir::{Block, BlockLike, Location}, + ir::{Block, Location}, Context, }; use num_bigint::BigInt; @@ -82,14 +82,14 @@ pub fn build_int_range_try_new<'ctx, 'this>( let x_val = entry.append_op_result(arith::select(is_valid, x, y, location))?; let range = entry.insert_values(context, location, range, &[x_val, y])?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_valid, [0, 1], [&[range_check, range], &[range_check, range]], location, - )); - Ok(()) + ) } /// Generate MLIR operations for the `int_range_pop_front` libfunc. @@ -128,14 +128,14 @@ pub fn build_int_range_pop_front<'ctx, 'this>( }; let range = entry.insert_value(context, location, range, x_p_1, 0)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_valid, [1, 0], // failure, success [&[range, x], &[]], location, - )); - Ok(()) + ) } #[cfg(test)] diff --git a/src/libfuncs/mem.rs b/src/libfuncs/mem.rs index 224491f9f..95226f294 100644 --- a/src/libfuncs/mem.rs +++ b/src/libfuncs/mem.rs @@ -20,7 +20,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::llvm, - ir::{Block, BlockLike, Location}, + ir::{Block, Location}, Context, }; @@ -84,8 +84,7 @@ pub fn build_alloc_local<'ctx, 'this>( let value = entry.append_op_result(llvm::undef(target_type, location))?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } /// Generate MLIR operations for the `store_local` libfunc. @@ -98,6 +97,5 @@ pub fn build_store_local<'ctx, 'this>( _metadata: &mut MetadataStorage, _info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - entry.append_operation(helper.br(0, &[entry.arg(1)?], location)); - Ok(()) + helper.br(entry, 0, &[entry.arg(1)?], location) } diff --git a/src/libfuncs/nullable.rs b/src/libfuncs/nullable.rs index 8e11a1f0f..15bebccdd 100644 --- a/src/libfuncs/nullable.rs +++ b/src/libfuncs/nullable.rs @@ -65,8 +65,7 @@ fn build_null<'ctx, 'this>( let value = entry .append_op_result(ods::llvm::mlir_zero(context, pointer(context, 0), location).into())?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } /// Generate MLIR operations for the `match_nullable` libfunc. @@ -109,8 +108,8 @@ fn build_match_nullable<'ctx, 'this>( location, )); - block_is_null.append_operation(helper.br(0, &[], location)); - block_is_not_null.append_operation(helper.br(1, &[arg], location)); + helper.br(block_is_null, 0, &[], location)?; + helper.br(block_is_not_null, 1, &[arg], location)?; Ok(()) } diff --git a/src/libfuncs/pedersen.rs b/src/libfuncs/pedersen.rs index 4c42fe98b..580878a74 100644 --- a/src/libfuncs/pedersen.rs +++ b/src/libfuncs/pedersen.rs @@ -17,7 +17,7 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - ir::{r#type::IntegerType, Block, BlockLike, Location}, + ir::{r#type::IntegerType, Block, Location}, Context, }; @@ -91,8 +91,7 @@ pub fn build_pedersen<'ctx>( let result = entry.load(context, location, dst_ptr, i256_ty)?; let result = entry.trunci(result, felt252_ty, location)?; - entry.append_operation(helper.br(0, &[pedersen_builtin, result], location)); - Ok(()) + helper.br(entry, 0, &[pedersen_builtin, result], location) } #[cfg(test)] diff --git a/src/libfuncs/poseidon.rs b/src/libfuncs/poseidon.rs index 67d6dce97..f0a5727d2 100644 --- a/src/libfuncs/poseidon.rs +++ b/src/libfuncs/poseidon.rs @@ -18,7 +18,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::ods, - ir::{r#type::IntegerType, Block, BlockLike, Location}, + ir::{r#type::IntegerType, Block, Location}, Context, }; @@ -104,9 +104,7 @@ pub fn build_hades_permutation<'ctx>( let op1 = entry.trunci(op1_i256, felt252_ty, location)?; let op2 = entry.trunci(op2_i256, felt252_ty, location)?; - entry.append_operation(helper.br(0, &[poseidon_builtin, op0, op1, op2], location)); - - Ok(()) + helper.br(entry, 0, &[poseidon_builtin, op0, op1, op2], location) } #[cfg(test)] diff --git a/src/libfuncs/starknet.rs b/src/libfuncs/starknet.rs index e5a26546b..815d4ae63 100644 --- a/src/libfuncs/starknet.rs +++ b/src/libfuncs/starknet.rs @@ -346,8 +346,9 @@ pub fn build_call_contract<'ctx, 'this>( IntegerType::new(context, 64).into(), )?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -355,8 +356,7 @@ pub fn build_call_contract<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_class_hash_const<'ctx, 'this>( @@ -378,8 +378,7 @@ pub fn build_class_hash_const<'ctx, 'this>( 252, )?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } pub fn build_class_hash_try_from_felt252<'ctx, 'this>( @@ -406,14 +405,14 @@ pub fn build_class_hash_try_from_felt252<'ctx, 'this>( ))?; let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, value, limit, location)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_in_range, [0, 1], [&[range_check, value], &[range_check]], location, - )); - Ok(()) + ) } pub fn build_contract_address_const<'ctx, 'this>( @@ -435,8 +434,7 @@ pub fn build_contract_address_const<'ctx, 'this>( 252, )?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } pub fn build_contract_address_try_from_felt252<'ctx, 'this>( @@ -463,14 +461,14 @@ pub fn build_contract_address_try_from_felt252<'ctx, 'this>( ))?; let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, value, limit, location)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_in_range, [0, 1], [&[range_check, value], &[range_check]], location, - )); - Ok(()) + ) } pub fn build_storage_read<'ctx, 'this>( @@ -613,8 +611,9 @@ pub fn build_storage_read<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -622,8 +621,7 @@ pub fn build_storage_read<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_storage_write<'ctx, 'this>( @@ -774,8 +772,9 @@ pub fn build_storage_write<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -783,8 +782,7 @@ pub fn build_storage_write<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_storage_base_address_const<'ctx, 'this>( @@ -806,8 +804,7 @@ pub fn build_storage_base_address_const<'ctx, 'this>( 252, )?; - entry.append_operation(helper.br(0, &[value], location)); - Ok(()) + helper.br(entry, 0, &[value], location) } pub fn build_storage_base_address_from_felt252<'ctx, 'this>( @@ -847,8 +844,7 @@ pub fn build_storage_base_address_from_felt252<'ctx, 'this>( location, ))?; - entry.append_operation(helper.br(0, &[range_check, value], location)); - Ok(()) + helper.br(entry, 0, &[range_check, value], location) } pub fn build_storage_address_from_base_and_offset<'ctx, 'this>( @@ -863,8 +859,7 @@ pub fn build_storage_address_from_base_and_offset<'ctx, 'this>( let offset = entry.extui(entry.arg(1)?, entry.argument(0)?.r#type(), location)?; let addr = entry.addi(entry.arg(0)?, offset, location)?; - entry.append_operation(helper.br(0, &[addr], location)); - Ok(()) + helper.br(entry, 0, &[addr], location) } pub fn build_storage_address_try_from_felt252<'ctx, 'this>( @@ -891,14 +886,14 @@ pub fn build_storage_address_try_from_felt252<'ctx, 'this>( ))?; let is_in_range = entry.cmpi(context, CmpiPredicate::Ult, value, limit, location)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, is_in_range, [0, 1], [&[range_check, value], &[range_check]], location, - )); - Ok(()) + ) } pub fn build_emit_event<'ctx, 'this>( @@ -1090,8 +1085,9 @@ pub fn build_emit_event<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1099,8 +1095,7 @@ pub fn build_emit_event<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_get_block_hash<'ctx, 'this>( @@ -1238,8 +1233,9 @@ pub fn build_get_block_hash<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1247,8 +1243,7 @@ pub fn build_get_block_hash<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_get_execution_info<'ctx, 'this>( @@ -1380,8 +1375,9 @@ pub fn build_get_execution_info<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1389,8 +1385,7 @@ pub fn build_get_execution_info<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_get_execution_info_v2<'ctx, 'this>( @@ -1522,8 +1517,9 @@ pub fn build_get_execution_info_v2<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1531,8 +1527,7 @@ pub fn build_get_execution_info_v2<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_deploy<'ctx, 'this>( @@ -1756,8 +1751,9 @@ pub fn build_deploy<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1782,8 +1778,7 @@ pub fn build_deploy<'ctx, 'this>( ], ], location, - )); - Ok(()) + ) } pub fn build_keccak<'ctx, 'this>( @@ -1931,8 +1926,9 @@ pub fn build_keccak<'ctx, 'this>( }; let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1940,8 +1936,7 @@ pub fn build_keccak<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_library_call<'ctx, 'this>( @@ -2110,8 +2105,9 @@ pub fn build_library_call<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -2119,8 +2115,7 @@ pub fn build_library_call<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } /// Executes the `meta_tx_v0_syscall`. @@ -2325,8 +2320,9 @@ pub fn build_meta_tx_v0<'ctx, 'this>( IntegerType::new(context, 64).into(), )?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -2334,8 +2330,7 @@ pub fn build_meta_tx_v0<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_replace_class<'ctx, 'this>( @@ -2474,8 +2469,9 @@ pub fn build_replace_class<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -2483,8 +2479,7 @@ pub fn build_replace_class<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_send_message_to_l1<'ctx, 'this>( @@ -2648,8 +2643,9 @@ pub fn build_send_message_to_l1<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -2657,8 +2653,7 @@ pub fn build_send_message_to_l1<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_sha256_process_block_syscall<'ctx, 'this>( @@ -2798,8 +2793,9 @@ pub fn build_sha256_process_block_syscall<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -2807,8 +2803,7 @@ pub fn build_sha256_process_block_syscall<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_get_class_hash_at<'ctx, 'this>( @@ -2950,8 +2945,9 @@ pub fn build_get_class_hash_at<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -2959,8 +2955,7 @@ pub fn build_get_class_hash_at<'ctx, 'this>( &[remaining_gas, entry.arg(1)?, payload_ok], ], location, - )); - Ok(()) + ) } #[cfg(test)] diff --git a/src/libfuncs/starknet/secp256.rs b/src/libfuncs/starknet/secp256.rs index a4f400ad8..56ebd8700 100644 --- a/src/libfuncs/starknet/secp256.rs +++ b/src/libfuncs/starknet/secp256.rs @@ -268,8 +268,9 @@ pub fn build_k1_new<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -277,8 +278,7 @@ pub fn build_k1_new<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_k1_add<'ctx, 'this>( @@ -478,8 +478,9 @@ pub fn build_k1_add<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -487,8 +488,7 @@ pub fn build_k1_add<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_k1_mul<'ctx, 'this>( @@ -688,8 +688,9 @@ pub fn build_k1_mul<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -697,8 +698,7 @@ pub fn build_k1_mul<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_k1_get_point_from_x<'ctx, 'this>( @@ -891,8 +891,9 @@ pub fn build_k1_get_point_from_x<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -900,8 +901,7 @@ pub fn build_k1_get_point_from_x<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_k1_get_xy<'ctx, 'this>( @@ -1135,8 +1135,9 @@ pub fn build_k1_get_xy<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1149,8 +1150,7 @@ pub fn build_k1_get_xy<'ctx, 'this>( ], ], location, - )); - Ok(()) + ) } pub fn build_r1_new<'ctx, 'this>( @@ -1350,8 +1350,9 @@ pub fn build_r1_new<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1359,8 +1360,7 @@ pub fn build_r1_new<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_r1_add<'ctx, 'this>( @@ -1560,8 +1560,9 @@ pub fn build_r1_add<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1569,8 +1570,7 @@ pub fn build_r1_add<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_r1_mul<'ctx, 'this>( @@ -1773,8 +1773,9 @@ pub fn build_r1_mul<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1782,8 +1783,7 @@ pub fn build_r1_mul<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_r1_get_point_from_x<'ctx, 'this>( @@ -1978,8 +1978,9 @@ pub fn build_r1_get_point_from_x<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -1987,8 +1988,7 @@ pub fn build_r1_get_point_from_x<'ctx, 'this>( &[remaining_gas, entry.argument(1)?.into(), payload_ok], ], location, - )); - Ok(()) + ) } pub fn build_r1_get_xy<'ctx, 'this>( @@ -2224,8 +2224,9 @@ pub fn build_r1_get_xy<'ctx, 'this>( let remaining_gas = entry.load(context, location, gas_builtin_ptr, gas_ty)?; - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, result_tag, [1, 0], [ @@ -2238,6 +2239,5 @@ pub fn build_r1_get_xy<'ctx, 'this>( ], ], location, - )); - Ok(()) + ) } diff --git a/src/libfuncs/starknet/testing.rs b/src/libfuncs/starknet/testing.rs index 2854c0336..e46e59940 100644 --- a/src/libfuncs/starknet/testing.rs +++ b/src/libfuncs/starknet/testing.rs @@ -115,7 +115,5 @@ pub fn build<'ctx, 'this>( location, LoadStoreOptions::new(), ))?; - entry.append_operation(helper.br(0, &[result], location)); - - Ok(()) + helper.br(entry, 0, &[result], location) } diff --git a/src/libfuncs/struct.rs b/src/libfuncs/struct.rs index 890aa42ac..dfbb02d94 100644 --- a/src/libfuncs/struct.rs +++ b/src/libfuncs/struct.rs @@ -71,9 +71,7 @@ pub fn build_construct<'ctx, 'this>( &fields, )?; - entry.append_operation(helper.br(0, &[value], location)); - - Ok(()) + helper.br(entry, 0, &[value], location) } /// Generate MLIR operations for the `struct_construct` libfunc. @@ -117,7 +115,5 @@ pub fn build_deconstruct<'ctx, 'this>( fields.push(value); } - entry.append_operation(helper.br(0, &fields, location)); - - Ok(()) + helper.br(entry, 0, &fields, location) } diff --git a/src/libfuncs/uint256.rs b/src/libfuncs/uint256.rs index fba56f598..b23a0c6d9 100644 --- a/src/libfuncs/uint256.rs +++ b/src/libfuncs/uint256.rs @@ -250,12 +250,12 @@ pub fn build_divmod<'ctx, 'this>( let op = entry.append_operation(llvm::undef(guarantee_type, location)); let guarantee = op.result(0)?.into(); - entry.append_operation(helper.br( + helper.br( + entry, 0, &[range_check, result_div, result_rem, guarantee], location, - )); - Ok(()) + ) } /// Generate MLIR operations for the `u256_is_zero` libfunc. @@ -326,14 +326,14 @@ pub fn build_is_zero<'ctx, 'this>( .result(0)? .into(); - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, val_is_zero, [0, 1], [&[], &[val_struct]], location, - )); - Ok(()) + ) } /// Generate MLIR operations for the `u256_sqrt` libfunc. @@ -629,8 +629,7 @@ pub fn build_square_root<'ctx, 'this>( .result(0)? .into(); - entry.append_operation(helper.br(0, &[range_check, result], location)); - Ok(()) + helper.br(entry, 0, &[range_check, result], location) } /// Generate MLIR operations for the `u256_guarantee_inv_mod_n` libfunc. @@ -922,8 +921,9 @@ pub fn build_u256_guarantee_inv_mod_n<'ctx, 'this>( let op = entry.append_operation(llvm::undef(guarantee_type, location)); let guarantee = op.result(0)?.into(); - entry.append_operation(helper.cond_br( + helper.cond_br( context, + entry, condition, [0, 1], [ @@ -942,9 +942,7 @@ pub fn build_u256_guarantee_inv_mod_n<'ctx, 'this>( &[entry.arg(0)?, guarantee, guarantee], ], location, - )); - - Ok(()) + ) } #[cfg(test)] diff --git a/src/libfuncs/uint512.rs b/src/libfuncs/uint512.rs index 3ba0c23ed..bb31def01 100644 --- a/src/libfuncs/uint512.rs +++ b/src/libfuncs/uint512.rs @@ -17,7 +17,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::{arith, llvm}, - ir::{r#type::IntegerType, Block, BlockLike, Location, Value}, + ir::{r#type::IntegerType, Block, Location, Value}, Context, }; @@ -140,7 +140,8 @@ pub fn build_divmod_u256<'ctx, 'this>( let guarantee = entry.append_op_result(llvm::undef(guarantee_type, location))?; - entry.append_operation(helper.br( + helper.br( + entry, 0, &[ range_check, @@ -153,8 +154,7 @@ pub fn build_divmod_u256<'ctx, 'this>( guarantee, ], location, - )); - Ok(()) + ) } #[cfg(test)] diff --git a/src/metadata.rs b/src/metadata.rs index b4328d69c..e33656b0c 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -21,6 +21,7 @@ pub mod dup_overrides; pub mod enum_snapshot_variants; pub mod felt252_dict; pub mod gas; +pub mod profiler; pub mod realloc_bindings; pub mod runtime_bindings; pub mod tail_recursion; diff --git a/src/metadata/profiler.rs b/src/metadata/profiler.rs new file mode 100644 index 000000000..4a8905052 --- /dev/null +++ b/src/metadata/profiler.rs @@ -0,0 +1,411 @@ +#![cfg(feature = "with-libfunc-profiling")] +//! The libfunc profiling feature is used to generate information about every libfunc executed in a sierra program. +//! +//! When this feature is used, the compiler will call the important methods: +//! +//! 1. `measure_timestamp`: called before every libfunc execution. +//! +//! 2. `push_frame`: called before every branching operation. This method will also call `measure_timestamp`. This, +//! with the timestamp calculated before the execution, will allow to measure each statement's execution time. +//! If for some reason, the statement delta time could not be gathered, we just record an unit value, recording that +//! we executed the given statement. +//! +//! Once the program execution finished and the information was gathered, the `get_profile` method can be called. +//! It groups the samples by libfunc, and returns all data related to each libfunc. +//! +//! As well as with the trace-dump feature, in the context of Starknet contracts, we need to add support for building +//! profiles for multiple executions. To do so, we need two important elements, which must be set before every contract +//! execution: +//! +//! 1. A global static hashmap to map every profile ID to its respective profiler. See `LIBFUNC_PROFILE`. +//! +//! 2. A counter to track the ID of the current profiler, which gets updated every time we switch to another +//! contract. Since a contract can call other contracts, we need a way of restoring the counter after every execution. +//! +//! See `cairo-native-run` for an example on how to do it. + +use crate::{ + error::{Error, Result}, + utils::BlockExt, +}; +use cairo_lang_sierra::{ + ids::ConcreteLibfuncId, + program::{Program, Statement, StatementIdx}, +}; +use melior::{ + dialect::{ + arith::{self, CmpiPredicate}, + llvm::{self}, + memref, ods, + }, + ir::{ + attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute}, + operation::OperationBuilder, + r#type::{IntegerType, MemRefType}, + Attribute, Block, BlockLike, Identifier, Location, Module, Region, Value, + }, + Context, +}; + +use std::{ + cell::RefCell, + collections::{HashMap, HashSet}, + ffi::c_void, + ptr, + sync::{LazyLock, Mutex}, +}; + +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub enum ProfilerBinding { + PushStmt, + ProfileId, +} + +impl ProfilerBinding { + pub const fn symbol(self) -> &'static str { + match self { + ProfilerBinding::PushStmt => "cairo_native__profiler__push_stmt", + ProfilerBinding::ProfileId => "cairo_native__profiler__profile_id", + } + } + + const fn function_ptr(self) -> *const () { + match self { + ProfilerBinding::PushStmt => ProfilerImpl::push_stmt as *const (), + ProfilerBinding::ProfileId => ptr::null(), + } + } +} + +#[derive(Clone, Default)] +pub struct ProfilerMeta { + active_map: RefCell>, +} + +impl ProfilerMeta { + pub fn new() -> Self { + Self { + active_map: RefCell::new(HashSet::new()), + } + } + + /// Register the global for the given binding, if not yet registered, and return + /// a pointer to the stored value. + /// + /// For the function to be available, `setup_runtime` must be called before running the module + fn build_function<'c, 'a>( + &self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + binding: ProfilerBinding, + ) -> Result> { + if self.active_map.borrow_mut().insert(binding) { + module.body().append_operation( + ods::llvm::mlir_global( + context, + Region::new(), + TypeAttribute::new(llvm::r#type::pointer(context, 0)), + StringAttribute::new(context, binding.symbol()), + Attribute::parse(context, "#llvm.linkage") + .ok_or(Error::ParseAttributeError)?, + location, + ) + .into(), + ); + } + + let global_address = block.append_op_result( + ods::llvm::mlir_addressof( + context, + llvm::r#type::pointer(context, 0), + FlatSymbolRefAttribute::new(context, binding.symbol()), + location, + ) + .into(), + )?; + + block.load( + context, + location, + global_address, + llvm::r#type::pointer(context, 0), + ) + } + + pub fn build_profile_id<'c, 'a>( + &self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + ) -> Result> { + if self + .active_map + .borrow_mut() + .insert(ProfilerBinding::ProfileId) + { + module.body().append_operation(memref::global( + context, + ProfilerBinding::ProfileId.symbol(), + None, + MemRefType::new(IntegerType::new(context, 64).into(), &[], None, None), + None, + false, + None, + location, + )); + } + + let trace_profile_ptr = block + .append_op_result(memref::get_global( + context, + ProfilerBinding::ProfileId.symbol(), + MemRefType::new(IntegerType::new(context, 64).into(), &[], None, None), + location, + )) + .unwrap(); + + block.append_op_result(memref::load(trace_profile_ptr, &[], location)) + } + + /// Gets the current timestamp. + /// + /// The values returned are: + /// 1. Timestamp: CPU cycles since its reset. + /// 2. CPU's id core in which the execution is running (only for x86 arch). + /// In case of arm, 0 is always returned as there's no way to know in which + /// CPU core the execution was run. + /// + /// We use the last value to ensure that both the initial and the end timestamp of + /// a libfunc's execution were calculated by the same core. This is to avoid gathering + /// invalid data + #[cfg(target_arch = "x86_64")] + pub fn measure_timestamp<'c, 'a>( + &self, + context: &'c Context, + block: &'a Block<'c>, + location: Location<'c>, + ) -> Result<(Value<'c, 'a>, Value<'c, 'a>)> { + let i32_ty = IntegerType::new(context, 32).into(); + let i64_ty = IntegerType::new(context, 64).into(); + let k32 = block.const_int_from_type(context, location, 32, i64_ty)?; + + // edx:eax := TimeStampCounter (clock value) + // ecx := IA32_TSC_AUX[31:0] (core ID) + let value = block.append_op_result( + OperationBuilder::new("llvm.inline_asm", location) + .add_attributes(&[ + ( + Identifier::new(context, "asm_string"), + StringAttribute::new(context, "mfence\nrdtscp\nlfence").into(), + ), + ( + Identifier::new(context, "has_side_effects"), + Attribute::unit(context), + ), + ( + Identifier::new(context, "constraints"), + StringAttribute::new(context, "={edx},={eax},={ecx}").into(), + ), + ]) + .add_results(&[llvm::r#type::r#struct( + context, + &[i32_ty, i32_ty, i32_ty], + false, + )]) + .build()?, + )?; + let value_hi = block.extract_value(context, location, value, i32_ty, 0)?; + let value_lo = block.extract_value(context, location, value, i32_ty, 1)?; + let core_idx = block.extract_value(context, location, value, i32_ty, 2)?; + + let value_hi = block.extui(value_hi, i64_ty, location)?; + let value_lo = block.extui(value_lo, i64_ty, location)?; + let value = block.shli(value_hi, k32, location)?; + let value = block.append_op_result(arith::ori(value, value_lo, location))?; + + Ok((value, core_idx)) + } + + /// Gets the current timestamp. + /// + /// The values returned are: + /// 1. Timestamp: CPU cycles since its reset. + /// 2. CPU's id core in which the program is running (only for x86 arch). + /// In case of arm, 0 is always returned as there's no way to know in which + /// CPU core the execution was run. + /// + /// We use the last value to ensure that both the initial and the end timestamp of + /// a libfunc's execution were calculated by the same core. This is to avoid gathering + /// invalid data + #[cfg(target_arch = "aarch64")] + pub fn measure_timestamp<'c, 'a>( + &self, + context: &'c Context, + block: &'a Block<'c>, + location: Location<'c>, + ) -> Result<(Value<'c, 'a>, Value<'c, 'a>)> { + let i64_ty = IntegerType::new(context, 64).into(); + + let value = block.append_op_result( + OperationBuilder::new("llvm.inline_asm", location) + .add_attributes(&[ + ( + Identifier::new(context, "asm_string"), + StringAttribute::new(context, "isb\nmrs $0, CNTVCT_EL0\nisb").into(), + ), + ( + Identifier::new(context, "has_side_effects"), + Attribute::unit(context), + ), + ( + Identifier::new(context, "constraints"), + StringAttribute::new(context, "=r").into(), + ), + ]) + .add_results(&[i64_ty]) + .build()?, + )?; + let core_idx = block.const_int(context, location, 0, 64)?; + + Ok((value, core_idx)) + } + + #[allow(clippy::too_many_arguments)] + /// Receives two timestamps, if they were originated in the same CPU core, + /// the delta time between these two is calculated. If not, then the delta time is + /// assigned to -1. Then it pushes the frame, which is made of the statement index + /// the delta time. + pub fn push_frame<'c>( + &self, + context: &'c Context, + module: &Module, + block: &Block<'c>, + statement_idx: usize, + // (timestamp, core_idx) + t0: (Value<'c, '_>, Value<'c, '_>), + t1: (Value<'c, '_>, Value<'c, '_>), + location: Location<'c>, + ) -> Result<()> { + // If core idx matches: + // Calculate time delta. + // Write statement idx and time delta. + // If core idx does not match: + // Write statement idx and -1. + + let trace_id = self.build_profile_id(context, module, block, location)?; + + let i64_ty = IntegerType::new(context, 64).into(); + + let statement_idx = block.const_int_from_type(context, location, statement_idx, i64_ty)?; + let is_same_core = block.cmpi(context, CmpiPredicate::Eq, t0.1, t1.1, location)?; + + let delta_value = block.append_op_result(arith::subi(t1.0, t0.0, location))?; + let invalid_value = block.const_int_from_type(context, location, u64::MAX, i64_ty)?; + let delta_value = block.append_op_result(arith::select( + is_same_core, + delta_value, + invalid_value, + location, + ))?; + + let callback_ptr = + self.build_function(context, module, block, location, ProfilerBinding::PushStmt)?; + + block.append_operation( + ods::llvm::call( + context, + &[callback_ptr, trace_id, statement_idx, delta_value], + location, + ) + .into(), + ); + + Ok(()) + } +} + +/// Represents the entire profile of the execution. +/// +/// It maps the libfunc ID to a libfunc profile. +type Profile = HashMap; + +/// Represents the profile data for a particular libfunc. +#[derive(Default)] +pub struct LibfuncProfileData { + /// A vector of execution times, for each time the libfunc was executed. + /// It expreses the number of CPU cycles completed during the execution. + pub deltas: Vec, + /// If the time delta for a particular execution could not be gathered, + /// we just increase `extra_counts` by 1. + pub extra_counts: u64, +} + +pub static LIBFUNC_PROFILE: LazyLock>> = + LazyLock::new(|| Mutex::new(HashMap::new())); + +#[derive(Default)] +pub struct ProfilerImpl { + /// The samples recorded by the profiler. A value of `u64::MAX` implies + /// that the delta time for a statement could not be gathered. + pub samples: Vec<(StatementIdx, u64)>, +} + +impl ProfilerImpl { + pub fn new() -> Self { + Self { + samples: Vec::new(), + } + } + + // Push a profiler frame + pub extern "C" fn push_stmt(profile_id: u64, statement_idx: u64, tick_delta: u64) { + let mut profiler = LIBFUNC_PROFILE.lock().unwrap(); + + let Some(profiler) = profiler.get_mut(&profile_id) else { + eprintln!("Could not find libfunc profiler!"); + return; + }; + + profiler + .samples + .push((StatementIdx(statement_idx as usize), tick_delta)); + } + + /// Returns the execution profile, grouped by libfunc + pub fn get_profile(&self, sierra_program: &Program) -> Profile { + let mut profile = HashMap::::new(); + + for (statement_idx, tick_delta) in self.samples.iter() { + if let Statement::Invocation(invocation) = &sierra_program.statements[statement_idx.0] { + let LibfuncProfileData { + deltas, + extra_counts, + } = profile.entry(invocation.libfunc_id.clone()).or_default(); + + // A tick_delta equal to u64::MAX implies it is invalid, so we don't take it + // into account + if *tick_delta != u64::MAX { + deltas.push(*tick_delta); + } else { + *extra_counts += 1; + } + } + } + + profile + } +} + +pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut c_void>) { + let bindings = &[ProfilerBinding::PushStmt]; + + for binding in bindings { + if let Some(global) = find_symbol_ptr(binding.symbol()) { + let global = global.cast::<*const ()>(); + unsafe { *global = binding.function_ptr() }; + } + } +}