diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..a0f1a930 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2024-01-04" diff --git a/stwo_cairo_prover/src/components/memory/mod.rs b/stwo_cairo_prover/src/components/memory/mod.rs index 7c60567a..2dad05d8 100644 --- a/stwo_cairo_prover/src/components/memory/mod.rs +++ b/stwo_cairo_prover/src/components/memory/mod.rs @@ -2,7 +2,7 @@ pub mod component; pub mod component_prover; #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::collections::BTreeMap; use component::{ diff --git a/stwo_cairo_prover/src/components/ret_opcode/component.rs b/stwo_cairo_prover/src/components/ret_opcode/component.rs index e7274d0f..35ba8bef 100644 --- a/stwo_cairo_prover/src/components/ret_opcode/component.rs +++ b/stwo_cairo_prover/src/components/ret_opcode/component.rs @@ -32,7 +32,7 @@ pub struct RetOpcode { impl Component for RetOpcode { fn n_constraints(&self) -> usize { - todo!() + 3 } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -59,12 +59,53 @@ impl Component for RetOpcode { fn evaluate_constraint_quotients_at_point( &self, - _point: CirclePoint, - _mask: &TreeVec>>, - _evaluation_accumulator: &mut PointEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, + point: CirclePoint, + mask: &TreeVec>>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + interaction_elements: &InteractionElements, + lookup_values: &LookupValues, ) { - todo!() + // First lookup point boundary constraint. + let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset; + let (alpha, z) = ( + interaction_elements[MEMORY_ALPHA], + interaction_elements[MEMORY_Z], + ); + let value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); + let address = mask[BASE_TRACE][0][0]; + let value_at_pc = SecureField::one(); + let mut address_and_value = [SecureField::zero(); N_M31_IN_FELT252 + 1]; + address_and_value[0] = address; + address_and_value[1] = value_at_pc; + + let numerator = + value * shifted_secure_combination(&address_and_value, alpha, z) - SecureField::one(); + let denom = point_vanishing(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(numerator / denom); + + // Last lookup point boundary constraint. + let lookup_value = SecureField::from_m31( + lookup_values[RET_LOOKUP_VALUE_0], + lookup_values[RET_LOOKUP_VALUE_1], + lookup_values[RET_LOOKUP_VALUE_2], + lookup_values[RET_LOOKUP_VALUE_3], + ); + let numerator = value - lookup_value; + let denom = point_vanishing( + constraint_zero_domain.at(constraint_zero_domain.size() - 1), + point, + ); + evaluation_accumulator.accumulate(numerator / denom); + + // Lookup step constraint. + let prev_value = + SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1])); + let numerator = (value - prev_value) + * shifted_secure_combination(&address_and_value, alpha, z) + - SecureField::one(); + let denom = coset_vanishing(constraint_zero_domain, point) + / point_excluder(constraint_zero_domain.at(0), point); + evaluation_accumulator.accumulate(numerator / denom); } } diff --git a/stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs b/stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs index 79fa96c5..a5b9cc2d 100644 --- a/stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs +++ b/stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs @@ -1,30 +1,154 @@ #![allow(unused_imports)] +use std::collections::BTreeMap; + +use itertools::izip; use num_traits::identities::Zero; +use num_traits::One; use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator; use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; use stwo_prover::core::backend::{Column, CpuBackend}; -use stwo_prover::core::constraints::coset_vanishing; +use stwo_prover::core::constraints::{coset_vanishing, point_excluder}; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::poly::circle::CanonicCoset; -use stwo_prover::core::utils::bit_reverse; +use stwo_prover::core::utils::{ + bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index, + shifted_secure_combination, +}; use stwo_prover::core::{InteractionElements, LookupValues}; +use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE}; -use super::component::RetOpcode; +use super::component::{ + RetOpcode, RET_LOOKUP_VALUE_0, RET_LOOKUP_VALUE_1, RET_LOOKUP_VALUE_2, RET_LOOKUP_VALUE_3, +}; +use crate::components::memory::component::{MEMORY_ALPHA, MEMORY_Z, N_M31_IN_FELT252}; impl ComponentProver for RetOpcode { #[allow(unused_parens)] fn evaluate_constraint_quotients_on_domain( &self, - _trace: &ComponentTrace<'_, CpuBackend>, - _evaluation_accumulator: &mut DomainEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, + trace: &ComponentTrace<'_, CpuBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + interaction_elements: &InteractionElements, + lookup_values: &LookupValues, ) { - todo!() + let max_constraint_degree = self.max_constraint_log_degree_bound(); + let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain(); + let trace_evals = &trace.evals; + let zero_domain = CanonicCoset::new(self.log_n_instances).coset; + let [mut accum] = + evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]); + + // TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions. + let first_point_denom_inverses = + point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); + let last_point_denom_inverses = point_vanish_denominator_inverses( + trace_eval_domain, + zero_domain.at(zero_domain.size() - 1), + ); + let mut step_denoms = vec![]; + for point in trace_eval_domain.iter() { + step_denoms.push( + coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point), + ); + } + bit_reverse(&mut step_denoms); + let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; + BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses); + + let (alpha, z) = ( + interaction_elements[MEMORY_ALPHA], + interaction_elements[MEMORY_Z], + ); + + let lookup_value = SecureField::from_m31( + lookup_values[RET_LOOKUP_VALUE_0], + lookup_values[RET_LOOKUP_VALUE_1], + lookup_values[RET_LOOKUP_VALUE_2], + lookup_values[RET_LOOKUP_VALUE_3], + ); + + for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!( + first_point_denom_inverses, + last_point_denom_inverses, + step_denom_inverses, + ) + .enumerate() + { + // Value = InteractionPoly(i); + let value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][i] + })); + + // PrevValue = InteractionPoly(i - g); + let prev_index = previous_bit_reversed_circle_domain_index( + i, + zero_domain.log_size, + trace_eval_domain.log_size(), + ); + let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| { + trace_evals[INTERACTION_TRACE][j][prev_index] + })); + + // PC Column. + let address = trace_evals[BASE_TRACE][0][i]; + let value_at_pc = BaseField::one(); + let mut address_and_value = [BaseField::zero(); N_M31_IN_FELT252 + 1]; + address_and_value[0] = address; + address_and_value[1] = value_at_pc; + + // TODO(Ohad): add remaining lookup constraints. + + let first_point_numerator = accum.random_coeff_powers[2] + * (value * shifted_secure_combination(&address_and_value, alpha, z) + - SecureField::one()); + let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value); + let step_numerator = accum.random_coeff_powers[0] + * ((value - prev_value) * shifted_secure_combination(&address_and_value, alpha, z) + - SecureField::from_u32_unchecked(1, 0, 0, 0)); + accum.accumulate( + i, + first_point_numerator * first_point_denom_inverse + + last_point_numerator * last_point_denom_inverse + + step_numerator * step_denom_inverse, + ); + } } - fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { - todo!() + + fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + let domain = CanonicCoset::new(self.log_n_instances); + let trace_poly = &trace.polys[INTERACTION_TRACE]; + let values = BTreeMap::from_iter([ + ( + RET_LOOKUP_VALUE_0.to_string(), + trace_poly[0] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ( + RET_LOOKUP_VALUE_1.to_string(), + trace_poly[1] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ( + RET_LOOKUP_VALUE_2.to_string(), + trace_poly[2] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ( + RET_LOOKUP_VALUE_3.to_string(), + trace_poly[3] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) + .try_into() + .unwrap(), + ), + ]); + LookupValues::new(values) } } diff --git a/stwo_cairo_prover/src/components/ret_opcode/mod.rs b/stwo_cairo_prover/src/components/ret_opcode/mod.rs index 83a49781..68a96e76 100644 --- a/stwo_cairo_prover/src/components/ret_opcode/mod.rs +++ b/stwo_cairo_prover/src/components/ret_opcode/mod.rs @@ -7,6 +7,7 @@ pub mod trace; pub(crate) mod tests { use itertools::Itertools; use num_traits::{One, Zero}; + use stwo_prover::core::backend::CpuBackend; use stwo_prover::core::channel::{Blake2sChannel, Channel}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::qm31::SecureField; @@ -14,11 +15,16 @@ pub(crate) mod tests { use stwo_prover::core::utils::shifted_secure_combination; use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; use stwo_prover::core::vcs::hasher::Hasher; - use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier}; + use stwo_prover::trace_generation::{commit_and_prove, commit_and_verify, AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; - use crate::components::memory::component::{MEMORY_ALPHA, MEMORY_Z, N_M31_IN_FELT252}; + use crate::components::memory::component::{MemoryTraceGenerator, MEMORY_ALPHA, MEMORY_COMPONENT_ID, MEMORY_Z, N_M31_IN_FELT252}; use crate::components::ret_opcode::test_utils::TestRetAirGenerator; + use super::component::RET_COMPONENT_ID; + use super::test_utils::TestAir; + use super::trace::RetOpcodeCpuTraceGenerator; + + #[test] fn test_ret_interaction_trace() { let mut air_generator = TestRetAirGenerator::new(); @@ -48,4 +54,45 @@ pub(crate) mod tests { assert_eq!(logup_sum, expected_logup_sum); } + + #[test] + fn test_ret_lookup_values() { + let mut air_generator = TestRetAirGenerator::new(); + let trace = air_generator.write_trace(); + let prover_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + + let interaction_trace = + air_generator.interact(&trace, &air_generator.interaction_elements(prover_channel)); + let (ret_interaction, memory_interaction) = interaction_trace.split_at(4); + let ret_lookup_value = + SecureField::from_m31_array(std::array::from_fn(|i| ret_interaction[i].values[1])); + let memory_lookup_value = + SecureField::from_m31_array(std::array::from_fn(|i| memory_interaction[i].values[1])); + + assert_eq!(ret_lookup_value, memory_lookup_value); + } + + #[test] + fn test_ret_proof() { + let mut air_generator = TestRetAirGenerator::new(); + let trace = air_generator.write_trace(); + let prover_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let proof = commit_and_prove::(&air_generator, prover_channel, trace).unwrap(); + + let verifier_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let air = TestAir { + memory_component: air_generator + .registry + .get_generator::(MEMORY_COMPONENT_ID) + .component(), + ret_component: air_generator + .registry + .get_generator::(RET_COMPONENT_ID) + .component(), + }; + commit_and_verify(proof, &air, verifier_channel).unwrap(); + } }