diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml index 976b07e40..0ceea0b10 100644 --- a/.github/workflows/daily.yml +++ b/.github/workflows/daily.yml @@ -199,8 +199,7 @@ jobs: continue-on-error: true - name: Compare states - run: | - ./scripts/cmp_state_dumps.sh | tee output + run: python ./scripts/cmp_state_dumps.py | tee output - name: Upload Compare Results id: upload_compare_results diff --git a/.github/workflows/starknet-blocks.yml b/.github/workflows/starknet-blocks.yml index 9c3d3e320..5513f8b29 100644 --- a/.github/workflows/starknet-blocks.yml +++ b/.github/workflows/starknet-blocks.yml @@ -32,7 +32,7 @@ jobs: with: repository: lambdaclass/starknet-replay path: starknet-replay - ref: 135ba7cd5b45fe137b11b7f75654dcd472363033 + ref: 1b8e2e0be21a8df9f5f6b8f8514d1a40b456ef58 # We need native to use the linux deps ci action - name: Checkout Native uses: actions/checkout@v4 @@ -43,7 +43,7 @@ jobs: with: repository: lambdaclass/sequencer path: sequencer - ref: b61262980394dab0e0a4c4cab85f12d31d0ce878 + ref: 40331042c1149f5cb84b27f9dd8d47994a010bbe - name: Cache RPC Calls uses: actions/cache@v4.2.0 @@ -126,5 +126,4 @@ jobs: continue-on-error: true - name: Compare states - run: | - ./scripts/cmp_state_dumps.sh + run: python ./scripts/cmp_state_dumps.py diff --git a/benches/compile_time.rs b/benches/compile_time.rs index 5ec301c50..d6caa5a7c 100644 --- a/benches/compile_time.rs +++ b/benches/compile_time.rs @@ -15,7 +15,7 @@ pub fn bench_compile_time(c: &mut Criterion) { b.iter(|| { let native_context = NativeContext::new(); native_context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .unwrap(); // pass manager internally verifies the MLIR output is correct. }) @@ -32,7 +32,7 @@ pub fn bench_compile_time(c: &mut Criterion) { c.bench_with_input(BenchmarkId::new(filename, 1), &program, |b, program| { b.iter(|| { native_context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .unwrap(); // pass manager internally verifies the MLIR output is correct. }) @@ -48,9 +48,9 @@ pub fn bench_compile_time(c: &mut Criterion) { b.iter(|| { let native_context = NativeContext::new(); let module = native_context - .compile(black_box(program), false, Some(Default::default())) + .compile(black_box(program), false, Some(Default::default()), None) .unwrap(); - let object = module_to_object(module.module(), OptLevel::None) + let object = module_to_object(module.module(), OptLevel::None, None) .expect("to compile correctly to a object file"); black_box(object) }) @@ -67,9 +67,9 @@ pub fn bench_compile_time(c: &mut Criterion) { c.bench_with_input(BenchmarkId::new(filename, 1), &program, |b, program| { b.iter(|| { let module = native_context - .compile(black_box(program), false, Some(Default::default())) + .compile(black_box(program), false, Some(Default::default()), None) .unwrap(); - let object = module_to_object(module.module(), OptLevel::None) + let object = module_to_object(module.module(), OptLevel::None, None) .expect("to compile correctly to a object file"); black_box(object) }) @@ -86,9 +86,9 @@ pub fn bench_compile_time(c: &mut Criterion) { c.bench_with_input(BenchmarkId::new(filename, 1), &program, |b, program| { b.iter(|| { let module = native_context - .compile(black_box(program), false, Some(Default::default())) + .compile(black_box(program), false, Some(Default::default()), None) .unwrap(); - let object = module_to_object(module.module(), OptLevel::Aggressive) + let object = module_to_object(module.module(), OptLevel::Aggressive, None) .expect("to compile correctly to a object file"); black_box(object) }) diff --git a/benches/libfuncs.rs b/benches/libfuncs.rs index 706920420..537a64a6e 100644 --- a/benches/libfuncs.rs +++ b/benches/libfuncs.rs @@ -55,7 +55,7 @@ pub fn bench_libfuncs(c: &mut Criterion) { let native_context = NativeContext::new(); b.iter(|| { let module = native_context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .unwrap(); // pass manager internally verifies the MLIR output is correct. let native_executor = @@ -77,7 +77,7 @@ pub fn bench_libfuncs(c: &mut Criterion) { |b, program| { let native_context = NativeContext::new(); let module = native_context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .unwrap(); // pass manager internally verifies the MLIR output is correct. let native_executor = @@ -108,7 +108,7 @@ pub fn bench_libfuncs(c: &mut Criterion) { let native_context = NativeContext::new(); b.iter(|| { let module = native_context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .unwrap(); // pass manager internally verifies the MLIR output is correct. let native_executor = @@ -130,7 +130,7 @@ pub fn bench_libfuncs(c: &mut Criterion) { |b, program| { let native_context = NativeContext::new(); let module = native_context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .unwrap(); // pass manager internally verifies the MLIR output is correct. let native_executor = diff --git a/examples/easy_api.rs b/examples/easy_api.rs index fef4ab8ea..c2e2a7d46 100644 --- a/examples/easy_api.rs +++ b/examples/easy_api.rs @@ -16,7 +16,7 @@ fn main() { // Compile the sierra program into a MLIR module. let native_program = native_context - .compile(&sierra_program, false, Some(Default::default())) + .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); // The parameters of the entry point. diff --git a/examples/erc20.rs b/examples/erc20.rs index 8e444854e..c1cffd6a1 100644 --- a/examples/erc20.rs +++ b/examples/erc20.rs @@ -322,7 +322,7 @@ fn main() { let native_context = NativeContext::new(); let native_program = native_context - .compile(&sierra_program, false, Some(Default::default())) + .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); let entry_point_fn = diff --git a/examples/invoke.rs b/examples/invoke.rs index 2ac4a2494..5ea86c925 100644 --- a/examples/invoke.rs +++ b/examples/invoke.rs @@ -21,7 +21,7 @@ fn main() { let native_context = NativeContext::new(); let native_program = native_context - .compile(&sierra_program, false, Some(Default::default())) + .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); // Call the echo function from the contract using the generated wrapper. diff --git a/examples/starknet.rs b/examples/starknet.rs index ab66c6eb7..bae24586c 100644 --- a/examples/starknet.rs +++ b/examples/starknet.rs @@ -456,7 +456,7 @@ fn main() { let native_context = NativeContext::new(); let native_program = native_context - .compile(&sierra_program, false, Some(Default::default())) + .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); // Call the echo function from the contract using the generated wrapper. diff --git a/scripts/cmp_state_dumps.py b/scripts/cmp_state_dumps.py new file mode 100755 index 000000000..f28b03569 --- /dev/null +++ b/scripts/cmp_state_dumps.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# +# usage: cmp-state-dumps [-h] [-d] +# Compare all files in the state_dumps directory and outputs a summary +# options: +# -h, --help show this help message and exit +# -d, --delete removes matching files +# +# Uses a pool of worker threads that compare each state dump. +# possible improvements: use a pool of workers for file removing. + +import argparse +import glob +import re +import multiprocessing as mp +import os +from collections import defaultdict + +POOL_SIZE = 16 + +STATE_DUMPS_PATH = "state_dumps" +VM_DIRECTORY = "vm" +NATIVE_DIRECTORY = "native" + +LOG_PATH = "state_dumps/matching.log" + + +def compare(vm_dump_path: str): + native_dump_path = re.sub(VM_DIRECTORY, NATIVE_DIRECTORY, vm_dump_path, count=1) + + if not (m := re.findall(r"/(0x.*).json", vm_dump_path)): + raise Exception("bad path") + tx = m[0] + + if not (m := re.findall(r"block(\d+)", vm_dump_path)): + raise Exception("bad path") + block = m[0] + + try: + with open(native_dump_path) as f: + native_dump = f.read() + with open(vm_dump_path) as f: + vm_dump = f.read() + except: # noqa: E722 + return ("MISS", block, tx) + + native_dump = re.sub(r".*reverted.*", "", native_dump, count=1) + vm_dump = re.sub(r".*reverted.*", "", vm_dump, count=1) + + if native_dump == vm_dump: + return ("MATCH", block, tx, vm_dump_path, native_dump_path) + else: + return ("DIFF", block, tx) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="cmp-state-dumps", + description="Compare all files in the state_dumps directory and outputs a summary", + ) + parser.add_argument( + "-d", "--delete", action="store_true", help="removes matching files" + ) + config = parser.parse_args() + + files = glob.glob(f"{STATE_DUMPS_PATH}/{VM_DIRECTORY}/*/*.json") + files.sort(key=os.path.getmtime) + + print(f"Starting comparison with {POOL_SIZE} workers") + + stats = defaultdict(int) + with mp.Pool(POOL_SIZE) as pool, open(LOG_PATH, mode="a") as log: + for status, *info in pool.imap(compare, files): + stats[status] += 1 + + if status != "MATCH": + (block, tx) = info + print(status, block, tx) + + elif status == "MATCH" and config.delete: + (block, tx, vm_dump_path, native_dump_path) = info + + log.write(f"{block} {tx}\n") + log.flush() + os.remove(native_dump_path) + os.remove(vm_dump_path) + + print("Finished comparison") + + print() + for key, count in stats.items(): + print(key, count) + + if stats["DIFF"] != 0 or stats["MISS"] != 0: + exit(1) + else: + exit(0) diff --git a/scripts/cmp_state_dumps.sh b/scripts/cmp_state_dumps.sh deleted file mode 100755 index 5719a3228..000000000 --- a/scripts/cmp_state_dumps.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env bash - -# Compares state dump files between two directories: 'state_dumps/vm' and 'state_dumps/native'. -# It iterates over all JSON files in the 'state_dumps/vm' directory and checks if the corresponding -# file exists in 'state_dumps/native'. -# If the corresponding file does not exist, it skips the comparison and counts the missing files. -# For existing pairs, it compares the contents, ignoring the lines containing the "reverted" field, because of error message diference in Native and VM. -# It counts and displays the number of matching, differing, and missing state dumps. - -matching=0 -diffing=0 -missing=0 - -# Iterate over state_dumps/vm dumps -for vm_dump in state_dumps/vm/*/*.json; do - [ -f "$vm_dump" ] || continue - - native_dump="${vm_dump//vm/native}" - - # Check if the corresponding native_dump file exists, if not, skip - if [ ! -f "$native_dump" ]; then - echo "Missing: $native_dump (file not found)" - missing=$((missing+1)) - continue - fi - - tx_name=$(basename "$vm_dump") - tx=${tx_name//.*/} - block_name=$(basename "$(dirname "$vm_dump")") - block=${block_name//block/} - - if ! cmp -s \ - <(sed '/"revert_error": /d' "$native_dump") \ - <(sed '/"revert_error": /d' "$vm_dump") - then - echo "Diff at block $block, tx $tx" - diffing=$((diffing+1)) - else - matching=$((matching+1)) - fi -done - -echo -echo "Finished comparison" -echo "- Matching: $matching" -echo "- Diffing: $diffing" -echo "- Missing: $missing" - -if ! [[ $diffing -eq 0 && $missing -eq 0 ]] ; then - exit 1 -fi diff --git a/src/bin/cairo-native-compile.rs b/src/bin/cairo-native-compile.rs index 1fceffd37..bac328a85 100644 --- a/src/bin/cairo-native-compile.rs +++ b/src/bin/cairo-native-compile.rs @@ -56,7 +56,7 @@ fn main() -> anyhow::Result<()> { // Compile the sierra program into a MLIR module. let native_module = native_context - .compile(&sierra_program, false, Some(Default::default())) + .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); let output_mlir = args @@ -79,9 +79,10 @@ fn main() -> anyhow::Result<()> { }) }); - let object_data = module_to_object(native_module.module(), args.opt_level.into()) + let object_data = module_to_object(native_module.module(), args.opt_level.into(), None) .context("Failed to convert module to object.")?; - object_to_shared_lib(&object_data, &output_lib).context("Failed to write shared library.")?; + object_to_shared_lib(&object_data, &output_lib, None) + .context("Failed to write shared library.")?; Ok(()) } diff --git a/src/bin/cairo-native-dump.rs b/src/bin/cairo-native-dump.rs index 2361530b3..21f3dc6b8 100644 --- a/src/bin/cairo-native-dump.rs +++ b/src/bin/cairo-native-dump.rs @@ -35,7 +35,7 @@ fn main() -> Result<(), Box> { let program = load_program(Path::new(&args.input), args.starknet)?; // Compile the program. - let module = context.compile(&program, false, Some(Default::default()))?; + let module = context.compile(&program, false, Some(Default::default()), None)?; // Write the output. let output_str = module diff --git a/src/bin/cairo-native-run.rs b/src/bin/cairo-native-run.rs index 2214c9bd6..4440026ff 100644 --- a/src/bin/cairo-native-run.rs +++ b/src/bin/cairo-native-run.rs @@ -92,7 +92,7 @@ fn main() -> anyhow::Result<()> { // Compile the sierra program into a MLIR module. let native_module = native_context - .compile(&sierra_program, false, Some(Default::default())) + .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); let native_executor: Box _> = match args.run_mode { diff --git a/src/bin/cairo-native-stress/main.rs b/src/bin/cairo-native-stress/main.rs index 167ef1869..349e6c213 100644 --- a/src/bin/cairo-native-stress/main.rs +++ b/src/bin/cairo-native-stress/main.rs @@ -280,7 +280,7 @@ where ) -> Arc { let native_module = self .context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .expect("failed to compile program"); let registry = ProgramRegistry::new(program).expect("failed to get program registry"); @@ -291,7 +291,7 @@ where .expect("module should have gas metadata"); let shared_library = { - let object_data = module_to_object(native_module.module(), opt_level) + let object_data = module_to_object(native_module.module(), opt_level, None) .expect("failed to convert MLIR to object"); let shared_library_dir = Path::new(AOT_CACHE_DIR); @@ -299,7 +299,7 @@ where let shared_library_name = format!("lib{key}{SHARED_LIBRARY_EXT}"); let shared_library_path = shared_library_dir.join(shared_library_name); - object_to_shared_lib(&object_data, &shared_library_path) + object_to_shared_lib(&object_data, &shared_library_path, None) .expect("failed to link object into shared library"); unsafe { diff --git a/src/bin/scarb-native-dump.rs b/src/bin/scarb-native-dump.rs index ab779d4f4..6e2591a36 100644 --- a/src/bin/scarb-native-dump.rs +++ b/src/bin/scarb-native-dump.rs @@ -41,6 +41,7 @@ fn main() -> anyhow::Result<()> { &compiled.into_v1().unwrap().program, false, Some(Default::default()), + None, ) .unwrap(); diff --git a/src/bin/starknet-native-compile.rs b/src/bin/starknet-native-compile.rs index 8be9f5077..be9155c38 100644 --- a/src/bin/starknet-native-compile.rs +++ b/src/bin/starknet-native-compile.rs @@ -1,4 +1,6 @@ use anyhow::{anyhow, bail, Context}; +use cairo_native::statistics::Statistics; +use std::fs; use std::path::PathBuf; use cairo_lang_sierra::program::Program; @@ -20,6 +22,10 @@ struct Args { opt_level: u8, /// The output file path. output: PathBuf, + + #[arg(long)] + /// Output path for compilation statistics + stats: Option, } fn main() -> anyhow::Result<()> { @@ -27,15 +33,24 @@ fn main() -> anyhow::Result<()> { let (contract_class, sierra_program, sierra_version) = load_sierra_program_from_file(&args.path)?; + let mut stats_with_path = args.stats.map(|path| (Statistics::default(), path)); + let stats = stats_with_path.as_mut().map(|v| &mut v.0); + AotContractExecutor::new_into( &sierra_program, &contract_class.entry_points_by_type, sierra_version, args.output.clone(), args.opt_level.into(), + stats, ) .context("Error compiling Sierra program.")? .with_context(|| format!("Failed to take lock on path {}", args.output.display()))?; + + if let Some((stats, path)) = stats_with_path { + fs::write(path.with_extension("json"), serde_json::to_string(&stats)?)?; + } + Ok(()) } diff --git a/src/bin/utils/test.rs b/src/bin/utils/test.rs index 2b1b3cfc8..aee1bdb3f 100644 --- a/src/bin/utils/test.rs +++ b/src/bin/utils/test.rs @@ -138,7 +138,7 @@ pub fn run_tests( // Compile the sierra program into a MLIR module. let native_module = native_context - .compile(&sierra_program, false, Some(Default::default())) + .compile(&sierra_program, false, Some(Default::default()), None) .unwrap(); let native_executor: Box _> = match args.run_mode { diff --git a/src/cache/aot.rs b/src/cache/aot.rs index f939b131f..d14389583 100644 --- a/src/cache/aot.rs +++ b/src/cache/aot.rs @@ -47,10 +47,10 @@ where mut metadata, } = self .context - .compile(program, false, Some(Default::default()))?; + .compile(program, false, Some(Default::default()), None)?; // Compile module into an object. - let object_data = crate::ffi::module_to_object(&module, opt_level)?; + let object_data = crate::ffi::module_to_object(&module, opt_level, None)?; // Compile object into a shared library. let shared_library_path = tempfile::Builder::new() @@ -58,7 +58,7 @@ where .suffix(SHARED_LIBRARY_EXT) .tempfile()? .into_temp_path(); - crate::ffi::object_to_shared_lib(&object_data, &shared_library_path)?; + crate::ffi::object_to_shared_lib(&object_data, &shared_library_path, None)?; let shared_library = unsafe { Library::new(shared_library_path)? }; let executor = AotNativeExecutor::new( diff --git a/src/cache/jit.rs b/src/cache/jit.rs index 975f692f5..ae64ec338 100644 --- a/src/cache/jit.rs +++ b/src/cache/jit.rs @@ -48,7 +48,7 @@ where ) -> Result>> { let module = self .context - .compile(program, false, Some(Default::default()))?; + .compile(program, false, Some(Default::default()), None)?; let executor = JitNativeExecutor::from_native_module(module, opt_level)?; let executor = Arc::new(executor); diff --git a/src/compiler.rs b/src/compiler.rs index 13b7060d5..2e998d241 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -45,6 +45,7 @@ //! [BFS algorithm]: https://en.wikipedia.org/wiki/Breadth-first_search use crate::{ + clone_option_mut, debug::libfunc_to_name, error::{panic::ToNativeAssertError, Error}, libfuncs::{BranchArg, LibfuncBuilder, LibfuncHelper}, @@ -54,8 +55,9 @@ use crate::{ MetadataStorage, }, native_assert, native_panic, + statistics::Statistics, types::TypeBuilder, - utils::{generate_function_name, BlockExt}, + utils::{generate_function_name, walk_ir::walk_mlir_block, BlockExt}, }; use bumpalo::Bump; use cairo_lang_sierra::{ @@ -98,6 +100,7 @@ use mlir_sys::{ use std::{ cell::Cell, collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + ffi::c_void, ops::Deref, }; @@ -121,6 +124,7 @@ type BlockStorage<'c, 'a> = /// /// Additionally, it needs a reference to the MLIR context, the output module and the metadata /// storage. The last one is passed externally so that stuff can be initialized if necessary. +#[allow(clippy::too_many_arguments)] pub fn compile( context: &Context, module: &Module, @@ -129,6 +133,7 @@ pub fn compile( metadata: &mut MetadataStorage, di_compile_unit_id: Attribute, ignore_debug_names: bool, + stats: Option<&mut Statistics>, ) -> Result<(), Error> { if let Ok(x) = std::env::var("NATIVE_DEBUG_DUMP") { if x == "1" || x == "true" { @@ -158,6 +163,7 @@ pub fn compile( di_compile_unit_id, sierra_stmt_start_offset, ignore_debug_names, + clone_option_mut!(stats), )?; } @@ -184,6 +190,7 @@ fn compile_func( di_compile_unit_id: Attribute, sierra_stmt_start_offset: usize, ignore_debug_names: bool, + stats: Option<&mut Statistics>, ) -> Result<(), Error> { let fn_location = Location::new( context, @@ -621,6 +628,23 @@ fn compile_func( &helper, metadata, )?; + + // When statistics are enabled, we iterate from the start + // to the end block of the compiled libfunc, and count all the operations. + if let Some(&mut ref mut stats) = stats { + unsafe extern "C" fn callback( + _: mlir_sys::MlirOperation, + data: *mut c_void, + ) -> mlir_sys::MlirWalkResult { + let data = data.cast::().as_mut().unwrap(); + *data += 1; + 0 + } + let data = walk_mlir_block(*block, *helper.last_block.get(), callback, 0); + let name = libfunc_to_name(libfunc).to_string(); + *stats.mlir_operations_by_libfunc.entry(name).or_insert(0) += data; + } + native_assert!( block.terminator().is_some(), "libfunc {} had no terminator", diff --git a/src/context.rs b/src/context.rs index 87d141c9f..de8802c43 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,10 +1,12 @@ use crate::{ + clone_option_mut, error::{panic::ToNativeAssertError, Error}, ffi::{get_data_layout_rep, get_target_triple}, metadata::{gas::GasMetadata, runtime_bindings::RuntimeBindingsMeta, MetadataStorage}, module::NativeModule, native_assert, - utils::run_pass_manager, + statistics::Statistics, + utils::{run_pass_manager, walk_ir::walk_mlir_operations}, }; use cairo_lang_sierra::{ extensions::core::{CoreLibfunc, CoreType}, @@ -31,8 +33,7 @@ use mlir_sys::{ mlirLLVMDIModuleAttrGet, MlirLLVMDIEmissionKind_MlirLLVMDIEmissionKindFull, MlirLLVMDINameTableKind_MlirLLVMDINameTableKindDefault, }; -use std::{sync::OnceLock, time::Instant}; -use tracing::trace; +use std::{ffi::c_void, sync::OnceLock, time::Instant}; /// Context of IRs, dialects and passes for Cairo programs compilation. #[derive(Debug, Eq, PartialEq)] @@ -69,10 +70,8 @@ impl NativeContext { program: &Program, ignore_debug_names: bool, gas_metadata_config: Option, + stats: Option<&mut Statistics>, ) -> Result { - trace!("starting sierra to mlir compilation"); - let pre_sierra_compilation_instant = Instant::now(); - static INITIALIZED: OnceLock<()> = OnceLock::new(); INITIALIZED.get_or_init(|| unsafe { LLVM_InitializeAllTargets(); @@ -167,6 +166,7 @@ impl NativeContext { // Create the Sierra program registry let registry = ProgramRegistry::::new(program)?; + let pre_sierra_to_mlir_instant = Instant::now(); crate::compile( &self.context, &module, @@ -175,13 +175,12 @@ impl NativeContext { &mut metadata, unsafe { Attribute::from_raw(di_unit_id) }, ignore_debug_names, + clone_option_mut!(stats), )?; - - let sierra_compilation_time = pre_sierra_compilation_instant.elapsed().as_millis(); - trace!( - time = sierra_compilation_time, - "sierra to mlir compilation finished" - ); + let sierra_to_mlir_time = pre_sierra_to_mlir_instant.elapsed().as_millis(); + if let Some(&mut ref mut stats) = stats { + stats.compilation_sierra_to_mlir_time_ms = Some(sierra_to_mlir_time); + } if let Ok(x) = std::env::var("NATIVE_DEBUG_DUMP") { if x == "1" || x == "true" { @@ -201,11 +200,25 @@ impl NativeContext { } } - trace!("starting mlir passes"); - let pre_passes_instant = Instant::now(); + if let Some(&mut ref mut stats) = stats { + unsafe extern "C" fn callback( + _: mlir_sys::MlirOperation, + data: *mut c_void, + ) -> mlir_sys::MlirWalkResult { + let data = data.cast::().as_mut().unwrap(); + *data += 1; + 0 + } + let data = walk_mlir_operations(module.as_operation(), callback, 0); + stats.mlir_operation_count = Some(data) + } + + let pre_mlir_passes_instant = Instant::now(); run_pass_manager(&self.context, &mut module)?; - let passes_time = pre_passes_instant.elapsed().as_millis(); - trace!(time = passes_time, "mlir passes finished"); + let mlir_passes_time = pre_mlir_passes_instant.elapsed().as_millis(); + if let Some(&mut ref mut stats) = stats { + stats.compilation_mlir_passes_time_ms = Some(mlir_passes_time); + } if let Ok(x) = std::env::var("NATIVE_DEBUG_DUMP") { if x == "1" || x == "true" { diff --git a/src/executor.rs b/src/executor.rs index 35f79fd96..7ac66cd6b 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -720,7 +720,7 @@ mod tests { fn test_invoke_dynamic_aot_native_executor(program: Program) { let native_context = NativeContext::new(); let module = native_context - .compile(&program, false, Some(Default::default())) + .compile(&program, false, Some(Default::default()), None) .expect("failed to compile context"); let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()).unwrap(); @@ -738,7 +738,7 @@ mod tests { fn test_invoke_dynamic_jit_native_executor(program: Program) { let native_context = NativeContext::new(); let module = native_context - .compile(&program, false, None) + .compile(&program, false, None, None) .expect("failed to compile context"); let executor = JitNativeExecutor::from_native_module(module, OptLevel::default()).unwrap(); @@ -756,7 +756,7 @@ mod tests { fn test_invoke_contract_dynamic_aot(starknet_program: Program) { let native_context = NativeContext::new(); let module = native_context - .compile(&starknet_program, false, Some(Default::default())) + .compile(&starknet_program, false, Some(Default::default()), None) .expect("failed to compile context"); let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()).unwrap(); @@ -788,7 +788,7 @@ mod tests { fn test_invoke_contract_dynamic_jit(starknet_program: Program) { let native_context = NativeContext::new(); let module = native_context - .compile(&starknet_program, false, Some(Default::default())) + .compile(&starknet_program, false, Some(Default::default()), None) .expect("failed to compile context"); let executor = JitNativeExecutor::from_native_module(module, OptLevel::default()).unwrap(); diff --git a/src/executor/aot.rs b/src/executor/aot.rs index 7de2cb4f8..a792b8c2b 100644 --- a/src/executor/aot.rs +++ b/src/executor/aot.rs @@ -76,8 +76,8 @@ impl AotNativeExecutor { .keep() .map_err(io::Error::from)?; - let object_data = crate::module_to_object(&module, opt_level)?; - crate::object_to_shared_lib(&object_data, &library_path)?; + let object_data = crate::module_to_object(&module, opt_level, None)?; + crate::object_to_shared_lib(&object_data, &library_path, None)?; Ok(Self::new( unsafe { Library::new(&library_path)? }, @@ -257,7 +257,7 @@ mod tests { fn test_invoke_dynamic(program: Program, #[case] optlevel: OptLevel) { let native_context = NativeContext::new(); let module = native_context - .compile(&program, false, Some(Default::default())) + .compile(&program, false, Some(Default::default()), None) .expect("failed to compile context"); let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap(); @@ -278,7 +278,7 @@ mod tests { fn test_invoke_dynamic_with_syscall_handler(program: Program, #[case] optlevel: OptLevel) { let native_context = NativeContext::new(); let module = native_context - .compile(&program, false, Some(Default::default())) + .compile(&program, false, Some(Default::default()), None) .expect("failed to compile context"); let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap(); @@ -317,7 +317,7 @@ mod tests { fn test_invoke_contract_dynamic(starknet_program: Program, #[case] optlevel: OptLevel) { let native_context = NativeContext::new(); let module = native_context - .compile(&starknet_program, false, Some(Default::default())) + .compile(&starknet_program, false, Some(Default::default()), None) .expect("failed to compile context"); let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap(); diff --git a/src/executor/contract.rs b/src/executor/contract.rs index 6dc417e7b..9e93bbe50 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -33,7 +33,9 @@ use crate::{ arch::AbiArgument, + clone_option_mut, context::NativeContext, + debug::libfunc_to_name, error::{panic::ToNativeAssertError, Error, Result}, execution_result::{BuiltinStats, ContractExecutionResult}, executor::{invoke_trampoline, BuiltinCostsGuard}, @@ -41,6 +43,7 @@ use crate::{ module::NativeModule, native_assert, native_panic, starknet::{handler::StarknetSyscallHandlerCallbacks, StarknetSyscallHandler}, + statistics::Statistics, types::TypeBuilder, utils::{ decode_error_message, generate_function_name, get_integer_layout, libc_free, libc_malloc, @@ -57,7 +60,7 @@ use cairo_lang_sierra::{ starknet::StarknetTypeConcrete, }, ids::FunctionId, - program::{GenFunction, Program, StatementIdx}, + program::{GenFunction, GenStatement, Program, StatementIdx}, program_registry::ProgramRegistry, }; use cairo_lang_sierra_to_casm::metadata::MetadataComputationConfig; @@ -80,6 +83,7 @@ use std::{ path::{Path, PathBuf}, ptr::{self, NonNull}, sync::Arc, + time::Instant, }; use tempfile::NamedTempFile; @@ -134,11 +138,15 @@ impl BuiltinType { impl AotContractExecutor { /// Compile and load a program using a temporary shared library. + /// + /// When enabled, compilation stats will be saved to the `stats`. The + /// initial statistics can be build using the default builder. pub fn new( program: &Program, entry_points: &ContractEntryPoints, sierra_version: VersionId, opt_level: OptLevel, + stats: Option<&mut Statistics>, ) -> Result { let output_path = NamedTempFile::new()? .into_temp_path() @@ -151,6 +159,7 @@ impl AotContractExecutor { sierra_version, output_path, opt_level, + stats, )? .to_native_assert_error("temporary contract path collision")?; @@ -165,12 +174,16 @@ impl AotContractExecutor { /// attempt to compile a program while the `output_path` is already locked will result in /// `Ok(None)` being returned. When this happens, the user should wait until the lock is /// released, at which point they can use `AotContractExecutor::from_path` to load it. + /// + /// When enabled, compilation stats will be saved to the `stats`. The + /// initial statistics can be build using the default builder. pub fn new_into( program: &Program, entry_points: &ContractEntryPoints, sierra_version: VersionId, output_path: impl Into, opt_level: OptLevel, + stats: Option<&mut Statistics>, ) -> Result> { let output_path = output_path.into(); let lock_file = match LockFile::new(&output_path)? { @@ -178,6 +191,8 @@ impl AotContractExecutor { None => return Ok(None), }; + let pre_compilation_instant = Instant::now(); + let context = NativeContext::new(); let no_eq_solver = match sierra_version.major.cmp(&1) { @@ -186,6 +201,13 @@ impl AotContractExecutor { Ordering::Greater => true, }; + if let Some(&mut ref mut stats) = stats { + stats.sierra_type_count = Some(program.type_declarations.len()); + stats.sierra_libfunc_count = Some(program.libfunc_declarations.len()); + stats.sierra_statement_count = Some(program.statements.len()); + stats.sierra_func_count = Some(program.funcs.len()); + } + // Compile the Sierra program. let NativeModule { module, registry, .. @@ -210,8 +232,19 @@ impl AotContractExecutor { skip_non_linear_solver_comparisons: false, compute_runtime_costs: false, }), + clone_option_mut!(stats), )?; + if let Some(&mut ref mut stats) = stats { + for statement in &program.statements { + if let GenStatement::Invocation(invocation) = statement { + let libfunc = registry.get_libfunc(&invocation.libfunc_id)?; + let name = libfunc_to_name(libfunc).to_string(); + *stats.sierra_libfunc_frequency.entry(name).or_insert(0) += 1; + } + } + } + // Generate mappings between the entry point's selectors and their function indexes. let entry_point_mappings = chain!( entry_points.constructor.iter(), @@ -236,10 +269,18 @@ impl AotContractExecutor { }) .collect::>>()?; - let object_data = crate::module_to_object(&module, opt_level)?; + let object_data = crate::module_to_object(&module, opt_level, clone_option_mut!(stats))?; + if let Some(&mut ref mut stats) = stats { + stats.object_size_bytes = Some(object_data.len()); + } // Build the shared library into the lockfile, to avoid using a tmp file. - crate::object_to_shared_lib(&object_data, &lock_file.0)?; + crate::object_to_shared_lib(&object_data, &lock_file.0, clone_option_mut!(stats))?; + + let compilation_time = pre_compilation_instant.elapsed().as_millis(); + if let Some(&mut ref mut stats) = stats { + stats.compilation_total_time_ms = Some(compilation_time); + } // Write the contract info. fs::write( @@ -250,6 +291,10 @@ impl AotContractExecutor { })?, )?; + if let Some(&mut ref mut stats) = stats { + native_assert!(stats.validate(), "some statistics are missing"); + } + // Atomically move the built shared library to the correct path. This will avoid data races // when loading contracts. lock_file.rename(&output_path)?; @@ -781,6 +826,7 @@ mod tests { &starknet_program.entry_points_by_type, sierra_version, optlevel, + None, ) .unwrap(), ); @@ -820,6 +866,7 @@ mod tests { &starknet_program.entry_points_by_type, sierra_version, optlevel, + None, ) .unwrap(); @@ -859,6 +906,7 @@ mod tests { &starknet_program_factorial.entry_points_by_type, sierra_version, optlevel, + None, ) .unwrap(); @@ -899,6 +947,7 @@ mod tests { &starknet_program_empty.entry_points_by_type, sierra_version, optlevel, + None, ) .unwrap(); diff --git a/src/ffi.rs b/src/ffi.rs index ead1d033b..040b602e6 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -3,11 +3,16 @@ //! This is a "hotfix" for missing Rust interfaces to the C/C++ libraries we use, namely LLVM/MLIR //! APIs that are missing from melior. -use crate::error::{panic::ToNativeAssertError, Error, Result}; +use crate::{ + error::{panic::ToNativeAssertError, Error, Result}, + statistics::Statistics, + utils::walk_ir::walk_llvm_instructions, +}; use llvm_sys::{ core::{ LLVMContextCreate, LLVMContextDispose, LLVMDisposeMemoryBuffer, LLVMDisposeMessage, - LLVMDisposeModule, LLVMGetBufferSize, LLVMGetBufferStart, + LLVMDisposeModule, LLVMGetBufferSize, LLVMGetBufferStart, LLVMGetFirstUse, + LLVMGetInstructionOpcode, }, error::LLVMGetErrorMessage, prelude::LLVMMemoryBufferRef, @@ -95,7 +100,11 @@ impl From for OptLevel { } /// Converts a MLIR module to a compile object, that can be linked with a linker. -pub fn module_to_object(module: &Module<'_>, opt_level: OptLevel) -> Result> { +pub fn module_to_object( + module: &Module<'_>, + opt_level: OptLevel, + stats: Option<&mut Statistics>, +) -> Result> { static INITIALIZED: OnceLock<()> = OnceLock::new(); INITIALIZED.get_or_init(|| unsafe { @@ -111,11 +120,42 @@ pub fn module_to_object(module: &Module<'_>, opt_level: OptLevel) -> Result, opt_level: OptLevel) -> Result")) .to_native_assert_error("only fails if the hardcoded string contains a null byte")?; - trace!("starting llvm passes"); - let pre_passes_instant = Instant::now(); + let pre_llvm_passes_instant = Instant::now(); let error = LLVMRunPasses(llvm_module, passes.as_ptr(), machine, opts); - let passes_time = pre_passes_instant.elapsed().as_millis(); - trace!(time = passes_time, "llvm passes finished"); + let llvm_passes_time = pre_llvm_passes_instant.elapsed().as_millis(); + if let Some(&mut ref mut stats) = stats { + stats.compilation_llvm_passes_time_ms = Some(llvm_passes_time); + } if !error.is_null() { let msg = LLVMGetErrorMessage(error); @@ -184,7 +225,7 @@ pub fn module_to_object(module: &Module<'_>, opt_level: OptLevel) -> Result = MaybeUninit::uninit(); trace!("starting llvm to object compilation"); - let pre_llvm_compilation_instant = Instant::now(); + let pre_llvm_to_object_instant = Instant::now(); let ok = LLVMTargetMachineEmitToMemoryBuffer( machine, llvm_module, @@ -192,11 +233,10 @@ pub fn module_to_object(module: &Module<'_>, opt_level: OptLevel) -> Result, opt_level: OptLevel) -> Result Result<()> { +pub fn object_to_shared_lib( + object: &[u8], + output_filename: &Path, + stats: Option<&mut Statistics>, +) -> Result<()> { // linker seems to need a file and doesn't accept stdin let mut file = NamedTempFile::new()?; file.write_all(object)?; @@ -289,11 +333,12 @@ pub fn object_to_shared_lib(object: &[u8], output_filename: &Path) -> Result<()> let mut linker = std::process::Command::new("ld"); - trace!("starting linking"); let pre_linking_instant = Instant::now(); let proc = linker.args(args.iter().map(|x| x.as_ref())).output()?; let linking_time = pre_linking_instant.elapsed().as_millis(); - trace!(time = linking_time, "linking finished"); + if let Some(&mut ref mut stats) = stats { + stats.compilation_linking_time_ms = Some(linking_time); + } if proc.status.success() { Ok(()) diff --git a/src/lib.rs b/src/lib.rs index a3f5c9898..b76c13932 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,7 @@ pub mod module; mod runtime; pub mod starknet; pub mod starknet_stub; +pub mod statistics; mod types; pub mod utils; mod values; diff --git a/src/libfuncs/bounded_int.rs b/src/libfuncs/bounded_int.rs index b5d744dcb..3e048a31f 100644 --- a/src/libfuncs/bounded_int.rs +++ b/src/libfuncs/bounded_int.rs @@ -841,7 +841,7 @@ mod test { } ); let ctx = NativeContext::new(); - let module = ctx.compile(&program, false, None).unwrap(); + let module = ctx.compile(&program, false, None, None).unwrap(); let executor = JitNativeExecutor::from_native_module(module, OptLevel::Default).unwrap(); let ExecutionResult { remaining_gas: _, @@ -874,7 +874,7 @@ mod test { } ); let ctx = NativeContext::new(); - let module = ctx.compile(&program, false, None).unwrap(); + let module = ctx.compile(&program, false, None, None).unwrap(); let executor = JitNativeExecutor::from_native_module(module, OptLevel::Default).unwrap(); let ExecutionResult { remaining_gas: _, @@ -907,7 +907,7 @@ mod test { } ); let ctx = NativeContext::new(); - let module = ctx.compile(&program, false, None).unwrap(); + let module = ctx.compile(&program, false, None, None).unwrap(); let executor = JitNativeExecutor::from_native_module(module, OptLevel::Default).unwrap(); let ExecutionResult { remaining_gas: _, @@ -940,7 +940,7 @@ mod test { } ); let ctx = NativeContext::new(); - let module = ctx.compile(&program, false, None).unwrap(); + let module = ctx.compile(&program, false, None, None).unwrap(); let executor = JitNativeExecutor::from_native_module(module, OptLevel::Default).unwrap(); let ExecutionResult { remaining_gas: _, diff --git a/src/libfuncs/enum.rs b/src/libfuncs/enum.rs index 2a774b835..694f33054 100644 --- a/src/libfuncs/enum.rs +++ b/src/libfuncs/enum.rs @@ -647,7 +647,7 @@ mod test { let native_context = NativeContext::new(); native_context - .compile(&program, false, Some(Default::default())) + .compile(&program, false, Some(Default::default()), None) .unwrap(); } } diff --git a/src/libfuncs/int.rs b/src/libfuncs/int.rs index 0b800eccc..63b532cd4 100644 --- a/src/libfuncs/int.rs +++ b/src/libfuncs/int.rs @@ -947,7 +947,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -993,7 +993,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [0u128, 1u128, u128::MAX]; @@ -1073,7 +1073,7 @@ mod test { }; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; if min.is_zero() { @@ -1161,7 +1161,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -1238,7 +1238,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -1305,7 +1305,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -1374,7 +1374,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [ @@ -1447,7 +1447,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [0u128, 1u128, u128::MAX]; @@ -1528,7 +1528,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -1619,7 +1619,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -1691,7 +1691,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -1750,7 +1750,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; @@ -1794,7 +1794,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [ @@ -1864,7 +1864,7 @@ mod test { .map_err(|e| e.to_string())?; let context = NativeContext::new(); - let module = context.compile(&program, false, None)?; + let module = context.compile(&program, false, None, None)?; let executor = JitNativeExecutor::from_native_module(module, OptLevel::default())?; let data = [T::min_value(), T::zero(), T::one(), T::max_value()]; diff --git a/src/statistics.rs b/src/statistics.rs new file mode 100644 index 000000000..132ad85d8 --- /dev/null +++ b/src/statistics.rs @@ -0,0 +1,93 @@ +use std::collections::BTreeMap; + +use serde::Serialize; + +/// A set of compilation statistics gathered during the compilation. +/// It should be completely filled at the end of the compilation. +#[derive(Default, Serialize)] +pub struct Statistics { + /// Number of types defined in the Sierra code. + pub sierra_type_count: Option, + /// Number of libfuncs defined in the Sierra code. + pub sierra_libfunc_count: Option, + /// Number of statements contained in the Sierra code. + pub sierra_statement_count: Option, + /// Number of user functions defined in the Sierra code. + pub sierra_func_count: Option, + /// Number of statements for each distinct libfunc. + pub sierra_libfunc_frequency: BTreeMap, + /// Number of MLIR operations generated. + pub mlir_operation_count: Option, + /// Number of MLIR operations generated for each distinct libfunc. + pub mlir_operations_by_libfunc: BTreeMap, + /// Number of LLVMIR instructions generated. + pub llvmir_instruction_count: Option, + /// Number of LLVMIR virtual registers defined. + pub llvmir_virtual_register_count: Option, + /// Number of LLVMIR instructions for each distinct opcode. + pub llvmir_opcode_frequency: BTreeMap, + /// Total compilation time. + pub compilation_total_time_ms: Option, + /// Time spent at Sierra to MLIR. + pub compilation_sierra_to_mlir_time_ms: Option, + /// Time spent at MLIR passes. + pub compilation_mlir_passes_time_ms: Option, + /// Time spent at MLIR to LLVMIR translation. + pub compilation_mlir_to_llvm_time_ms: Option, + /// Time spent at LLVM passes. + pub compilation_llvm_passes_time_ms: Option, + /// Time spent at LLVM to object compilation. + pub compilation_llvm_to_object_time_ms: Option, + /// Time spent at linking the shared library. + pub compilation_linking_time_ms: Option, + /// Size of the compiled object. + pub object_size_bytes: Option, +} + +impl Statistics { + pub fn validate(&self) -> bool { + self.sierra_type_count.is_some() + && self.sierra_libfunc_count.is_some() + && self.sierra_statement_count.is_some() + && self.sierra_func_count.is_some() + && !self.sierra_libfunc_frequency.is_empty() + && self.mlir_operation_count.is_some() + && !self.mlir_operations_by_libfunc.is_empty() + && self.llvmir_instruction_count.is_some() + && self.llvmir_virtual_register_count.is_some() + && !self.llvmir_opcode_frequency.is_empty() + && self.compilation_total_time_ms.is_some() + && self.compilation_sierra_to_mlir_time_ms.is_some() + && self.compilation_mlir_passes_time_ms.is_some() + && self.compilation_mlir_to_llvm_time_ms.is_some() + && self.compilation_llvm_passes_time_ms.is_some() + && self.compilation_llvm_to_object_time_ms.is_some() + && self.compilation_linking_time_ms.is_some() + && self.object_size_bytes.is_some() + } +} + +/// Clones a variable of type `Option<&mut T>` without consuming self +/// +/// # Example +/// +/// The following example would fail to compile otherwise. +/// +/// ``` +/// # use cairo_native::clone_option_mut; +/// fn consume(v: Option<&mut Vec>) {} +/// +/// let mut vec = Vec::new(); +/// let option = Some(&mut vec); +/// consume(clone_option_mut!(option)); +/// consume(option); +/// ``` +#[macro_export] +macro_rules! clone_option_mut { + ( $var:ident ) => { + match $var { + None => None, + Some(&mut ref mut s) => Some(s), + } + }; +} diff --git a/src/utils.rs b/src/utils.rs index 1dda89185..78ae71733 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -38,6 +38,7 @@ mod range_ext; pub mod safe_runner; pub mod sierra_gen; pub mod trace_dump; +pub mod walk_ir; #[cfg(target_os = "macos")] pub const SHARED_LIBRARY_EXT: &str = "dylib"; @@ -608,7 +609,7 @@ pub mod test { let context = NativeContext::new(); let module = context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .expect("Could not compile test program to MLIR."); let executor = JitNativeExecutor::from_native_module(module, OptLevel::Less).unwrap(); diff --git a/src/utils/trace_dump.rs b/src/utils/trace_dump.rs index afa489e03..3307d5b2b 100644 --- a/src/utils/trace_dump.rs +++ b/src/utils/trace_dump.rs @@ -184,7 +184,7 @@ mod tests { let native_context = NativeContext::new(); let module = native_context - .compile(&program, false, Some(Default::default())) + .compile(&program, false, Some(Default::default()), None) .expect("failed to compile context"); let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()).unwrap(); diff --git a/src/utils/walk_ir.rs b/src/utils/walk_ir.rs new file mode 100644 index 000000000..b1563341c --- /dev/null +++ b/src/utils/walk_ir.rs @@ -0,0 +1,127 @@ +use std::ffi::c_void; + +use llvm_sys::{ + core::{ + LLVMGetFirstBasicBlock, LLVMGetFirstFunction, LLVMGetFirstInstruction, + LLVMGetNextBasicBlock, LLVMGetNextFunction, LLVMGetNextInstruction, + }, + prelude::{LLVMModuleRef, LLVMValueRef}, + LLVMBasicBlock, LLVMValue, +}; +use melior::ir::{BlockLike, BlockRef, OperationRef}; +use mlir_sys::{MlirOperation, MlirWalkResult}; + +type OperationWalkCallback = + unsafe extern "C" fn(MlirOperation, *mut ::std::os::raw::c_void) -> MlirWalkResult; + +/// Traverses the given operation tree in preorder. +/// +/// Calls `f` on each operation encountered. The second argument to `f` should +/// be interpreted as a pointer to a value of type `T`. +/// +/// TODO: Can we receive a closure instead? +/// We may need to save a pointer to the closure +/// inside of the callback data. +pub fn walk_mlir_operations( + top_op: OperationRef, + f: OperationWalkCallback, + initial: T, +) -> T { + let mut data = Box::new(initial); + unsafe { + mlir_sys::mlirOperationWalk( + top_op.to_raw(), + Some(f), + data.as_mut() as *mut _ as *mut c_void, + mlir_sys::MlirWalkOrder_MlirWalkPreOrder, + ); + }; + *data +} + +/// Traverses from start block to end block (including) in preorder. +/// +/// Calls `f` on each operation encountered. The second argument to `f` should +/// be interpreted as a pointer to a value of type `T`. +/// +/// TODO: Can we receive a closure instead? +/// We may need to save a pointer to the closure +/// inside of the callback data. +pub fn walk_mlir_block( + start_block: BlockRef, + end_block: BlockRef, + f: OperationWalkCallback, + initial: T, +) -> T { + let mut data = Box::new(initial); + + let mut current_block = start_block; + loop { + let mut next_operation = current_block.first_operation(); + + while let Some(operation) = next_operation { + unsafe { + mlir_sys::mlirOperationWalk( + operation.to_raw(), + Some(f), + data.as_mut() as *mut _ as *mut c_void, + mlir_sys::MlirWalkOrder_MlirWalkPreOrder, + ); + }; + + // we have to convert it to raw, and back to ref to bypass borrow checker. + next_operation = unsafe { + operation + .next_in_block() + .map(OperationRef::to_raw) + .map(|op| OperationRef::from_raw(op)) + } + } + + if current_block == end_block { + break; + } + + current_block = current_block + .next_in_region() + .expect("should always reach `end_block`"); + } + + *data +} + +/// Traverses the whole LLVM Module, calling `f` on each instruction. +/// +/// As this function receives a closure rather than a function, there is no need +/// to receive initial data, and can instead modify the captured environment. +pub unsafe fn walk_llvm_instructions(llvm_module: LLVMModuleRef, mut f: impl FnMut(LLVMValueRef)) { + let new_value = |function_ptr: *mut LLVMValue| { + if function_ptr.is_null() { + None + } else { + Some(function_ptr) + } + }; + let new_block = |function_ptr: *mut LLVMBasicBlock| { + if function_ptr.is_null() { + None + } else { + Some(function_ptr) + } + }; + + let mut current_function = new_value(LLVMGetFirstFunction(llvm_module)); + while let Some(function) = current_function { + let mut current_block = new_block(LLVMGetFirstBasicBlock(function)); + while let Some(block) = current_block { + let mut current_instruction = new_value(LLVMGetFirstInstruction(block)); + while let Some(instruction) = current_instruction { + f(instruction); + + current_instruction = new_value(LLVMGetNextInstruction(instruction)); + } + current_block = new_block(LLVMGetNextBasicBlock(block)); + } + current_function = new_value(LLVMGetNextFunction(function)); + } +} diff --git a/tests/common.rs b/tests/common.rs index 320f1ec21..d32063bf3 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -233,7 +233,7 @@ pub fn run_native_program( let context = NativeContext::new(); let module = context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .expect("Could not compile test program to MLIR."); assert!( @@ -430,7 +430,7 @@ pub fn run_native_starknet_contract( let native_context = NativeContext::new(); let native_program = native_context - .compile(sierra_program, false, Some(Default::default())) + .compile(sierra_program, false, Some(Default::default()), None) .unwrap(); let entry_point_fn = find_entry_point_by_idx(sierra_program, entry_point_function_idx).unwrap(); @@ -456,6 +456,7 @@ pub fn run_native_starknet_aot_contract( &contract.entry_points_by_type, sierra_version, Default::default(), + None, ) .unwrap(); native_executor diff --git a/tests/tests/compile_library.rs b/tests/tests/compile_library.rs index cbd5a014c..eabd3a9c2 100644 --- a/tests/tests/compile_library.rs +++ b/tests/tests/compile_library.rs @@ -14,12 +14,12 @@ pub fn compile_library() -> Result<(), Box> { } }; - let module = context.compile(&program.1, false, Some(Default::default()))?; + let module = context.compile(&program.1, false, Some(Default::default()), None)?; - let object = cairo_native::module_to_object(module.module(), Default::default())?; + let object = cairo_native::module_to_object(module.module(), Default::default(), None)?; let file = NamedTempFile::new()?.into_temp_path(); - cairo_native::object_to_shared_lib(&object, &file)?; + cairo_native::object_to_shared_lib(&object, &file, None)?; Ok(()) } diff --git a/tests/tests/trampoline.rs b/tests/tests/trampoline.rs index 846172c08..d7f3a6b47 100644 --- a/tests/tests/trampoline.rs +++ b/tests/tests/trampoline.rs @@ -14,7 +14,7 @@ fn run_program(program: &Program, entry_point: &str, args: &[Value]) -> Executio let context = NativeContext::new(); let module = context - .compile(program, false, Some(Default::default())) + .compile(program, false, Some(Default::default()), None) .unwrap(); // FIXME: There are some bugs with non-zero LLVM optimization levels. let executor = JitNativeExecutor::from_native_module(module, OptLevel::None).unwrap();