diff --git a/Cargo.lock b/Cargo.lock index acc32481df0..274f5a68853 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -707,6 +707,7 @@ dependencies = [ "itertools 0.12.1", "log", "num-bigint", + "num-integer", "num-traits 0.2.19", "pretty_assertions", "rust-analyzer-salsa", diff --git a/crates/cairo-lang-lowering/Cargo.toml b/crates/cairo-lang-lowering/Cargo.toml index 63238d2bd28..822cdc31e24 100644 --- a/crates/cairo-lang-lowering/Cargo.toml +++ b/crates/cairo-lang-lowering/Cargo.toml @@ -20,6 +20,7 @@ id-arena.workspace = true itertools = { workspace = true, default-features = true } log.workspace = true num-bigint = { workspace = true, default-features = true } +num-integer = { workspace = true, default-features = true } num-traits = { workspace = true, default-features = true } salsa.workspace = true smol_str.workspace = true diff --git a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs index a38c3338087..6869c4a7c84 100644 --- a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs +++ b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs @@ -14,6 +14,7 @@ use cairo_lang_utils::{extract_matches, try_extract_matches, Intern, LookupInter use id_arena::Arena; use itertools::{chain, zip_eq}; use num_bigint::BigInt; +use num_integer::Integer; use num_traits::Zero; use smol_str::SmolStr; @@ -64,6 +65,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) { visited[block_id.0] = true; let block = &mut lowered.blocks[block_id]; + let mut additional_consts = vec![]; for stmt in block.statements.iter_mut() { ctx.maybe_replace_inputs(stmt.inputs_mut()); match stmt { @@ -92,8 +94,9 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) { } } Statement::Call(call_stmt) => { - if let Some(updated_stmt) = ctx.handle_statement_call(call_stmt) { - *stmt = updated_stmt; + if let Some((updated_stmt, additional)) = ctx.handle_statement_call(call_stmt) { + *stmt = Statement::Const(updated_stmt); + additional_consts.extend(additional); } } Statement::StructConstruct(StatementStructConstruct { inputs, output }) => { @@ -164,6 +167,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) { } } } + block.statements.splice(0..0, additional_consts.into_iter().map(Statement::Const)); match &mut block.end { FlatBlockEnd::Goto(block_id, remappings) => { @@ -188,7 +192,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) { MatchInfo::Extern(info) => { if let Some((extra_stmt, updated_end)) = ctx.handle_extern_block_end(info) { if let Some(stmt) = extra_stmt { - block.statements.push(stmt); + block.statements.push(Statement::Const(stmt)); } block.end = updated_end; } @@ -216,8 +220,13 @@ struct ConstFoldingContext<'a> { impl<'a> ConstFoldingContext<'a> { /// Handles a statement call. /// Returns None if no additional changes are required. - /// If changes are required, returns an updated statement. - fn handle_statement_call(&mut self, stmt: &mut StatementCall) -> Option { + /// If changes are required, returns an updated const-statement (to override the current + /// statement), and a possible additional const-statement, if multiple statements are required + /// for replacing the existing statement. + fn handle_statement_call( + &mut self, + stmt: &mut StatementCall, + ) -> Option<(StatementConst, Option)> { let id = stmt.function.get_extern(self.db)?; if id == self.felt_sub { // (a - 0) can be replaced by a. @@ -245,7 +254,21 @@ impl<'a> ConstFoldingContext<'a> { value = ConstValue::NonZero(Box::new(value)); } self.var_info.insert(output, VarInfo::Const(value.clone())); - Some(Statement::Const(StatementConst { value, output })) + Some((StatementConst { value, output }, None)) + } else if self.div_rem_fns.contains(&id) { + let lhs = self.as_int(stmt.inputs[0].var_id)?; + let rhs = self.as_int(stmt.inputs[1].var_id)?; + let (q, r) = lhs.div_rem(rhs); + let q_output = stmt.outputs[0]; + let q_value = ConstValue::Int(q, self.variables[q_output].ty); + self.var_info.insert(q_output, VarInfo::Const(q_value.clone())); + let r_output = stmt.outputs[1]; + let r_value = ConstValue::Int(r, self.variables[r_output].ty); + self.var_info.insert(r_output, VarInfo::Const(r_value.clone())); + Some(( + StatementConst { value: q_value, output: q_output }, + Some(StatementConst { value: r_value, output: r_output }), + )) } else if id == self.storage_base_address_from_felt252 { let input_var = stmt.inputs[0].var_id; if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) { @@ -268,12 +291,13 @@ impl<'a> ConstFoldingContext<'a> { let value = ConstValue::Boxed(const_value.clone().into()); // Not inserting the value into the `var_info` map because the // resulting box isn't an actual const at the Sierra level. - Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] })) + Some((StatementConst { value, output: stmt.outputs[0] }, None)) } else if id == self.upcast { let int_value = self.as_int(stmt.inputs[0].var_id)?; - let value = ConstValue::Int(int_value.clone(), self.variables[stmt.outputs[0]].ty); - self.var_info.insert(stmt.outputs[0], VarInfo::Const(value.clone())); - Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] })) + let output = stmt.outputs[0]; + let value = ConstValue::Int(int_value.clone(), self.variables[output].ty); + self.var_info.insert(output, VarInfo::Const(value.clone())); + Some((StatementConst { value, output }, None)) } else { None } @@ -281,12 +305,12 @@ impl<'a> ConstFoldingContext<'a> { /// Handles the end of an extern block. /// Returns None if no additional changes are required. - /// If changes are required, returns a possible additional statement to the block, as well as an - /// updated block end. + /// If changes are required, returns a possible additional const-statement to the block, as well + /// as an updated block end. fn handle_extern_block_end( &mut self, info: &mut MatchExternInfo, - ) -> Option<(Option, FlatBlockEnd)> { + ) -> Option<(Option, FlatBlockEnd)> { let id = info.function.get_extern(self.db)?; if self.nz_fns.contains(&id) { let val = self.as_const(info.inputs[0].var_id)?; @@ -305,7 +329,7 @@ impl<'a> ConstFoldingContext<'a> { let nz_val = ConstValue::NonZero(Box::new(val.clone())); self.var_info.insert(nz_var, VarInfo::Const(nz_val.clone())); ( - Some(Statement::Const(StatementConst { value: nz_val, output: nz_var })), + Some(StatementConst { value: nz_val, output: nz_var }), FlatBlockEnd::Goto(arm.block_id, Default::default()), ) }) @@ -336,7 +360,7 @@ impl<'a> ConstFoldingContext<'a> { let value = ConstValue::Int(value, ty); self.var_info.insert(actual_output, VarInfo::Const(value.clone())); Some(( - Some(Statement::Const(StatementConst { value, output: actual_output })), + Some(StatementConst { value, output: actual_output }), FlatBlockEnd::Goto(arm.block_id, Default::default()), )) } else if id == self.downcast { @@ -349,7 +373,7 @@ impl<'a> ConstFoldingContext<'a> { let value = ConstValue::Int(value, ty); self.var_info.insert(success_output, VarInfo::Const(value.clone())); ( - Some(Statement::Const(StatementConst { value, output: success_output })), + Some(StatementConst { value, output: success_output }), FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()), ) } else { @@ -373,7 +397,7 @@ impl<'a> ConstFoldingContext<'a> { } self.var_info.insert(output, VarInfo::Const(value.clone())); Some(( - Some(Statement::Const(StatementConst { value, output })), + Some(StatementConst { value, output }), FlatBlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()), )) } else { @@ -500,6 +524,8 @@ pub struct ConstFoldingLibfuncInfo { isub_fns: OrderedHashSet, /// The set of functions to multiply integers. wide_mul_fns: OrderedHashSet, + /// The set of functions to divide and get the remainder of integers. + div_rem_fns: OrderedHashSet, /// The `bounded_int_add` libfunc. bounded_int_add: ExternFunctionId, /// The `bounded_int_sub` libfunc. @@ -556,6 +582,10 @@ impl ConstFoldingLibfuncInfo { ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"] .map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))), )); + let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!( + [bounded_int_module.extern_function_id("bounded_int_div_rem")], + utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_safe_divmod"))), + )); let bounded_int_add = bounded_int_module.extern_function_id("bounded_int_add"); let bounded_int_sub = bounded_int_module.extern_function_id("bounded_int_sub"); let bounded_int_constrain = bounded_int_module.extern_function_id("bounded_int_constrain"); @@ -589,6 +619,7 @@ impl ConstFoldingLibfuncInfo { iadd_fns, isub_fns, wide_mul_fns, + div_rem_fns, bounded_int_add, bounded_int_sub, bounded_int_constrain, diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding index a84b5456bfc..526b8a05d4b 100644 --- a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding @@ -173,8 +173,8 @@ End: test_match_optimizer //! > function -fn foo(x: u8) -> u8 { - x / 4 +fn foo() -> u8 { + 8 / 4 } //! > function_name @@ -185,9 +185,10 @@ foo //! > semantic_diagnostics //! > before -Parameters: v0: core::integer::u8 +Parameters: blk0 (root): Statements: + (v0: core::integer::u8) <- 8 (v1: core::integer::u8) <- 4 End: Match(match core::integer::u8_is_zero(v1) { @@ -235,9 +236,10 @@ End: Return(v19) //! > after -Parameters: v0: core::integer::u8 +Parameters: blk0 (root): Statements: + (v0: core::integer::u8) <- 8 (v1: core::integer::u8) <- 4 (v2: core::zeroable::NonZero::) <- NonZero(4) End: @@ -256,7 +258,8 @@ End: blk2: Statements: - (v10: core::integer::u8, v11: core::integer::u8) <- core::integer::u8_safe_divmod(v0, v2) + (v11: core::integer::u8) <- 0 + (v10: core::integer::u8) <- 2 (v12: (core::integer::u8,)) <- struct_construct(v10) (v13: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v12) End: @@ -3203,8 +3206,8 @@ End: test_match_optimizer //! > function -fn foo(x: i8) -> i8 { - x / 4 +fn foo() -> i8 { + 8 / 4 } //! > function_name @@ -3215,9 +3218,10 @@ foo //! > semantic_diagnostics //! > before -Parameters: v0: core::integer::i8 +Parameters: blk0 (root): Statements: + (v0: core::integer::i8) <- 8 (v1: core::integer::i8) <- 4 End: Match(match core::internal::bounded_int::bounded_int_is_zero::(v1) { @@ -3281,9 +3285,10 @@ End: Return(v24) //! > after -Parameters: v0: core::integer::i8 +Parameters: blk0 (root): Statements: + (v0: core::integer::i8) <- 8 (v1: core::integer::i8) <- 4 (v2: core::zeroable::NonZero::) <- NonZero(4) End: