Skip to content

Commit

Permalink
Enabled bounded-int const folding. (#6419)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Sep 26, 2024
1 parent f982d3c commit e4136a4
Show file tree
Hide file tree
Showing 4 changed files with 491 additions and 10 deletions.
5 changes: 2 additions & 3 deletions corelib/src/integer.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1962,7 +1962,7 @@ impl I128PartialOrd of PartialOrd<i128> {

mod signed_div_rem {
use crate::internal::bounded_int::{
BoundedInt, ConstrainHelper, DivRemHelper, NegateHelper, constrain, div_rem
BoundedInt, ConstrainHelper, DivRemHelper, NegateHelper, constrain, div_rem, is_zero,
};
use super::{upcast, downcast};

Expand Down Expand Up @@ -2105,10 +2105,9 @@ mod signed_div_rem {
>;
pub impl I128DivRem = DivRemImpl<i128>;

extern fn bounded_int_is_zero<T>(value: T) -> super::IsZeroResult<T> implicits() nopanic;
pub impl TryIntoNonZero<T> of TryInto<T, NonZero<T>> {
fn try_into(self: T) -> Option<NonZero<T>> {
match bounded_int_is_zero(self) {
match is_zero(self) {
super::IsZeroResult::Zero => Option::None,
super::IsZeroResult::NonZero(x) => Option::Some(x),
}
Expand Down
5 changes: 4 additions & 1 deletion corelib/src/internal/bounded_int.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ extern fn bounded_int_constrain<T, const BOUNDARY: felt252, impl H: ConstrainHel
value: T
) -> Result<H::LowT, H::HighT> implicits(RangeCheck) nopanic;

extern fn bounded_int_is_zero<T>(value: T) -> crate::zeroable::IsZeroResult<T> implicits() nopanic;

/// Returns the negation of the given `felt252` value.
trait NegFelt252<const NUM: felt252> {
/// The negation of the given `felt252` value.
Expand Down Expand Up @@ -216,5 +218,6 @@ impl MulMinusOneNegateHelper<T, impl H: MulHelper<T, MinusOne>> of NegateHelper<

pub use {
bounded_int_add as add, bounded_int_sub as sub, bounded_int_mul as mul,
bounded_int_div_rem as div_rem, bounded_int_constrain as constrain
bounded_int_div_rem as div_rem, bounded_int_constrain as constrain,
bounded_int_is_zero as is_zero
};
56 changes: 50 additions & 6 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use cairo_lang_semantic::{corelib, GenericArgumentId, TypeId};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::{try_extract_matches, Intern};
use cairo_lang_utils::{extract_matches, try_extract_matches, Intern, LookupIntern};
use id_arena::Arena;
use itertools::{chain, zip_eq};
use num_bigint::BigInt;
Expand Down Expand Up @@ -226,10 +226,19 @@ impl<'a> ConstFoldingContext<'a> {
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
}
None
} else if self.wide_mul_fns.contains(&id) {
} else if self.wide_mul_fns.contains(&id)
|| id == self.bounded_int_add
|| id == self.bounded_int_sub
{
let lhs = self.as_int(stmt.inputs[0].var_id)?;
let rhs = self.as_int(stmt.inputs[1].var_id)?;
let value = lhs * rhs;
let value = if id == self.bounded_int_add {
lhs + rhs
} else if id == self.bounded_int_sub {
lhs - rhs
} else {
lhs * rhs
};
let output = stmt.outputs[0];
let ty = self.variables[output].ty;
let value = ConstValue::Int(value, ty);
Expand Down Expand Up @@ -344,6 +353,24 @@ impl<'a> ConstFoldingContext<'a> {
} else {
(None, FlatBlockEnd::Goto(info.arms[1].block_id, Default::default()))
})
} else if id == self.bounded_int_constrain {
let input_var = info.inputs[0].var_id;
let value = self.as_int(input_var)?;
let semantic_id =
extract_matches!(info.function.lookup_intern(self.db), FunctionLongId::Semantic);
let generic_arg = semantic_id.get_concrete(self.db.upcast()).generic_args[1];
let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
.lookup_intern(self.db)
.into_int()
.unwrap();
let arm_idx = if value < &constrain_value { 0 } else { 1 };
let output = info.arms[arm_idx].var_ids[0];
let value = ConstValue::Int(value.clone(), self.variables[output].ty);
self.var_info.insert(output, VarInfo::Const(value.clone()));
Some((
Some(Statement::Const(StatementConst { value, output })),
FlatBlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
))
} else {
None
}
Expand Down Expand Up @@ -452,6 +479,12 @@ pub struct ConstFoldingLibfuncInfo {
isub_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to multiply integers.
wide_mul_fns: OrderedHashSet<ExternFunctionId>,
/// The `bounded_int_add` libfunc.
bounded_int_add: ExternFunctionId,
/// The `bounded_int_sub` libfunc.
bounded_int_sub: ExternFunctionId,
/// The `bounded_int_constrain` libfunc.
bounded_int_constrain: ExternFunctionId,
/// The storage access module.
storage_access_module: ModuleId,
/// Type ranges.
Expand All @@ -464,14 +497,18 @@ impl ConstFoldingLibfuncInfo {
let box_module = core.submodule("box");
let into_box = box_module.extern_function_id("into_box");
let integer_module = core.submodule("integer");
let bounded_int_module = core.submodule("internal").submodule("bounded_int");
let upcast = integer_module.extern_function_id("upcast");
let downcast = integer_module.extern_function_id("downcast");
let starknet_module = core.submodule("starknet");
let storage_access_module = starknet_module.submodule("storage_access");
let storage_base_address_from_felt252 =
storage_access_module.extern_function_id("storage_base_address_from_felt252");
let nz_fns = OrderedHashSet::<_>::from_iter(chain!(
[core.extern_function_id("felt252_is_zero")],
[
core.extern_function_id("felt252_is_zero"),
bounded_int_module.extern_function_id("bounded_int_is_zero")
],
["u8", "u16", "u32", "u64", "u128", "u256", "i8", "i16", "i32", "i64", "i128"]
.map(|ty| integer_module.extern_function_id(format!("{ty}_is_zero")))
));
Expand All @@ -493,10 +530,14 @@ impl ConstFoldingLibfuncInfo {
itypes
.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub_impl"))),
);
let wide_mul_fns = OrderedHashSet::<_>::from_iter(
let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
[bounded_int_module.extern_function_id("bounded_int_mul")],
["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
.map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))),
);
));
let bounded_int_add = bounded_int_module.extern_function_id("bounded_int_add");
let bounded_int_sub = bounded_int_module.extern_function_id("bounded_int_sub");
let bounded_int_constrain = bounded_int_module.extern_function_id("bounded_int_constrain");
let type_value_ranges = OrderedHashMap::from_iter(
[
("u8", TypeRange::closed(0, u8::MAX)),
Expand Down Expand Up @@ -527,6 +568,9 @@ impl ConstFoldingLibfuncInfo {
iadd_fns,
isub_fns,
wide_mul_fns,
bounded_int_add,
bounded_int_sub,
bounded_int_constrain,
storage_access_module: storage_access_module.id,
type_value_ranges,
}
Expand Down
Loading

0 comments on commit e4136a4

Please sign in to comment.