From a1b9395360a30bda729a51a507085672778b33d7 Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Tue, 14 Jan 2025 17:23:02 +0100 Subject: [PATCH 1/8] Add basic Sierra generator. --- src/utils.rs | 1 + src/utils/sierra_gen.rs | 236 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 237 insertions(+) create mode 100644 src/utils/sierra_gen.rs diff --git a/src/utils.rs b/src/utils.rs index 6ac8733c2..e6f4c31f7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -34,6 +34,7 @@ mod block_ext; pub mod mem_tracing; mod program_registry_ext; mod range_ext; +pub mod sierra_gen; #[cfg(target_os = "macos")] pub const SHARED_LIBRARY_EXT: &str = "dylib"; diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs new file mode 100644 index 000000000..b4869fe7d --- /dev/null +++ b/src/utils/sierra_gen.rs @@ -0,0 +1,236 @@ +#![cfg(test)] + +use cairo_lang_sierra::{ + extensions::{ + lib_func::{SierraApChange, SignatureSpecializationContext}, + type_specialization_context::TypeSpecializationContext, + GenericLibfunc, + }, + ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericTypeId, VarId}, + program::{ + BranchInfo, ConcreteLibfuncLongId, ConcreteTypeLongId, Function, FunctionSignature, + GenBranchTarget, GenericArg, Invocation, LibfuncDeclaration, Param, Program, Statement, + StatementIdx, TypeDeclaration, + }, +}; +use std::cell::RefCell; + +pub fn generate_program(args: &[GenericArg]) -> Program +where + T: GenericLibfunc, +{ + // Initialize the Sierra generation context (which contains an empty program). + let context = Context(RefCell::new(Program { + type_declarations: Vec::new(), + libfunc_declarations: Vec::new(), + statements: Vec::new(), + funcs: Vec::new(), + })); + + // Extract the libfunc id. + let libfunc_ids = T::supported_ids(); + let libfunc = T::by_id(&libfunc_ids[0]).unwrap(); + assert_eq!(libfunc_ids.len(), 1); + + // Specialize the target libfunc signature. This will generate the required types within the + // program. + let libfunc_signature = libfunc.specialize_signature(&context, args).unwrap(); + + // Generate the target libfunc declaration. + let mut program = context.0.into_inner(); + let libfunc_id = ConcreteLibfuncId::new(program.libfunc_declarations.len() as u64); + program.libfunc_declarations.push(LibfuncDeclaration { + id: libfunc_id.clone(), + long_id: ConcreteLibfuncLongId { + generic_id: libfunc_ids[0].clone(), + generic_args: args.to_vec(), + }, + }); + + // Generate the test's entry point. + let num_builtins; + let ret_types = { + // Add all builtins. + let mut ret_types: Vec = libfunc_signature + .param_signatures + .iter() + .take_while(|param_signature| { + let ty = program + .type_declarations + .iter() + .find(|ty| ty.id == param_signature.ty) + .unwrap(); + matches!( + ty.long_id.generic_id.0.as_str(), + "Bitwise" + | "EcOp" + | "GasBuiltin" + | "BuiltinCosts" + | "RangeCheck" + | "RangeCheck96" + | "Pedersen" + | "Poseidon" + | "Coupon" + | "System" + | "SegmentArena" + | "AddMod" + | "MulMod" + ) + }) + .map(|param_signature| param_signature.ty.clone()) + .collect(); + num_builtins = ret_types.len(); + + // Push the return value. + ret_types.push({ + let num_branches = libfunc_signature.branch_signatures.len(); + let mut iter = libfunc_signature + .branch_signatures + .iter() + .map(|branch_signature| match branch_signature.vars.len() { + 1 => branch_signature.vars[0].ty.clone(), + _ => todo!(), + }); + + match num_branches { + 0 => todo!(), + 1 => iter.next().unwrap(), + _ => todo!(), + } + }); + + ret_types + }; + + program.funcs.push(Function { + id: FunctionId::new(0), + signature: FunctionSignature { + param_types: libfunc_signature + .param_signatures + .iter() + .map(|param_signature| param_signature.ty.clone()) + .collect(), + ret_types, + }, + params: libfunc_signature + .param_signatures + .iter() + .enumerate() + .map(|(id, param_signature)| Param { + id: VarId::new(id as u64), + ty: param_signature.ty.clone(), + }) + .collect(), + entry_point: StatementIdx(0), + }); + + // Generate the statements. + let mut libfunc_invocation = Invocation { + libfunc_id, + args: libfunc_signature + .param_signatures + .iter() + .enumerate() + .map(|(idx, _)| VarId::new(idx as u64)) + .collect(), + branches: Vec::new(), + }; + + for branch_signature in &libfunc_signature.branch_signatures { + libfunc_invocation.branches.push(BranchInfo { + target: GenBranchTarget::Statement(StatementIdx(program.statements.len() + 1)), + results: branch_signature + .vars + .iter() + .enumerate() + .map(|(idx, _)| VarId::new(idx as u64)) + .collect(), + }); + + // TODO: Handle multiple return values (struct_construct). + // TODO: Handle multiple branches (enum_init). + + program.statements.push(Statement::Return( + (0..=num_builtins) + .map(|idx| VarId::new(idx as u64)) + .collect(), + )); + } + + program + .statements + .insert(0, Statement::Invocation(libfunc_invocation)); + + program +} + +struct Context(RefCell); + +impl TypeSpecializationContext for Context { + fn try_get_type_info( + &self, + _id: ConcreteTypeId, + ) -> Option { + todo!() + } +} + +impl SignatureSpecializationContext for Context { + fn try_get_concrete_type( + &self, + id: GenericTypeId, + generic_args: &[GenericArg], + ) -> Option { + let mut program = self.0.borrow_mut(); + + let long_id = ConcreteTypeLongId { + generic_id: id, + generic_args: generic_args.to_vec(), + }; + match program + .type_declarations + .iter() + .find_map(|ty| (ty.long_id == long_id).then_some(ty.id.clone())) + { + Some(x) => Some(x), + None => { + let type_id = ConcreteTypeId { + id: program.type_declarations.len() as u64, + debug_name: None, + }; + program.type_declarations.push(TypeDeclaration { + id: type_id.clone(), + long_id, + declared_type_info: None, + }); + + Some(type_id) + } + } + } + + fn try_get_function_signature(&self, _function_id: &FunctionId) -> Option { + todo!() + } + + fn try_get_function_ap_change(&self, _function_id: &FunctionId) -> Option { + todo!() + } + + fn as_type_specialization_context(&self) -> &dyn TypeSpecializationContext { + todo!() + } +} + +#[cfg(test)] +mod test { + use super::*; + use cairo_lang_sierra::extensions::int::{unsigned::Uint64Traits, IntConstLibfunc}; + + #[test] + fn sierra_generator() { + let program = + generate_program::>(&[GenericArg::Value(0.into())]); + println!("{program}"); + } +} From 280691a498a898583e44f9b1d1b14043450c9d6a Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Wed, 15 Jan 2025 13:28:56 +0100 Subject: [PATCH 2/8] Add multi-return support. --- src/utils/sierra_gen.rs | 112 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 103 insertions(+), 9 deletions(-) diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs index b4869fe7d..04ba2a59d 100644 --- a/src/utils/sierra_gen.rs +++ b/src/utils/sierra_gen.rs @@ -3,17 +3,21 @@ use cairo_lang_sierra::{ extensions::{ lib_func::{SierraApChange, SignatureSpecializationContext}, + structure::{StructConstructLibfunc, StructType}, type_specialization_context::TypeSpecializationContext, - GenericLibfunc, + GenericLibfunc, NamedLibfunc, NamedType, + }, + ids::{ + ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericLibfuncId, GenericTypeId, UserTypeId, + VarId, }, - ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericTypeId, VarId}, program::{ - BranchInfo, ConcreteLibfuncLongId, ConcreteTypeLongId, Function, FunctionSignature, - GenBranchTarget, GenericArg, Invocation, LibfuncDeclaration, Param, Program, Statement, + BranchInfo, BranchTarget, ConcreteLibfuncLongId, ConcreteTypeLongId, Function, + FunctionSignature, GenericArg, Invocation, LibfuncDeclaration, Param, Program, Statement, StatementIdx, TypeDeclaration, }, }; -use std::cell::RefCell; +use std::{cell::RefCell, iter::once}; pub fn generate_program(args: &[GenericArg]) -> Program where @@ -89,7 +93,41 @@ where .iter() .map(|branch_signature| match branch_signature.vars.len() { 1 => branch_signature.vars[0].ty.clone(), - _ => todo!(), + _ => { + // Generate struct type. + let return_type = + ConcreteTypeId::new(program.type_declarations.len() as u64); + program.type_declarations.push(TypeDeclaration { + id: return_type.clone(), + long_id: ConcreteTypeLongId { + generic_id: StructType::ID, + generic_args: once(GenericArg::UserType(UserTypeId::from_string( + "Tuple", + ))) + .chain( + branch_signature + .vars + .iter() + .map(|var_info| GenericArg::Type(var_info.ty.clone())), + ) + .collect(), + }, + declared_type_info: None, + }); + + // Add the struct_construct libfunc declaration. + program.libfunc_declarations.push(LibfuncDeclaration { + id: ConcreteLibfuncId::new(program.libfunc_declarations.len() as u64), + long_id: ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string( + StructConstructLibfunc::STR_ID, + ), + generic_args: vec![GenericArg::Type(return_type.clone())], + }, + }); + + return_type + } }); match num_branches { @@ -136,9 +174,17 @@ where branches: Vec::new(), }; - for branch_signature in &libfunc_signature.branch_signatures { + let mut libfunc_idx = match libfunc_signature.branch_signatures.len() { + 0 => todo!(), + 1 => 1, + _ => 2, + }; + for (branch_idx, branch_signature) in libfunc_signature.branch_signatures.iter().enumerate() { libfunc_invocation.branches.push(BranchInfo { - target: GenBranchTarget::Statement(StatementIdx(program.statements.len() + 1)), + target: match branch_idx { + 0 => BranchTarget::Fallthrough, + _ => BranchTarget::Statement(StatementIdx(program.statements.len() + 1)), + }, results: branch_signature .vars .iter() @@ -147,7 +193,38 @@ where .collect(), }); + if branch_idx != 0 { + program.statements.push(Statement::Invocation(Invocation { + libfunc_id: program.libfunc_declarations[1].id.clone(), + args: Vec::new(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: Vec::new(), + }], + })); + } + // TODO: Handle multiple return values (struct_construct). + if branch_signature.vars.len() != 1 { + let packer_libfunc = &program.libfunc_declarations[libfunc_idx].id; + libfunc_idx += 1; + + program.statements.push(Statement::Invocation(Invocation { + libfunc_id: packer_libfunc.clone(), + args: branch_signature + .vars + .iter() + .enumerate() + .skip(num_builtins) + .map(|(idx, _)| VarId::new(idx as u64)) + .collect(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: vec![VarId::new(num_builtins as u64)], + }], + })); + } + // TODO: Handle multiple branches (enum_init). program.statements.push(Statement::Return( @@ -225,7 +302,12 @@ impl SignatureSpecializationContext for Context { #[cfg(test)] mod test { use super::*; - use cairo_lang_sierra::extensions::int::{unsigned::Uint64Traits, IntConstLibfunc}; + use cairo_lang_sierra::extensions::int::{ + signed::{Sint8Traits, SintDiffLibfunc}, + unsigned::Uint64Traits, + unsigned128::U128GuaranteeMulLibfunc, + IntConstLibfunc, + }; #[test] fn sierra_generator() { @@ -233,4 +315,16 @@ mod test { generate_program::>(&[GenericArg::Value(0.into())]); println!("{program}"); } + + #[test] + fn sierra_generator_multiret() { + let program = generate_program::(&[]); + println!("{program}"); + } + + #[test] + fn sierra_generator_multibranch() { + let program = generate_program::>(&[]); + println!("{program}"); + } } From 5595836d58d3644e6bc7af65de9c1d0bd19bb2a2 Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Wed, 15 Jan 2025 19:36:57 +0100 Subject: [PATCH 3/8] Refactor and complete the Sierra generator. --- src/utils/sierra_gen.rs | 527 +++++++++++++++++++++++++--------------- 1 file changed, 331 insertions(+), 196 deletions(-) diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs index 04ba2a59d..d9829a0fc 100644 --- a/src/utils/sierra_gen.rs +++ b/src/utils/sierra_gen.rs @@ -2,9 +2,12 @@ use cairo_lang_sierra::{ extensions::{ + branch_align::BranchAlignLibfunc, + enm::{EnumInitLibfunc, EnumType}, lib_func::{SierraApChange, SignatureSpecializationContext}, structure::{StructConstructLibfunc, StructType}, type_specialization_context::TypeSpecializationContext, + types::TypeInfo, GenericLibfunc, NamedLibfunc, NamedType, }, ids::{ @@ -17,130 +20,140 @@ use cairo_lang_sierra::{ StatementIdx, TypeDeclaration, }, }; -use std::{cell::RefCell, iter::once}; +use std::{ + cell::{OnceCell, RefCell}, + iter::once, +}; -pub fn generate_program(args: &[GenericArg]) -> Program +pub fn generate_program_with_libfunc_name( + generic_id: GenericLibfuncId, + generic_args: impl Into>, +) -> Program where T: GenericLibfunc, { - // Initialize the Sierra generation context (which contains an empty program). - let context = Context(RefCell::new(Program { - type_declarations: Vec::new(), - libfunc_declarations: Vec::new(), - statements: Vec::new(), - funcs: Vec::new(), - })); - - // Extract the libfunc id. - let libfunc_ids = T::supported_ids(); - let libfunc = T::by_id(&libfunc_ids[0]).unwrap(); - assert_eq!(libfunc_ids.len(), 1); - - // Specialize the target libfunc signature. This will generate the required types within the - // program. - let libfunc_signature = libfunc.specialize_signature(&context, args).unwrap(); - - // Generate the target libfunc declaration. - let mut program = context.0.into_inner(); - let libfunc_id = ConcreteLibfuncId::new(program.libfunc_declarations.len() as u64); - program.libfunc_declarations.push(LibfuncDeclaration { - id: libfunc_id.clone(), - long_id: ConcreteLibfuncLongId { - generic_id: libfunc_ids[0].clone(), - generic_args: args.to_vec(), - }, - }); + let context = ContextWrapper(RefCell::new(Context::default())); + let generic_args = generic_args.into(); - // Generate the test's entry point. - let num_builtins; - let ret_types = { - // Add all builtins. - let mut ret_types: Vec = libfunc_signature - .param_signatures - .iter() - .take_while(|param_signature| { - let ty = program - .type_declarations - .iter() - .find(|ty| ty.id == param_signature.ty) - .unwrap(); - matches!( - ty.long_id.generic_id.0.as_str(), - "Bitwise" - | "EcOp" - | "GasBuiltin" - | "BuiltinCosts" - | "RangeCheck" - | "RangeCheck96" - | "Pedersen" - | "Poseidon" - | "Coupon" - | "System" - | "SegmentArena" - | "AddMod" - | "MulMod" - ) - }) - .map(|param_signature| param_signature.ty.clone()) - .collect(); - num_builtins = ret_types.len(); - - // Push the return value. - ret_types.push({ - let num_branches = libfunc_signature.branch_signatures.len(); - let mut iter = libfunc_signature - .branch_signatures + let libfunc = T::by_id(&generic_id).unwrap(); + let libfunc_signature = libfunc + .specialize_signature(&context, &generic_args) + .unwrap(); + + let mut context = RefCell::into_inner(context.0); + + // Push the libfunc declaration. + let libfunc_id = context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id, + generic_args: generic_args.to_vec(), + }) + .clone(); + + // Generate packed types. + let num_builtins = libfunc_signature + .param_signatures + .iter() + .take_while(|param_signature| { + let long_id = &context + .program + .type_declarations .iter() - .map(|branch_signature| match branch_signature.vars.len() { - 1 => branch_signature.vars[0].ty.clone(), - _ => { - // Generate struct type. - let return_type = - ConcreteTypeId::new(program.type_declarations.len() as u64); - program.type_declarations.push(TypeDeclaration { - id: return_type.clone(), - long_id: ConcreteTypeLongId { - generic_id: StructType::ID, - generic_args: once(GenericArg::UserType(UserTypeId::from_string( - "Tuple", - ))) - .chain( - branch_signature - .vars - .iter() - .map(|var_info| GenericArg::Type(var_info.ty.clone())), - ) - .collect(), - }, - declared_type_info: None, - }); - - // Add the struct_construct libfunc declaration. - program.libfunc_declarations.push(LibfuncDeclaration { - id: ConcreteLibfuncId::new(program.libfunc_declarations.len() as u64), - long_id: ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string( - StructConstructLibfunc::STR_ID, - ), - generic_args: vec![GenericArg::Type(return_type.clone())], - }, - }); - - return_type - } - }); - - match num_branches { - 0 => todo!(), - 1 => iter.next().unwrap(), - _ => todo!(), - } + .find(|type_declaration| type_declaration.id == param_signature.ty) + .unwrap() + .long_id; + + matches!( + long_id.generic_id.0.as_str(), + "Bitwise" + | "EcOp" + | "GasBuiltin" + | "BuiltinCosts" + | "RangeCheck" + | "RangeCheck96" + | "Pedersen" + | "Poseidon" + | "Coupon" + | "System" + | "SegmentArena" + | "AddMod" + | "MulMod" + ) + }) + .count(); + + let mut return_types = Vec::with_capacity(libfunc_signature.branch_signatures.len()); + let mut packed_unit_type_id = None; + for branch_signature in &libfunc_signature.branch_signatures { + assert!(branch_signature + .vars + .iter() + .zip(libfunc_signature.param_signatures.iter().take(num_builtins)) + .all(|(lhs, rhs)| lhs.ty == rhs.ty)); + + return_types.push(match branch_signature.vars.len() - num_builtins { + 0 => match libfunc_signature.branch_signatures.len() { + 1 => ResultVarType::Empty(None), + _ => ResultVarType::Empty(Some( + packed_unit_type_id + .get_or_insert_with(|| { + context + .push_type_declaration(ConcreteTypeLongId { + generic_id: StructType::ID, + generic_args: vec![GenericArg::UserType( + UserTypeId::from_string("Tuple"), + )], + }) + .clone() + }) + .clone(), + )), + }, + 1 => ResultVarType::Single(branch_signature.vars[num_builtins].ty.clone()), + _ => ResultVarType::Multi( + context + .push_type_declaration(ConcreteTypeLongId { + generic_id: StructType::ID, + generic_args: once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) + .chain( + branch_signature + .vars + .iter() + .skip(num_builtins) + .map(|var_info| GenericArg::Type(var_info.ty.clone())), + ) + .collect(), + }) + .clone(), + ), }); + } - ret_types + // Generate switch type. + let return_type = match return_types.len() { + 1 => match return_types[0].clone() { + ResultVarType::Empty(ty) => ty.unwrap().clone(), + ResultVarType::Single(ty) => ty.clone(), + ResultVarType::Multi(ty) => ty.clone(), + }, + _ => context + .push_type_declaration(ConcreteTypeLongId { + generic_id: EnumType::ID, + generic_args: once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) + .chain(return_types.iter().map(|ty| { + GenericArg::Type(match ty { + ResultVarType::Empty(ty) => ty.clone().unwrap(), + ResultVarType::Single(ty) => ty.clone(), + ResultVarType::Multi(ty) => ty.clone(), + }) + })) + .collect(), + }) + .clone(), }; - program.funcs.push(Function { + // Generate function declaration. + context.program.funcs.push(Function { id: FunctionId::new(0), signature: FunctionSignature { param_types: libfunc_signature @@ -148,21 +161,25 @@ where .iter() .map(|param_signature| param_signature.ty.clone()) .collect(), - ret_types, + ret_types: libfunc_signature.param_signatures[..num_builtins] + .iter() + .map(|param_signature| param_signature.ty.clone()) + .chain(once(return_type.clone())) + .collect(), }, params: libfunc_signature .param_signatures .iter() .enumerate() - .map(|(id, param_signature)| Param { - id: VarId::new(id as u64), + .map(|(idx, param_signature)| Param { + id: VarId::new(idx as u64), ty: param_signature.ty.clone(), }) .collect(), entry_point: StatementIdx(0), }); - // Generate the statements. + // Generate statements. let mut libfunc_invocation = Invocation { libfunc_id, args: libfunc_signature @@ -174,17 +191,123 @@ where branches: Vec::new(), }; - let mut libfunc_idx = match libfunc_signature.branch_signatures.len() { - 0 => todo!(), - 1 => 1, - _ => 2, - }; - for (branch_idx, branch_signature) in libfunc_signature.branch_signatures.iter().enumerate() { + let branch_align_libfunc = OnceCell::new(); + let construct_unit_libfunc = packed_unit_type_id.map(|ty| { + context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string(StructConstructLibfunc::STR_ID), + generic_args: vec![GenericArg::Type(ty)], + }) + .clone() + }); + + for (branch_index, branch_signature) in libfunc_signature.branch_signatures.iter().enumerate() { + let branch_target = match branch_index { + 0 => BranchTarget::Fallthrough, + _ => { + let statement_idx = StatementIdx(context.program.statements.len() + 1); + let branch_align_libfunc_id = branch_align_libfunc + .get_or_init(|| { + context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string( + BranchAlignLibfunc::STR_ID, + ), + generic_args: Vec::new(), + }) + .clone() + }) + .clone(); + + context + .program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: branch_align_libfunc_id, + args: Vec::new(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: Vec::new(), + }], + })); + + BranchTarget::Statement(statement_idx) + } + }; + + // Maybe pack values. + match &return_types[branch_index] { + ResultVarType::Empty(Some(_)) => { + context + .program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: construct_unit_libfunc.clone().unwrap(), + args: Vec::new(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: vec![VarId::new(num_builtins as u64)], + }], + })); + } + ResultVarType::Multi(type_id) => { + let construct_libfunc_id = context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string(StructConstructLibfunc::STR_ID), + generic_args: vec![GenericArg::Type(type_id.clone())], + }) + .clone(); + + context + .program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: construct_libfunc_id, + args: (num_builtins..branch_signature.vars.len()) + .map(|x| VarId::new(x as u64)) + .collect(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: vec![VarId::new(num_builtins as u64)], + }], + })); + } + _ => {} + } + + // Maybe enum values. + if libfunc_signature.branch_signatures.len() > 1 { + let enum_libfunc_id = context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string(EnumInitLibfunc::STR_ID), + generic_args: vec![ + GenericArg::Type(return_type.clone()), + GenericArg::Value(branch_index.into()), + ], + }) + .clone(); + + context + .program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: enum_libfunc_id, + args: vec![VarId::new(num_builtins as u64)], + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: vec![VarId::new(num_builtins as u64)], + }], + })); + } + + // Return. + context.program.statements.push(Statement::Return( + (0..=num_builtins).map(|x| VarId::new(x as u64)).collect(), + )); + + // Push the branch target. libfunc_invocation.branches.push(BranchInfo { - target: match branch_idx { - 0 => BranchTarget::Fallthrough, - _ => BranchTarget::Statement(StatementIdx(program.statements.len() + 1)), - }, + target: branch_target, results: branch_signature .vars .iter() @@ -192,98 +315,97 @@ where .map(|(idx, _)| VarId::new(idx as u64)) .collect(), }); + } - if branch_idx != 0 { - program.statements.push(Statement::Invocation(Invocation { - libfunc_id: program.libfunc_declarations[1].id.clone(), - args: Vec::new(), - branches: vec![BranchInfo { - target: BranchTarget::Fallthrough, - results: Vec::new(), - }], - })); - } - - // TODO: Handle multiple return values (struct_construct). - if branch_signature.vars.len() != 1 { - let packer_libfunc = &program.libfunc_declarations[libfunc_idx].id; - libfunc_idx += 1; - - program.statements.push(Statement::Invocation(Invocation { - libfunc_id: packer_libfunc.clone(), - args: branch_signature - .vars - .iter() - .enumerate() - .skip(num_builtins) - .map(|(idx, _)| VarId::new(idx as u64)) - .collect(), - branches: vec![BranchInfo { - target: BranchTarget::Fallthrough, - results: vec![VarId::new(num_builtins as u64)], - }], - })); - } + context + .program + .statements + .insert(0, Statement::Invocation(libfunc_invocation)); - // TODO: Handle multiple branches (enum_init). + context.program +} - program.statements.push(Statement::Return( - (0..=num_builtins) - .map(|idx| VarId::new(idx as u64)) - .collect(), - )); +pub fn generate_program(args: &[GenericArg]) -> Program +where + T: GenericLibfunc, +{ + match T::supported_ids().as_slice() { + [generic_id] => generate_program_with_libfunc_name::(generic_id.clone(), args), + _ => panic!(), } +} - program - .statements - .insert(0, Statement::Invocation(libfunc_invocation)); +#[derive(Debug)] +struct Context { + program: Program, +} - program +impl Default for Context { + fn default() -> Self { + Self { + program: Program { + type_declarations: Vec::new(), + libfunc_declarations: Vec::new(), + statements: Vec::new(), + funcs: Vec::new(), + }, + } + } } -struct Context(RefCell); +impl Context { + pub fn push_type_declaration(&mut self, long_id: ConcreteTypeLongId) -> &ConcreteTypeId { + let id = ConcreteTypeId::new(self.program.type_declarations.len() as u64); + self.program.type_declarations.push(TypeDeclaration { + id, + long_id, + declared_type_info: None, + }); -impl TypeSpecializationContext for Context { - fn try_get_type_info( - &self, - _id: ConcreteTypeId, - ) -> Option { - todo!() + &self.program.type_declarations.last().unwrap().id + } + + pub fn push_libfunc_declaration( + &mut self, + long_id: ConcreteLibfuncLongId, + ) -> &ConcreteLibfuncId { + let id = ConcreteLibfuncId::new(self.program.libfunc_declarations.len() as u64); + self.program + .libfunc_declarations + .push(LibfuncDeclaration { id, long_id }); + + &self.program.libfunc_declarations.last().unwrap().id } } -impl SignatureSpecializationContext for Context { +struct ContextWrapper(RefCell); + +impl SignatureSpecializationContext for ContextWrapper { fn try_get_concrete_type( &self, id: GenericTypeId, generic_args: &[GenericArg], ) -> Option { - let mut program = self.0.borrow_mut(); + let mut context = self.0.borrow_mut(); let long_id = ConcreteTypeLongId { generic_id: id, generic_args: generic_args.to_vec(), }; - match program + assert!(!context + .program .type_declarations .iter() - .find_map(|ty| (ty.long_id == long_id).then_some(ty.id.clone())) - { - Some(x) => Some(x), - None => { - let type_id = ConcreteTypeId { - id: program.type_declarations.len() as u64, - debug_name: None, - }; - program.type_declarations.push(TypeDeclaration { - id: type_id.clone(), - long_id, - declared_type_info: None, - }); - - Some(type_id) - } - } + .any(|type_declaration| type_declaration.long_id == long_id)); + + let id = ConcreteTypeId::new(context.program.type_declarations.len() as u64); + context.program.type_declarations.push(TypeDeclaration { + id: id.clone(), + long_id, + declared_type_info: None, + }); + + Some(id) } fn try_get_function_signature(&self, _function_id: &FunctionId) -> Option { @@ -295,10 +417,23 @@ impl SignatureSpecializationContext for Context { } fn as_type_specialization_context(&self) -> &dyn TypeSpecializationContext { + self + } +} + +impl TypeSpecializationContext for ContextWrapper { + fn try_get_type_info(&self, _id: ConcreteTypeId) -> Option { todo!() } } +#[derive(Clone)] +enum ResultVarType { + Empty(Option), + Single(ConcreteTypeId), + Multi(ConcreteTypeId), +} + #[cfg(test)] mod test { use super::*; From 029e67ea8079d05ba84a26ac090346d3552e654e Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Wed, 15 Jan 2025 23:10:44 +0100 Subject: [PATCH 4/8] Fix API for template types. --- src/utils/sierra_gen.rs | 636 +++++++++++++++++++++------------------- 1 file changed, 337 insertions(+), 299 deletions(-) diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs index d9829a0fc..45370f27e 100644 --- a/src/utils/sierra_gen.rs +++ b/src/utils/sierra_gen.rs @@ -23,98 +23,136 @@ use cairo_lang_sierra::{ use std::{ cell::{OnceCell, RefCell}, iter::once, + marker::PhantomData, }; -pub fn generate_program_with_libfunc_name( - generic_id: GenericLibfuncId, - generic_args: impl Into>, -) -> Program +#[derive(Debug)] +pub struct SierraGenerator where T: GenericLibfunc, { - let context = ContextWrapper(RefCell::new(Context::default())); - let generic_args = generic_args.into(); + program: Program, + phantom: PhantomData, +} - let libfunc = T::by_id(&generic_id).unwrap(); - let libfunc_signature = libfunc - .specialize_signature(&context, &generic_args) - .unwrap(); +impl Default for SierraGenerator +where + T: GenericLibfunc, +{ + fn default() -> Self { + Self { + program: Program { + type_declarations: Vec::new(), + libfunc_declarations: Vec::new(), + statements: Vec::new(), + funcs: Vec::new(), + }, + phantom: PhantomData, + } + } +} - let mut context = RefCell::into_inner(context.0); +impl SierraGenerator +where + T: GenericLibfunc, +{ + pub fn build(self, generic_args: impl Into>) -> Program { + match T::supported_ids().as_slice() { + [generic_id] => self.build_with_generic_id(generic_id.clone(), generic_args.into()), + _ => panic!("multiple generic ids detected, please use build_with_generic_id directly"), + } + } - // Push the libfunc declaration. - let libfunc_id = context - .push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id, - generic_args: generic_args.to_vec(), - }) - .clone(); - - // Generate packed types. - let num_builtins = libfunc_signature - .param_signatures - .iter() - .take_while(|param_signature| { - let long_id = &context - .program - .type_declarations - .iter() - .find(|type_declaration| type_declaration.id == param_signature.ty) - .unwrap() - .long_id; - - matches!( - long_id.generic_id.0.as_str(), - "Bitwise" - | "EcOp" - | "GasBuiltin" - | "BuiltinCosts" - | "RangeCheck" - | "RangeCheck96" - | "Pedersen" - | "Poseidon" - | "Coupon" - | "System" - | "SegmentArena" - | "AddMod" - | "MulMod" - ) - }) - .count(); - - let mut return_types = Vec::with_capacity(libfunc_signature.branch_signatures.len()); - let mut packed_unit_type_id = None; - for branch_signature in &libfunc_signature.branch_signatures { - assert!(branch_signature - .vars + pub fn build_with_generic_id( + self, + generic_id: GenericLibfuncId, + generic_args: impl Into>, + ) -> Program { + let context = SierraGeneratorWrapper(RefCell::new(self)); + let generic_args = generic_args.into(); + + let libfunc = T::by_id(&generic_id).unwrap(); + let libfunc_signature = libfunc + .specialize_signature(&context, &generic_args) + .unwrap(); + + let mut context = RefCell::into_inner(context.0); + + // Push the libfunc declaration. + let libfunc_id = context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id, + generic_args: generic_args.to_vec(), + }) + .clone(); + + // Generate packed types. + let num_builtins = libfunc_signature + .param_signatures .iter() - .zip(libfunc_signature.param_signatures.iter().take(num_builtins)) - .all(|(lhs, rhs)| lhs.ty == rhs.ty)); - - return_types.push(match branch_signature.vars.len() - num_builtins { - 0 => match libfunc_signature.branch_signatures.len() { - 1 => ResultVarType::Empty(None), - _ => ResultVarType::Empty(Some( - packed_unit_type_id - .get_or_insert_with(|| { - context - .push_type_declaration(ConcreteTypeLongId { - generic_id: StructType::ID, - generic_args: vec![GenericArg::UserType( - UserTypeId::from_string("Tuple"), - )], - }) - .clone() - }) - .clone(), - )), - }, - 1 => ResultVarType::Single(branch_signature.vars[num_builtins].ty.clone()), - _ => ResultVarType::Multi( - context - .push_type_declaration(ConcreteTypeLongId { - generic_id: StructType::ID, - generic_args: once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) + .take_while(|param_signature| { + let long_id = &context + .program + .type_declarations + .iter() + .find(|type_declaration| type_declaration.id == param_signature.ty) + .unwrap() + .long_id; + + matches!( + long_id.generic_id.0.as_str(), + "Bitwise" + | "EcOp" + | "GasBuiltin" + | "BuiltinCosts" + | "RangeCheck" + | "RangeCheck96" + | "Pedersen" + | "Poseidon" + | "Coupon" + | "System" + | "SegmentArena" + | "AddMod" + | "MulMod" + ) + }) + .count(); + + let mut return_types = Vec::with_capacity(libfunc_signature.branch_signatures.len()); + let mut packed_unit_type_id = None; + for branch_signature in &libfunc_signature.branch_signatures { + assert!(branch_signature + .vars + .iter() + .zip(libfunc_signature.param_signatures.iter().take(num_builtins)) + .all(|(lhs, rhs)| lhs.ty == rhs.ty)); + + return_types.push(match branch_signature.vars.len() - num_builtins { + 0 => match libfunc_signature.branch_signatures.len() { + 1 => ResultVarType::Empty(None), + _ => ResultVarType::Empty(Some( + packed_unit_type_id + .get_or_insert_with(|| { + context + .push_type_declaration(ConcreteTypeLongId { + generic_id: StructType::ID, + generic_args: vec![GenericArg::UserType( + UserTypeId::from_string("Tuple"), + )], + }) + .clone() + }) + .clone(), + )), + }, + 1 => ResultVarType::Single(branch_signature.vars[num_builtins].ty.clone()), + _ => ResultVarType::Multi( + context + .push_type_declaration(ConcreteTypeLongId { + generic_id: StructType::ID, + generic_args: once(GenericArg::UserType(UserTypeId::from_string( + "Tuple", + ))) .chain( branch_signature .vars @@ -123,138 +161,171 @@ where .map(|var_info| GenericArg::Type(var_info.ty.clone())), ) .collect(), - }) - .clone(), - ), - }); - } - - // Generate switch type. - let return_type = match return_types.len() { - 1 => match return_types[0].clone() { - ResultVarType::Empty(ty) => ty.unwrap().clone(), - ResultVarType::Single(ty) => ty.clone(), - ResultVarType::Multi(ty) => ty.clone(), - }, - _ => context - .push_type_declaration(ConcreteTypeLongId { - generic_id: EnumType::ID, - generic_args: once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) - .chain(return_types.iter().map(|ty| { - GenericArg::Type(match ty { - ResultVarType::Empty(ty) => ty.clone().unwrap(), - ResultVarType::Single(ty) => ty.clone(), - ResultVarType::Multi(ty) => ty.clone(), }) - })) - .collect(), - }) - .clone(), - }; + .clone(), + ), + }); + } - // Generate function declaration. - context.program.funcs.push(Function { - id: FunctionId::new(0), - signature: FunctionSignature { - param_types: libfunc_signature + // Generate switch type. + let return_type = match return_types.len() { + 1 => match return_types[0].clone() { + ResultVarType::Empty(ty) => ty.unwrap().clone(), + ResultVarType::Single(ty) => ty.clone(), + ResultVarType::Multi(ty) => ty.clone(), + }, + _ => context + .push_type_declaration(ConcreteTypeLongId { + generic_id: EnumType::ID, + generic_args: once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) + .chain(return_types.iter().map(|ty| { + GenericArg::Type(match ty { + ResultVarType::Empty(ty) => ty.clone().unwrap(), + ResultVarType::Single(ty) => ty.clone(), + ResultVarType::Multi(ty) => ty.clone(), + }) + })) + .collect(), + }) + .clone(), + }; + + // Generate function declaration. + context.program.funcs.push(Function { + id: FunctionId::new(0), + signature: FunctionSignature { + param_types: libfunc_signature + .param_signatures + .iter() + .map(|param_signature| param_signature.ty.clone()) + .collect(), + ret_types: libfunc_signature.param_signatures[..num_builtins] + .iter() + .map(|param_signature| param_signature.ty.clone()) + .chain(once(return_type.clone())) + .collect(), + }, + params: libfunc_signature .param_signatures .iter() - .map(|param_signature| param_signature.ty.clone()) + .enumerate() + .map(|(idx, param_signature)| Param { + id: VarId::new(idx as u64), + ty: param_signature.ty.clone(), + }) .collect(), - ret_types: libfunc_signature.param_signatures[..num_builtins] + entry_point: StatementIdx(0), + }); + + // Generate statements. + let mut libfunc_invocation = Invocation { + libfunc_id, + args: libfunc_signature + .param_signatures .iter() - .map(|param_signature| param_signature.ty.clone()) - .chain(once(return_type.clone())) + .enumerate() + .map(|(idx, _)| VarId::new(idx as u64)) .collect(), - }, - params: libfunc_signature - .param_signatures - .iter() - .enumerate() - .map(|(idx, param_signature)| Param { - id: VarId::new(idx as u64), - ty: param_signature.ty.clone(), - }) - .collect(), - entry_point: StatementIdx(0), - }); - - // Generate statements. - let mut libfunc_invocation = Invocation { - libfunc_id, - args: libfunc_signature - .param_signatures - .iter() - .enumerate() - .map(|(idx, _)| VarId::new(idx as u64)) - .collect(), - branches: Vec::new(), - }; - - let branch_align_libfunc = OnceCell::new(); - let construct_unit_libfunc = packed_unit_type_id.map(|ty| { - context - .push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string(StructConstructLibfunc::STR_ID), - generic_args: vec![GenericArg::Type(ty)], - }) - .clone() - }); - - for (branch_index, branch_signature) in libfunc_signature.branch_signatures.iter().enumerate() { - let branch_target = match branch_index { - 0 => BranchTarget::Fallthrough, - _ => { - let statement_idx = StatementIdx(context.program.statements.len() + 1); - let branch_align_libfunc_id = branch_align_libfunc - .get_or_init(|| { - context - .push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string( - BranchAlignLibfunc::STR_ID, - ), - generic_args: Vec::new(), - }) - .clone() - }) - .clone(); + branches: Vec::new(), + }; - context - .program - .statements - .push(Statement::Invocation(Invocation { - libfunc_id: branch_align_libfunc_id, - args: Vec::new(), - branches: vec![BranchInfo { - target: BranchTarget::Fallthrough, - results: Vec::new(), - }], - })); + let branch_align_libfunc = OnceCell::new(); + let construct_unit_libfunc = packed_unit_type_id.map(|ty| { + context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string(StructConstructLibfunc::STR_ID), + generic_args: vec![GenericArg::Type(ty)], + }) + .clone() + }); - BranchTarget::Statement(statement_idx) + for (branch_index, branch_signature) in + libfunc_signature.branch_signatures.iter().enumerate() + { + let branch_target = match branch_index { + 0 => BranchTarget::Fallthrough, + _ => { + let statement_idx = StatementIdx(context.program.statements.len() + 1); + let branch_align_libfunc_id = branch_align_libfunc + .get_or_init(|| { + context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string( + BranchAlignLibfunc::STR_ID, + ), + generic_args: Vec::new(), + }) + .clone() + }) + .clone(); + + context + .program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: branch_align_libfunc_id, + args: Vec::new(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: Vec::new(), + }], + })); + + BranchTarget::Statement(statement_idx) + } + }; + + // Maybe pack values. + match &return_types[branch_index] { + ResultVarType::Empty(Some(_)) => { + context + .program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: construct_unit_libfunc.clone().unwrap(), + args: Vec::new(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: vec![VarId::new(num_builtins as u64)], + }], + })); + } + ResultVarType::Multi(type_id) => { + let construct_libfunc_id = context + .push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string( + StructConstructLibfunc::STR_ID, + ), + generic_args: vec![GenericArg::Type(type_id.clone())], + }) + .clone(); + + context + .program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: construct_libfunc_id, + args: (num_builtins..branch_signature.vars.len()) + .map(|x| VarId::new(x as u64)) + .collect(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: vec![VarId::new(num_builtins as u64)], + }], + })); + } + _ => {} } - }; - // Maybe pack values. - match &return_types[branch_index] { - ResultVarType::Empty(Some(_)) => { - context - .program - .statements - .push(Statement::Invocation(Invocation { - libfunc_id: construct_unit_libfunc.clone().unwrap(), - args: Vec::new(), - branches: vec![BranchInfo { - target: BranchTarget::Fallthrough, - results: vec![VarId::new(num_builtins as u64)], - }], - })); - } - ResultVarType::Multi(type_id) => { - let construct_libfunc_id = context + // Maybe enum values. + if libfunc_signature.branch_signatures.len() > 1 { + let enum_libfunc_id = context .push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string(StructConstructLibfunc::STR_ID), - generic_args: vec![GenericArg::Type(type_id.clone())], + generic_id: GenericLibfuncId::from_string(EnumInitLibfunc::STR_ID), + generic_args: vec![ + GenericArg::Type(return_type.clone()), + GenericArg::Value(branch_index.into()), + ], }) .clone(); @@ -262,98 +333,40 @@ where .program .statements .push(Statement::Invocation(Invocation { - libfunc_id: construct_libfunc_id, - args: (num_builtins..branch_signature.vars.len()) - .map(|x| VarId::new(x as u64)) - .collect(), + libfunc_id: enum_libfunc_id, + args: vec![VarId::new(num_builtins as u64)], branches: vec![BranchInfo { target: BranchTarget::Fallthrough, results: vec![VarId::new(num_builtins as u64)], }], })); } - _ => {} - } - // Maybe enum values. - if libfunc_signature.branch_signatures.len() > 1 { - let enum_libfunc_id = context - .push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string(EnumInitLibfunc::STR_ID), - generic_args: vec![ - GenericArg::Type(return_type.clone()), - GenericArg::Value(branch_index.into()), - ], - }) - .clone(); - - context - .program - .statements - .push(Statement::Invocation(Invocation { - libfunc_id: enum_libfunc_id, - args: vec![VarId::new(num_builtins as u64)], - branches: vec![BranchInfo { - target: BranchTarget::Fallthrough, - results: vec![VarId::new(num_builtins as u64)], - }], - })); + // Return. + context.program.statements.push(Statement::Return( + (0..=num_builtins).map(|x| VarId::new(x as u64)).collect(), + )); + + // Push the branch target. + libfunc_invocation.branches.push(BranchInfo { + target: branch_target, + results: branch_signature + .vars + .iter() + .enumerate() + .map(|(idx, _)| VarId::new(idx as u64)) + .collect(), + }); } - // Return. - context.program.statements.push(Statement::Return( - (0..=num_builtins).map(|x| VarId::new(x as u64)).collect(), - )); - - // Push the branch target. - libfunc_invocation.branches.push(BranchInfo { - target: branch_target, - results: branch_signature - .vars - .iter() - .enumerate() - .map(|(idx, _)| VarId::new(idx as u64)) - .collect(), - }); - } - - context - .program - .statements - .insert(0, Statement::Invocation(libfunc_invocation)); - - context.program -} - -pub fn generate_program(args: &[GenericArg]) -> Program -where - T: GenericLibfunc, -{ - match T::supported_ids().as_slice() { - [generic_id] => generate_program_with_libfunc_name::(generic_id.clone(), args), - _ => panic!(), - } -} - -#[derive(Debug)] -struct Context { - program: Program, -} + context + .program + .statements + .insert(0, Statement::Invocation(libfunc_invocation)); -impl Default for Context { - fn default() -> Self { - Self { - program: Program { - type_declarations: Vec::new(), - libfunc_declarations: Vec::new(), - statements: Vec::new(), - funcs: Vec::new(), - }, - } + context.program } -} -impl Context { pub fn push_type_declaration(&mut self, long_id: ConcreteTypeLongId) -> &ConcreteTypeId { let id = ConcreteTypeId::new(self.program.type_declarations.len() as u64); self.program.type_declarations.push(TypeDeclaration { @@ -365,10 +378,7 @@ impl Context { &self.program.type_declarations.last().unwrap().id } - pub fn push_libfunc_declaration( - &mut self, - long_id: ConcreteLibfuncLongId, - ) -> &ConcreteLibfuncId { + fn push_libfunc_declaration(&mut self, long_id: ConcreteLibfuncLongId) -> &ConcreteLibfuncId { let id = ConcreteLibfuncId::new(self.program.libfunc_declarations.len() as u64); self.program .libfunc_declarations @@ -378,9 +388,14 @@ impl Context { } } -struct ContextWrapper(RefCell); +struct SierraGeneratorWrapper(RefCell>) +where + T: GenericLibfunc; -impl SignatureSpecializationContext for ContextWrapper { +impl SignatureSpecializationContext for SierraGeneratorWrapper +where + T: GenericLibfunc, +{ fn try_get_concrete_type( &self, id: GenericTypeId, @@ -421,7 +436,10 @@ impl SignatureSpecializationContext for ContextWrapper { } } -impl TypeSpecializationContext for ContextWrapper { +impl TypeSpecializationContext for SierraGeneratorWrapper +where + T: GenericLibfunc, +{ fn try_get_type_info(&self, _id: ConcreteTypeId) -> Option { todo!() } @@ -437,29 +455,49 @@ enum ResultVarType { #[cfg(test)] mod test { use super::*; - use cairo_lang_sierra::extensions::int::{ - signed::{Sint8Traits, SintDiffLibfunc}, - unsigned::Uint64Traits, - unsigned128::U128GuaranteeMulLibfunc, - IntConstLibfunc, + use cairo_lang_sierra::extensions::{ + array::ArrayNewLibfunc, + int::{ + signed::{Sint8Traits, SintDiffLibfunc}, + unsigned::{Uint64Traits, Uint8Type}, + unsigned128::U128GuaranteeMulLibfunc, + IntConstLibfunc, + }, }; #[test] fn sierra_generator() { - let program = - generate_program::>(&[GenericArg::Value(0.into())]); + let program = SierraGenerator::>::default() + .build(&[GenericArg::Value(0.into())]); println!("{program}"); } #[test] fn sierra_generator_multiret() { - let program = generate_program::(&[]); + let program = SierraGenerator::::default().build(&[]); println!("{program}"); } #[test] fn sierra_generator_multibranch() { - let program = generate_program::>(&[]); + let program = SierraGenerator::>::default().build(&[]); + println!("{program}"); + } + + #[test] + fn sierra_generator_template() { + let program = { + let mut generator = SierraGenerator::::default(); + + let u8_type = generator + .push_type_declaration(ConcreteTypeLongId { + generic_id: Uint8Type::ID, + generic_args: Vec::new(), + }) + .clone(); + + generator.build(&[GenericArg::Type(u8_type)]) + }; println!("{program}"); } } From 7550f630446b32d509855527c9055985506c947f Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Thu, 16 Jan 2025 17:21:46 +0100 Subject: [PATCH 5/8] Fix type declarations. --- src/utils/sierra_gen.rs | 231 +++++++++++++++++++++++----------------- 1 file changed, 132 insertions(+), 99 deletions(-) diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs index 45370f27e..ee03b9021 100644 --- a/src/utils/sierra_gen.rs +++ b/src/utils/sierra_gen.rs @@ -3,19 +3,20 @@ use cairo_lang_sierra::{ extensions::{ branch_align::BranchAlignLibfunc, + core::CoreType, enm::{EnumInitLibfunc, EnumType}, lib_func::{SierraApChange, SignatureSpecializationContext}, structure::{StructConstructLibfunc, StructType}, type_specialization_context::TypeSpecializationContext, types::TypeInfo, - GenericLibfunc, NamedLibfunc, NamedType, + ConcreteType, GenericLibfunc, GenericType, NamedLibfunc, NamedType, }, ids::{ ConcreteLibfuncId, ConcreteTypeId, FunctionId, GenericLibfuncId, GenericTypeId, UserTypeId, VarId, }, program::{ - BranchInfo, BranchTarget, ConcreteLibfuncLongId, ConcreteTypeLongId, Function, + BranchInfo, BranchTarget, ConcreteLibfuncLongId, DeclaredTypeInfo, Function, FunctionSignature, GenericArg, Invocation, LibfuncDeclaration, Param, Program, Statement, StatementIdx, TypeDeclaration, }, @@ -64,22 +65,22 @@ where } pub fn build_with_generic_id( - self, + mut self, generic_id: GenericLibfuncId, generic_args: impl Into>, ) -> Program { - let context = SierraGeneratorWrapper(RefCell::new(self)); let generic_args = generic_args.into(); let libfunc = T::by_id(&generic_id).unwrap(); let libfunc_signature = libfunc - .specialize_signature(&context, &generic_args) + .specialize_signature( + &SierraGeneratorWrapper(RefCell::new(&mut self)), + &generic_args, + ) .unwrap(); - let mut context = RefCell::into_inner(context.0); - // Push the libfunc declaration. - let libfunc_id = context + let libfunc_id = self .push_libfunc_declaration(ConcreteLibfuncLongId { generic_id, generic_args: generic_args.to_vec(), @@ -91,7 +92,7 @@ where .param_signatures .iter() .take_while(|param_signature| { - let long_id = &context + let long_id = &self .program .type_declarations .iter() @@ -133,26 +134,18 @@ where _ => ResultVarType::Empty(Some( packed_unit_type_id .get_or_insert_with(|| { - context - .push_type_declaration(ConcreteTypeLongId { - generic_id: StructType::ID, - generic_args: vec![GenericArg::UserType( - UserTypeId::from_string("Tuple"), - )], - }) - .clone() + self.push_type_declaration::(&[GenericArg::UserType( + UserTypeId::from_string("Tuple"), + )]) + .clone() }) .clone(), )), }, 1 => ResultVarType::Single(branch_signature.vars[num_builtins].ty.clone()), _ => ResultVarType::Multi( - context - .push_type_declaration(ConcreteTypeLongId { - generic_id: StructType::ID, - generic_args: once(GenericArg::UserType(UserTypeId::from_string( - "Tuple", - ))) + self.push_type_declaration::( + once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) .chain( branch_signature .vars @@ -160,9 +153,9 @@ where .skip(num_builtins) .map(|var_info| GenericArg::Type(var_info.ty.clone())), ) - .collect(), - }) - .clone(), + .collect::>(), + ) + .clone(), ), }); } @@ -174,10 +167,9 @@ where ResultVarType::Single(ty) => ty.clone(), ResultVarType::Multi(ty) => ty.clone(), }, - _ => context - .push_type_declaration(ConcreteTypeLongId { - generic_id: EnumType::ID, - generic_args: once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) + _ => self + .push_type_declaration::( + once(GenericArg::UserType(UserTypeId::from_string("Tuple"))) .chain(return_types.iter().map(|ty| { GenericArg::Type(match ty { ResultVarType::Empty(ty) => ty.clone().unwrap(), @@ -185,13 +177,13 @@ where ResultVarType::Multi(ty) => ty.clone(), }) })) - .collect(), - }) + .collect::>(), + ) .clone(), }; // Generate function declaration. - context.program.funcs.push(Function { + self.program.funcs.push(Function { id: FunctionId::new(0), signature: FunctionSignature { param_types: libfunc_signature @@ -231,12 +223,11 @@ where let branch_align_libfunc = OnceCell::new(); let construct_unit_libfunc = packed_unit_type_id.map(|ty| { - context - .push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string(StructConstructLibfunc::STR_ID), - generic_args: vec![GenericArg::Type(ty)], - }) - .clone() + self.push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string(StructConstructLibfunc::STR_ID), + generic_args: vec![GenericArg::Type(ty)], + }) + .clone() }); for (branch_index, branch_signature) in @@ -245,22 +236,20 @@ where let branch_target = match branch_index { 0 => BranchTarget::Fallthrough, _ => { - let statement_idx = StatementIdx(context.program.statements.len() + 1); + let statement_idx = StatementIdx(self.program.statements.len() + 1); let branch_align_libfunc_id = branch_align_libfunc .get_or_init(|| { - context - .push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string( - BranchAlignLibfunc::STR_ID, - ), - generic_args: Vec::new(), - }) - .clone() + self.push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string( + BranchAlignLibfunc::STR_ID, + ), + generic_args: Vec::new(), + }) + .clone() }) .clone(); - context - .program + self.program .statements .push(Statement::Invocation(Invocation { libfunc_id: branch_align_libfunc_id, @@ -278,8 +267,7 @@ where // Maybe pack values. match &return_types[branch_index] { ResultVarType::Empty(Some(_)) => { - context - .program + self.program .statements .push(Statement::Invocation(Invocation { libfunc_id: construct_unit_libfunc.clone().unwrap(), @@ -291,7 +279,7 @@ where })); } ResultVarType::Multi(type_id) => { - let construct_libfunc_id = context + let construct_libfunc_id = self .push_libfunc_declaration(ConcreteLibfuncLongId { generic_id: GenericLibfuncId::from_string( StructConstructLibfunc::STR_ID, @@ -300,8 +288,7 @@ where }) .clone(); - context - .program + self.program .statements .push(Statement::Invocation(Invocation { libfunc_id: construct_libfunc_id, @@ -319,7 +306,7 @@ where // Maybe enum values. if libfunc_signature.branch_signatures.len() > 1 { - let enum_libfunc_id = context + let enum_libfunc_id = self .push_libfunc_declaration(ConcreteLibfuncLongId { generic_id: GenericLibfuncId::from_string(EnumInitLibfunc::STR_ID), generic_args: vec![ @@ -329,8 +316,7 @@ where }) .clone(); - context - .program + self.program .statements .push(Statement::Invocation(Invocation { libfunc_id: enum_libfunc_id, @@ -343,7 +329,7 @@ where } // Return. - context.program.statements.push(Statement::Return( + self.program.statements.push(Statement::Return( (0..=num_builtins).map(|x| VarId::new(x as u64)).collect(), )); @@ -359,21 +345,51 @@ where }); } - context - .program + self.program .statements .insert(0, Statement::Invocation(libfunc_invocation)); - context.program + self.program + } + + pub fn push_type_declaration( + &mut self, + generic_args: impl Into>, + ) -> &ConcreteTypeId + where + U: NamedType, + { + self.push_type_declaration_with_generic_id::(U::ID, generic_args) } - pub fn push_type_declaration(&mut self, long_id: ConcreteTypeLongId) -> &ConcreteTypeId { + pub fn push_type_declaration_with_generic_id( + &mut self, + generic_id: GenericTypeId, + generic_args: impl Into>, + ) -> &ConcreteTypeId + where + U: GenericType, + { + let generic_args = generic_args.into(); + + let type_info = U::by_id(&generic_id) + .unwrap() + .specialize(&SierraGeneratorWrapper(RefCell::new(self)), &generic_args) + .unwrap() + .info() + .clone(); + let id = ConcreteTypeId::new(self.program.type_declarations.len() as u64); - self.program.type_declarations.push(TypeDeclaration { + self.program.type_declarations.push(dbg!(TypeDeclaration { id, - long_id, - declared_type_info: None, - }); + long_id: type_info.long_id, + declared_type_info: Some(DeclaredTypeInfo { + storable: type_info.storable, + droppable: type_info.droppable, + duplicatable: type_info.duplicatable, + zero_sized: type_info.zero_sized, + }), + })); &self.program.type_declarations.last().unwrap().id } @@ -388,11 +404,11 @@ where } } -struct SierraGeneratorWrapper(RefCell>) +struct SierraGeneratorWrapper<'a, T>(RefCell<&'a mut SierraGenerator>) where T: GenericLibfunc; -impl SignatureSpecializationContext for SierraGeneratorWrapper +impl SignatureSpecializationContext for SierraGeneratorWrapper<'_, T> where T: GenericLibfunc, { @@ -401,26 +417,12 @@ where id: GenericTypeId, generic_args: &[GenericArg], ) -> Option { - let mut context = self.0.borrow_mut(); - - let long_id = ConcreteTypeLongId { - generic_id: id, - generic_args: generic_args.to_vec(), - }; - assert!(!context - .program - .type_declarations - .iter() - .any(|type_declaration| type_declaration.long_id == long_id)); - - let id = ConcreteTypeId::new(context.program.type_declarations.len() as u64); - context.program.type_declarations.push(TypeDeclaration { - id: id.clone(), - long_id, - declared_type_info: None, - }); - - Some(id) + Some( + self.0 + .borrow_mut() + .push_type_declaration_with_generic_id::(id, generic_args) + .clone(), + ) } fn try_get_function_signature(&self, _function_id: &FunctionId) -> Option { @@ -436,12 +438,32 @@ where } } -impl TypeSpecializationContext for SierraGeneratorWrapper +impl TypeSpecializationContext for SierraGeneratorWrapper<'_, T> where T: GenericLibfunc, { - fn try_get_type_info(&self, _id: ConcreteTypeId) -> Option { - todo!() + fn try_get_type_info(&self, id: ConcreteTypeId) -> Option { + self.0 + .borrow() + .program + .type_declarations + .iter() + .find_map(|type_declaration| { + (type_declaration.id == id).then(|| { + dbg!( + &type_declaration.long_id, + type_declaration.declared_type_info.as_ref() + ); + let declared_type_info = type_declaration.declared_type_info.as_ref().unwrap(); + TypeInfo { + long_id: type_declaration.long_id.clone(), + storable: declared_type_info.storable, + droppable: declared_type_info.droppable, + duplicatable: declared_type_info.duplicatable, + zero_sized: declared_type_info.zero_sized, + } + }) + }) } } @@ -457,9 +479,10 @@ mod test { use super::*; use cairo_lang_sierra::extensions::{ array::ArrayNewLibfunc, + bounded_int::BoundedIntTrimLibfunc, int::{ signed::{Sint8Traits, SintDiffLibfunc}, - unsigned::{Uint64Traits, Uint8Type}, + unsigned::{Uint32Type, Uint64Traits, Uint8Type}, unsigned128::U128GuaranteeMulLibfunc, IntConstLibfunc, }, @@ -489,15 +512,25 @@ mod test { let program = { let mut generator = SierraGenerator::::default(); - let u8_type = generator - .push_type_declaration(ConcreteTypeLongId { - generic_id: Uint8Type::ID, - generic_args: Vec::new(), - }) - .clone(); + let u8_type = generator.push_type_declaration::(&[]).clone(); generator.build(&[GenericArg::Type(u8_type)]) }; println!("{program}"); } + + #[test] + fn sierra_generator_type_info() { + let program = { + let mut generator = SierraGenerator::::default(); + + let u32_type = generator.push_type_declaration::(&[]).clone(); + + generator.build(&[ + GenericArg::Type(u32_type), + GenericArg::Value(u32::MAX.into()), + ]) + }; + println!("{program}"); + } } From 4f8c8ff2a71b15e19d5b945244ecbe83cd04e42f Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Fri, 17 Jan 2025 14:22:50 +0100 Subject: [PATCH 6/8] Fix branch align generation. --- src/utils/sierra_gen.rs | 62 ++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs index ee03b9021..a3a8a609e 100644 --- a/src/utils/sierra_gen.rs +++ b/src/utils/sierra_gen.rs @@ -233,33 +233,32 @@ where for (branch_index, branch_signature) in libfunc_signature.branch_signatures.iter().enumerate() { + if libfunc_signature.branch_signatures.len() > 1 { + let branch_align_libfunc_id = branch_align_libfunc + .get_or_init(|| { + self.push_libfunc_declaration(ConcreteLibfuncLongId { + generic_id: GenericLibfuncId::from_string(BranchAlignLibfunc::STR_ID), + generic_args: Vec::new(), + }) + .clone() + }) + .clone(); + self.program + .statements + .push(Statement::Invocation(Invocation { + libfunc_id: branch_align_libfunc_id, + args: Vec::new(), + branches: vec![BranchInfo { + target: BranchTarget::Fallthrough, + results: Vec::new(), + }], + })); + } + let branch_target = match branch_index { 0 => BranchTarget::Fallthrough, _ => { let statement_idx = StatementIdx(self.program.statements.len() + 1); - let branch_align_libfunc_id = branch_align_libfunc - .get_or_init(|| { - self.push_libfunc_declaration(ConcreteLibfuncLongId { - generic_id: GenericLibfuncId::from_string( - BranchAlignLibfunc::STR_ID, - ), - generic_args: Vec::new(), - }) - .clone() - }) - .clone(); - - self.program - .statements - .push(Statement::Invocation(Invocation { - libfunc_id: branch_align_libfunc_id, - args: Vec::new(), - branches: vec![BranchInfo { - target: BranchTarget::Fallthrough, - results: Vec::new(), - }], - })); - BranchTarget::Statement(statement_idx) } }; @@ -380,7 +379,7 @@ where .clone(); let id = ConcreteTypeId::new(self.program.type_declarations.len() as u64); - self.program.type_declarations.push(dbg!(TypeDeclaration { + self.program.type_declarations.push(TypeDeclaration { id, long_id: type_info.long_id, declared_type_info: Some(DeclaredTypeInfo { @@ -389,7 +388,7 @@ where duplicatable: type_info.duplicatable, zero_sized: type_info.zero_sized, }), - })); + }); &self.program.type_declarations.last().unwrap().id } @@ -450,10 +449,6 @@ where .iter() .find_map(|type_declaration| { (type_declaration.id == id).then(|| { - dbg!( - &type_declaration.long_id, - type_declaration.declared_type_info.as_ref() - ); let declared_type_info = type_declaration.declared_type_info.as_ref().unwrap(); TypeInfo { long_id: type_declaration.long_id.clone(), @@ -480,12 +475,14 @@ mod test { use cairo_lang_sierra::extensions::{ array::ArrayNewLibfunc, bounded_int::BoundedIntTrimLibfunc, + bytes31::Bytes31FromFelt252Trait, int::{ signed::{Sint8Traits, SintDiffLibfunc}, unsigned::{Uint32Type, Uint64Traits, Uint8Type}, unsigned128::U128GuaranteeMulLibfunc, IntConstLibfunc, }, + try_from_felt252::TryFromFelt252Libfunc, }; #[test] @@ -533,4 +530,11 @@ mod test { }; println!("{program}"); } + + #[test] + fn sierra_generator_branch_align() { + let program = + SierraGenerator::>::default().build(&[]); + println!("{program}"); + } } From ff88db982d2884e927c213228283f0b9a8a2a787 Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Fri, 24 Jan 2025 16:55:31 +0100 Subject: [PATCH 7/8] Fix bug. --- src/utils/sierra_gen.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs index a3a8a609e..2b9bb54ed 100644 --- a/src/utils/sierra_gen.rs +++ b/src/utils/sierra_gen.rs @@ -258,7 +258,7 @@ where let branch_target = match branch_index { 0 => BranchTarget::Fallthrough, _ => { - let statement_idx = StatementIdx(self.program.statements.len() + 1); + let statement_idx = StatementIdx(self.program.statements.len()); BranchTarget::Statement(statement_idx) } }; From 7b21d211e47cf65c84df15eec662cface5b116b5 Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Tue, 4 Feb 2025 20:51:45 +0100 Subject: [PATCH 8/8] Fix duplicated type generation. --- src/utils/sierra_gen.rs | 49 ++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/src/utils/sierra_gen.rs b/src/utils/sierra_gen.rs index 2b9bb54ed..f224130e9 100644 --- a/src/utils/sierra_gen.rs +++ b/src/utils/sierra_gen.rs @@ -377,20 +377,28 @@ where .unwrap() .info() .clone(); - - let id = ConcreteTypeId::new(self.program.type_declarations.len() as u64); - self.program.type_declarations.push(TypeDeclaration { - id, - long_id: type_info.long_id, - declared_type_info: Some(DeclaredTypeInfo { - storable: type_info.storable, - droppable: type_info.droppable, - duplicatable: type_info.duplicatable, - zero_sized: type_info.zero_sized, - }), + let current_index = self + .program + .type_declarations + .iter() + .enumerate() + .find_map(|(idx, type_decl)| (type_decl.long_id == type_info.long_id).then_some(idx)); + + let current_index = current_index.unwrap_or_else(|| { + let idx = self.program.type_declarations.len(); + self.program.type_declarations.push(TypeDeclaration { + id: ConcreteTypeId::new(idx as u64), + long_id: type_info.long_id, + declared_type_info: Some(DeclaredTypeInfo { + storable: type_info.storable, + droppable: type_info.droppable, + duplicatable: type_info.duplicatable, + zero_sized: type_info.zero_sized, + }), + }); + idx }); - - &self.program.type_declarations.last().unwrap().id + &self.program.type_declarations[current_index].id } fn push_libfunc_declaration(&mut self, long_id: ConcreteLibfuncLongId) -> &ConcreteLibfuncId { @@ -537,4 +545,19 @@ mod test { SierraGenerator::>::default().build(&[]); println!("{program}"); } + + #[test] + fn sierra_generator_type_generation() { + let mut generator = + SierraGenerator::::default(); + let u32_ty = generator.push_type_declaration::(&[]).clone(); + let array_ty = generator + .push_type_declaration::(&[ + GenericArg::Type(u32_ty), + ]) + .clone(); + + let program = generator.build(&[GenericArg::Type(array_ty)]); + println!("{program}"); + } }