From 2f5a94c407dcecd54120ed363297390a00987fcf Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Thu, 18 Jul 2024 11:46:35 +0300 Subject: [PATCH] Create CairoAir. --- stwo_cairo_prover/Cargo.lock | 2 +- stwo_cairo_prover/Cargo.toml | 2 +- stwo_cairo_prover/src/air/cairo.rs | 211 ++++++++++++++++++ stwo_cairo_prover/src/air/mod.rs | 1 + .../src/components/memory/component.rs | 130 +++++++---- .../src/components/memory/component_prover.rs | 93 ++++---- .../src/components/memory/mod.rs | 32 +-- .../components/range_check_unit/component.rs | 13 +- .../range_check_unit/component_prover.rs | 14 +- .../src/components/ret_opcode/test_utils.rs | 17 +- stwo_cairo_prover/src/main.rs | 3 + stwo_cairo_prover/src/test_utils.rs | 19 ++ 12 files changed, 401 insertions(+), 136 deletions(-) create mode 100644 stwo_cairo_prover/src/air/cairo.rs create mode 100644 stwo_cairo_prover/src/air/mod.rs create mode 100644 stwo_cairo_prover/src/test_utils.rs diff --git a/stwo_cairo_prover/Cargo.lock b/stwo_cairo_prover/Cargo.lock index 2ac0ad07..3ef9ef16 100644 --- a/stwo_cairo_prover/Cargo.lock +++ b/stwo_cairo_prover/Cargo.lock @@ -553,7 +553,7 @@ dependencies = [ [[package]] name = "stwo-prover" version = "0.1.1" -source = "git+https://github.com/starkware-libs/stwo?rev=7a0bddee#7a0bddeec1a847654dbecff5df37bf5a5891f216" +source = "git+https://github.com/starkware-libs/stwo?rev=7614d354#7614d354a0083d294c647bbdad2a252ae7ff8cb1" dependencies = [ "blake2", "blake3", diff --git a/stwo_cairo_prover/Cargo.toml b/stwo_cairo_prover/Cargo.toml index 8ec2e947..4e6dda1e 100644 --- a/stwo_cairo_prover/Cargo.toml +++ b/stwo_cairo_prover/Cargo.toml @@ -7,4 +7,4 @@ edition = "2021" itertools = "0.12.0" num-traits = "0.2.17" # TODO(ShaharS): take stwo version from the source repository. -stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "7a0bddee" } +stwo-prover = { git = "https://github.com/starkware-libs/stwo", rev = "7614d354" } diff --git a/stwo_cairo_prover/src/air/cairo.rs b/stwo_cairo_prover/src/air/cairo.rs new file mode 100644 index 00000000..07494bad --- /dev/null +++ b/stwo_cairo_prover/src/air/cairo.rs @@ -0,0 +1,211 @@ +use std::cmp::max; +use std::collections::BTreeMap; + +use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver}; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::channel::{Blake2sChannel, Channel}; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::prover::VerificationError; +use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; +use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; +use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; + +use crate::components::memory::component::{ + MemoryComponent, MemoryTraceGenerator, MAX_MEMORY_CELL_VALUE, MEMORY_ALPHA, + MEMORY_COMPONENT_ID, MEMORY_LOOKUP_VALUE_0, MEMORY_LOOKUP_VALUE_1, MEMORY_LOOKUP_VALUE_2, + MEMORY_LOOKUP_VALUE_3, MEMORY_Z, N_MEMORY_COLUMNS, +}; +use crate::components::range_check_unit::component::{ + RangeCheckUnitComponent, RangeCheckUnitTraceGenerator, N_RC_COLUMNS, RC_COMPONENT_ID, + RC_LOOKUP_VALUE_0, RC_LOOKUP_VALUE_1, RC_LOOKUP_VALUE_2, RC_LOOKUP_VALUE_3, RC_Z, +}; + +struct CairoAirGenerator { + pub registry: ComponentGenerationRegistry, +} + +impl CairoAirGenerator { + #[allow(dead_code)] + pub fn new(path: String) -> Self { + let mut registry = ComponentGenerationRegistry::default(); + registry.register(MEMORY_COMPONENT_ID, MemoryTraceGenerator::new(path)); + registry.register( + RC_COMPONENT_ID, + RangeCheckUnitTraceGenerator::new(MAX_MEMORY_CELL_VALUE), + ); + Self { registry } + } +} + +impl AirTraceVerifier for CairoAirGenerator { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(3); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), + ])) + } +} + +impl AirTraceGenerator for CairoAirGenerator { + fn write_trace(&mut self) -> Vec> { + let mut trace = Vec::with_capacity(N_MEMORY_COLUMNS + N_RC_COLUMNS); + trace.extend(MemoryTraceGenerator::write_trace( + MEMORY_COMPONENT_ID, + &mut self.registry, + )); + trace.extend(RangeCheckUnitTraceGenerator::write_trace( + RC_COMPONENT_ID, + &mut self.registry, + )); + trace + } + + fn interact( + &self, + trace: &ColumnVec>, + elements: &InteractionElements, + ) -> Vec> { + let mut interaction_trace = Vec::new(); + let trace_iter = &mut trace.iter(); + let memory_generator = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + interaction_trace.extend( + memory_generator + .write_interaction_trace(&trace_iter.take(N_MEMORY_COLUMNS).collect(), elements), + ); + let rc_generator = self + .registry + .get_generator::(RC_COMPONENT_ID); + interaction_trace.extend( + rc_generator + .write_interaction_trace(&trace_iter.take(N_RC_COLUMNS).collect(), elements), + ); + interaction_trace + } + + fn to_air_prover(&self) -> impl AirProver { + let memory = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + let range_check_unit = self + .registry + .get_generator::(RC_COMPONENT_ID); + CairoAir { + memory: memory.component(), + range_check_unit: range_check_unit.component(), + } + } + + fn composition_log_degree_bound(&self) -> u32 { + let memory = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + let range_check_unit = self + .registry + .get_generator::(RC_COMPONENT_ID); + max( + memory.component().max_constraint_log_degree_bound(), + range_check_unit + .component() + .max_constraint_log_degree_bound(), + ) + } +} + +#[derive(Clone)] +pub struct CairoAir { + pub memory: MemoryComponent, + pub range_check_unit: RangeCheckUnitComponent, +} + +impl Air for CairoAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.memory, &self.range_check_unit] + } + + fn verify_lookups(&self, lookup_values: &LookupValues) -> Result<(), VerificationError> { + let memory_rc_lookup_value = SecureField::from_m31( + lookup_values[MEMORY_LOOKUP_VALUE_0], + lookup_values[MEMORY_LOOKUP_VALUE_1], + lookup_values[MEMORY_LOOKUP_VALUE_2], + lookup_values[MEMORY_LOOKUP_VALUE_3], + ); + let rc_lookup_value = SecureField::from_m31( + lookup_values[RC_LOOKUP_VALUE_0], + lookup_values[RC_LOOKUP_VALUE_1], + lookup_values[RC_LOOKUP_VALUE_2], + lookup_values[RC_LOOKUP_VALUE_3], + ); + if memory_rc_lookup_value != rc_lookup_value { + return Err(VerificationError::InvalidLookup( + "Memory and RC".to_string(), + )); + } + Ok(()) + } +} + +impl AirProver for CairoAir { + fn prover_components(&self) -> Vec<&dyn ComponentProver> { + vec![&self.memory, &self.range_check_unit] + } +} + +impl AirTraceVerifier for CairoAir { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(3); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), + ])) + } +} + +#[cfg(test)] +mod tests { + use stwo_prover::core::backend::CpuBackend; + use stwo_prover::core::channel::{Blake2sChannel, Channel}; + use stwo_prover::core::fields::m31::BaseField; + use stwo_prover::core::fields::IntoSlice; + use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; + use stwo_prover::core::vcs::hasher::Hasher; + use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; + use stwo_prover::trace_generation::{ + commit_and_prove, commit_and_verify, AirTraceGenerator, ComponentTraceGenerator, + }; + + use super::*; + use crate::test_utils::register_test_memory; + + #[test] + fn test_air() { + let mut registry = ComponentGenerationRegistry::default(); + register_test_memory(&mut registry); + let mut air = CairoAirGenerator { registry }; + let trace = air.write_trace(); + let prover_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let proof = commit_and_prove::(&air, prover_channel, trace).unwrap(); + + let verifier_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let air = CairoAir { + memory: air + .registry + .get_generator::(MEMORY_COMPONENT_ID) + .component(), + range_check_unit: air + .registry + .get_generator::(RC_COMPONENT_ID) + .component(), + }; + commit_and_verify(proof, &air, verifier_channel).unwrap(); + } +} diff --git a/stwo_cairo_prover/src/air/mod.rs b/stwo_cairo_prover/src/air/mod.rs new file mode 100644 index 00000000..c6c62303 --- /dev/null +++ b/stwo_cairo_prover/src/air/mod.rs @@ -0,0 +1 @@ +pub mod cairo; diff --git a/stwo_cairo_prover/src/components/memory/component.rs b/stwo_cairo_prover/src/components/memory/component.rs index 06f40f5b..6e511f2f 100644 --- a/stwo_cairo_prover/src/components/memory/component.rs +++ b/stwo_cairo_prover/src/components/memory/component.rs @@ -1,11 +1,11 @@ use itertools::{zip_eq, Itertools}; -use num_traits::Zero; +use num_traits::{One, Zero}; use stwo_prover::core::air::accumulation::PointEvaluationAccumulator; use stwo_prover::core::air::mask::fixed_mask_points; use stwo_prover::core::air::Component; use stwo_prover::core::backend::CpuBackend; use stwo_prover::core::circle::CirclePoint; -use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; +use stwo_prover::core::constraints::coset_vanishing; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE}; @@ -22,6 +22,10 @@ use stwo_prover::trace_generation::{ ComponentGen, ComponentTraceGenerator, BASE_TRACE, INTERACTION_TRACE, }; +use crate::components::range_check_unit::component::{ + RangeCheckUnitTraceGenerator, RC_COMPONENT_ID, RC_Z, +}; + pub const MEMORY_ALPHA: &str = "MEMORY_ALPHA"; pub const MEMORY_Z: &str = "MEMORY_Z"; pub const MEMORY_COMPONENT_ID: &str = "MEMORY"; @@ -29,13 +33,19 @@ pub const MEMORY_LOOKUP_VALUE_0: &str = "MEMORY_LOOKUP_0"; pub const MEMORY_LOOKUP_VALUE_1: &str = "MEMORY_LOOKUP_1"; pub const MEMORY_LOOKUP_VALUE_2: &str = "MEMORY_LOOKUP_2"; pub const MEMORY_LOOKUP_VALUE_3: &str = "MEMORY_LOOKUP_3"; +pub const MEMORY_RC_LOOKUP_VALUE_0: &str = "MEMORY_RC_LOOKUP_0"; +pub const MEMORY_RC_LOOKUP_VALUE_1: &str = "MEMORY_RC_LOOKUP_1"; +pub const MEMORY_RC_LOOKUP_VALUE_2: &str = "MEMORY_RC_LOOKUP_2"; +pub const MEMORY_RC_LOOKUP_VALUE_3: &str = "MEMORY_RC_LOOKUP_3"; +pub const MAX_MEMORY_CELL_VALUE: usize = 1 << 9; pub const N_M31_IN_FELT252: usize = 28; pub const MULTIPLICITY_COLUMN_OFFSET: usize = N_M31_IN_FELT252 + 1; // TODO(AlonH): Make memory size configurable. pub const LOG_MEMORY_ADDRESS_BOUND: u32 = 3; pub const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND; -pub const N_MEMORY_COLUMNS: usize = N_M31_IN_FELT252 + 2; +// Addresses, M31 values, and multiplicities. +pub const N_MEMORY_COLUMNS: usize = 1 + N_M31_IN_FELT252 + 1; /// Addresses are continuous and start from 0. /// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value containing 9 bits). @@ -52,7 +62,7 @@ pub struct MemoryComponent { impl MemoryComponent { pub const fn n_columns(&self) -> usize { - N_M31_IN_FELT252 + 2 + N_MEMORY_COLUMNS } } @@ -107,11 +117,19 @@ impl ComponentTraceGenerator for MemoryTraceGenerator { // here or add constraints to the column here. trace[0][i] = BaseField::from_u32_unchecked(i as u32); for (j, value) in values.iter().enumerate() { - trace[j + 1][i] = BaseField::from_u32_unchecked(value.0); + trace[j + 1][i] = *value; } trace[MULTIPLICITY_COLUMN_OFFSET][i] = BaseField::from_u32_unchecked(*multiplicity); } + let rc_generator = + registry.get_generator_mut::(RC_COMPONENT_ID); + for column in trace[1..MULTIPLICITY_COLUMN_OFFSET].iter() { + column + .iter() + .for_each(|input| rc_generator.add_inputs(input)); + } + let domain = CanonicCoset::new(LOG_MEMORY_ADDRESS_BOUND).circle_domain(); trace .into_iter() @@ -126,7 +144,7 @@ impl ComponentTraceGenerator for MemoryTraceGenerator { ) -> ColumnVec> { let interaction_trace_domain = trace[0].domain; let domain_size = interaction_trace_domain.size(); - let (alpha, z) = (elements[MEMORY_ALPHA], elements[MEMORY_Z]); + let (alpha, z, rc_z) = (elements[MEMORY_ALPHA], elements[MEMORY_Z], elements[RC_Z]); let addresses_and_values: Vec<[BaseField; N_M31_IN_FELT252 + 1]> = (0 ..MEMORY_ADDRESS_BOUND) @@ -138,8 +156,8 @@ impl ComponentTraceGenerator for MemoryTraceGenerator { .collect_vec(); let mut denom_inverses = vec![SecureField::zero(); domain_size]; SecureField::batch_inverse(&denoms, &mut denom_inverses); - let mut logup_values = vec![SecureField::zero(); domain_size]; - let mut last = SecureField::zero(); + let mut logup_values = vec![vec![SecureField::zero(); domain_size]; 1 + N_M31_IN_FELT252]; + let mut column_last = SecureField::zero(); let log_size = interaction_trace_domain.log_size(); for i in 0..domain_size { let index = bit_reverse_index( @@ -147,13 +165,32 @@ impl ComponentTraceGenerator for MemoryTraceGenerator { log_size, ); let interaction_value = - last + (denom_inverses[index] * trace[MULTIPLICITY_COLUMN_OFFSET].values[index]); - logup_values[index] = interaction_value; - last = interaction_value; + denom_inverses[index] * trace[MULTIPLICITY_COLUMN_OFFSET].values[index]; + logup_values[0][index] = interaction_value; + let mut row_last = interaction_value; + + // TODO(AlonH): Batch inverse. + for j in 1..N_M31_IN_FELT252 { + let rc_interaction_value = row_last + (rc_z - trace[j].values[index]).inverse(); + logup_values[j][index] = rc_interaction_value; + row_last = rc_interaction_value; + } + + let final_interaction_value = + column_last + row_last + (rc_z - trace[N_M31_IN_FELT252].values[index]).inverse(); + logup_values[N_M31_IN_FELT252][index] = final_interaction_value; + column_last = final_interaction_value; } - let secure_column: SecureColumn = logup_values.into_iter().collect(); - secure_column - .columns + let interaction_columns: Vec> = logup_values + .into_iter() + .flat_map(|values| { + values + .into_iter() + .collect::>() + .columns + }) + .collect_vec(); + interaction_columns .into_iter() .map(|eval| CircleEvaluation::new(interaction_trace_domain, eval)) .collect_vec() @@ -168,7 +205,7 @@ impl ComponentTraceGenerator for MemoryTraceGenerator { impl Component for MemoryComponent { fn n_constraints(&self) -> usize { - 3 + N_M31_IN_FELT252 } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -178,7 +215,7 @@ impl Component for MemoryComponent { fn trace_log_degree_bounds(&self) -> TreeVec> { TreeVec::new(vec![ vec![self.log_n_rows; self.n_columns()], - vec![self.log_n_rows; SECURE_EXTENSION_DEGREE], + vec![self.log_n_rows; SECURE_EXTENSION_DEGREE * (1 + N_M31_IN_FELT252)], ]) } @@ -189,7 +226,10 @@ impl Component for MemoryComponent { let domain = CanonicCoset::new(self.log_n_rows); TreeVec::new(vec![ fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point), - vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE], + vec![ + vec![point, point - domain.step().into_ef()]; + SECURE_EXTENSION_DEGREE * (1 + N_M31_IN_FELT252) + ], ]) } @@ -201,48 +241,50 @@ impl Component for MemoryComponent { interaction_elements: &InteractionElements, lookup_values: &LookupValues, ) { - // First lookup point boundary constraint. + // TODO(AlonH): Add constraints to the range check interaction columns. let constraint_zero_domain = CanonicCoset::new(self.log_n_rows).coset; - let (alpha, z) = ( + let (alpha, z, rc_z) = ( interaction_elements[MEMORY_ALPHA], interaction_elements[MEMORY_Z], + interaction_elements[RC_Z], ); + let value = SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0])); let address_and_value: [SecureField; N_M31_IN_FELT252 + 1] = std::array::from_fn(|i| mask[BASE_TRACE][i][0]); - let numerator = value * shifted_secure_combination(&address_and_value, alpha, z) - - mask[BASE_TRACE][MULTIPLICITY_COLUMN_OFFSET][0]; - let denom = point_vanishing(constraint_zero_domain.at(0), point); - evaluation_accumulator.accumulate(numerator / denom); - - // Last lookup point boundary constraint. - let lookup_value = SecureField::from_m31( + let _lookup_value = SecureField::from_m31( lookup_values[MEMORY_LOOKUP_VALUE_0], lookup_values[MEMORY_LOOKUP_VALUE_1], lookup_values[MEMORY_LOOKUP_VALUE_2], lookup_values[MEMORY_LOOKUP_VALUE_3], ); - let numerator = value - lookup_value; - let denom = point_vanishing(constraint_zero_domain.at(1), point); - evaluation_accumulator.accumulate(numerator / denom); - // Lookup step constraint. - let prev_value = - SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1])); - let numerator = (value - prev_value) - * shifted_secure_combination(&address_and_value, alpha, z) + // First interaction column constraint. + let numerator = value * shifted_secure_combination(&address_and_value, alpha, z) - mask[BASE_TRACE][MULTIPLICITY_COLUMN_OFFSET][0]; - let denom = coset_vanishing(constraint_zero_domain, point) - / point_excluder(constraint_zero_domain.at(0), point); + let denom = coset_vanishing(constraint_zero_domain, point); evaluation_accumulator.accumulate(numerator / denom); + + // Middle interaction columns constraints. + let mut prev_row_value = value; + #[allow(clippy::needless_range_loop)] + for i in 1..N_M31_IN_FELT252 { + let value = SecureField::from_partial_evals(std::array::from_fn(|j| { + mask[INTERACTION_TRACE][i * SECURE_EXTENSION_DEGREE + j][0] + })); + let numerator = + (value - prev_row_value) * (rc_z - address_and_value[i]) - BaseField::one(); + evaluation_accumulator.accumulate(numerator / denom); + prev_row_value = value; + } } } #[cfg(test)] mod tests { use super::*; - use crate::components::memory::tests::register_test_memory; + use crate::test_utils::register_test_memory; #[test] fn test_memory_trace() { @@ -251,8 +293,14 @@ mod tests { let trace = MemoryTraceGenerator::write_trace(MEMORY_COMPONENT_ID, &mut registry); let alpha = SecureField::from_u32_unchecked(1, 2, 3, 117); let z = SecureField::from_u32_unchecked(2, 3, 4, 118); + let rc_z = SecureField::from_u32_unchecked(3, 4, 5, 119); let interaction_elements = InteractionElements::new( - [(MEMORY_ALPHA.to_string(), alpha), (MEMORY_Z.to_string(), z)].into(), + [ + (MEMORY_ALPHA.to_string(), alpha), + (MEMORY_Z.to_string(), z), + (RC_Z.to_string(), rc_z), + ] + .into(), ); let interaction_trace = registry .get_generator::(MEMORY_COMPONENT_ID) @@ -267,9 +315,13 @@ mod tests { alpha, z, ); + #[allow(clippy::needless_range_loop)] + for j in 1..(N_M31_IN_FELT252 + 1) { + expected_logup_sum += (rc_z - trace[j].values[i]).inverse(); + } } let logup_sum = - SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[j][1])); + SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[112 + j][1])); assert_eq!(logup_sum, expected_logup_sum); } diff --git a/stwo_cairo_prover/src/components/memory/component_prover.rs b/stwo_cairo_prover/src/components/memory/component_prover.rs index 26e93b67..9433f73c 100644 --- a/stwo_cairo_prover/src/components/memory/component_prover.rs +++ b/stwo_cairo_prover/src/components/memory/component_prover.rs @@ -1,19 +1,16 @@ use std::collections::BTreeMap; -use itertools::izip; -use num_traits::Zero; +use num_traits::{One, Zero}; use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator; use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; use stwo_prover::core::backend::CpuBackend; -use stwo_prover::core::constraints::{coset_vanishing, point_excluder}; +use stwo_prover::core::constraints::coset_vanishing; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::poly::circle::CanonicCoset; -use stwo_prover::core::utils::{ - bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index, - shifted_secure_combination, -}; +use stwo_prover::core::utils::{bit_reverse, shifted_secure_combination}; use stwo_prover::core::{InteractionElements, LookupValues}; use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE}; @@ -22,6 +19,7 @@ use super::component::{ MEMORY_LOOKUP_VALUE_2, MEMORY_LOOKUP_VALUE_3, MEMORY_Z, MULTIPLICITY_COLUMN_OFFSET, N_M31_IN_FELT252, }; +use crate::components::range_check_unit::component::RC_Z; impl ComponentProver for MemoryComponent { fn evaluate_constraint_quotients_on_domain( @@ -38,65 +36,50 @@ impl ComponentProver for MemoryComponent { let [mut accum] = evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]); - // TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions. - let first_point_denom_inverses = - point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); - let last_point_denom_inverses = - point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1)); - let mut step_denoms = vec![]; + let mut denoms = vec![]; for point in trace_eval_domain.iter() { - step_denoms.push( - coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point), - ); + denoms.push(coset_vanishing(zero_domain, point)); } - bit_reverse(&mut step_denoms); - let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; - BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses); - let (alpha, z) = ( + bit_reverse(&mut denoms); + let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; + BaseField::batch_inverse(&denoms, &mut denom_inverses); + let (alpha, z, rc_z) = ( interaction_elements[MEMORY_ALPHA], interaction_elements[MEMORY_Z], + interaction_elements[RC_Z], ); - let lookup_value = SecureField::from_m31( + + let _lookup_value = SecureField::from_m31( lookup_values[MEMORY_LOOKUP_VALUE_0], lookup_values[MEMORY_LOOKUP_VALUE_1], lookup_values[MEMORY_LOOKUP_VALUE_2], lookup_values[MEMORY_LOOKUP_VALUE_3], ); - for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!( - first_point_denom_inverses, - last_point_denom_inverses, - step_denom_inverses, - ) - .enumerate() - { + for (i, denom_inverse) in denom_inverses.iter().enumerate() { let value = SecureField::from_m31_array(std::array::from_fn(|j| { trace_evals[INTERACTION_TRACE][j][i] })); - let prev_index = previous_bit_reversed_circle_domain_index( - i, - zero_domain.log_size, - trace_eval_domain.log_size(), - ); - let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| { - trace_evals[INTERACTION_TRACE][j][prev_index] - })); let address_and_value: [BaseField; N_M31_IN_FELT252 + 1] = std::array::from_fn(|j| trace_evals[BASE_TRACE][j][i]); - let first_point_numerator = accum.random_coeff_powers[2] + // First interaction column constraint. + let mut numerator = accum.random_coeff_powers[N_M31_IN_FELT252 - 1] * (value * shifted_secure_combination(&address_and_value, alpha, z) - trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN_OFFSET][i]); - let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value); - let step_numerator = accum.random_coeff_powers[0] - * ((value - prev_value) * shifted_secure_combination(&address_and_value, alpha, z) - - trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN_OFFSET][i]); - accum.accumulate( - i, - first_point_numerator * first_point_denom_inverse - + last_point_numerator * last_point_denom_inverse - + step_numerator * step_denom_inverse, - ); + // Middle interaction columns constraints. + let mut prev_row_value = value; + #[allow(clippy::needless_range_loop)] + for j in 1..N_M31_IN_FELT252 { + let value = SecureField::from_m31_array(std::array::from_fn(|k| { + trace_evals[INTERACTION_TRACE][j * SECURE_EXTENSION_DEGREE + k][i] + })); + numerator += accum.random_coeff_powers[N_M31_IN_FELT252 - j - 1] + * ((value - prev_row_value) * (rc_z - address_and_value[j]) - BaseField::one()); + prev_row_value = value; + } + + accum.accumulate(i, numerator * *denom_inverse); } } @@ -106,29 +89,29 @@ impl ComponentProver for MemoryComponent { let values = BTreeMap::from_iter([ ( MEMORY_LOOKUP_VALUE_0.to_string(), - trace_poly[0] - .eval_at_point(domain.at(1).into_ef()) + trace_poly[4 * N_M31_IN_FELT252] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), ( MEMORY_LOOKUP_VALUE_1.to_string(), - trace_poly[1] - .eval_at_point(domain.at(1).into_ef()) + trace_poly[4 * N_M31_IN_FELT252 + 1] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), ( MEMORY_LOOKUP_VALUE_2.to_string(), - trace_poly[2] - .eval_at_point(domain.at(1).into_ef()) + trace_poly[4 * N_M31_IN_FELT252 + 2] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), ( MEMORY_LOOKUP_VALUE_3.to_string(), - trace_poly[3] - .eval_at_point(domain.at(1).into_ef()) + trace_poly[4 * N_M31_IN_FELT252 + 3] + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), diff --git a/stwo_cairo_prover/src/components/memory/mod.rs b/stwo_cairo_prover/src/components/memory/mod.rs index 7c60567a..84fab0e6 100644 --- a/stwo_cairo_prover/src/components/memory/mod.rs +++ b/stwo_cairo_prover/src/components/memory/mod.rs @@ -26,30 +26,8 @@ mod tests { }; use super::*; - - pub fn register_test_memory(registry: &mut ComponentGenerationRegistry) { - registry.register( - MEMORY_COMPONENT_ID, - MemoryTraceGenerator::new("".to_string()), - ); - vec![ - vec![BaseField::from_u32_unchecked(0); 3], - vec![BaseField::from_u32_unchecked(1); 1], - vec![BaseField::from_u32_unchecked(2); 2], - vec![BaseField::from_u32_unchecked(3); 5], - vec![BaseField::from_u32_unchecked(4); 10], - vec![BaseField::from_u32_unchecked(5); 1], - vec![BaseField::from_u32_unchecked(6); 0], - vec![BaseField::from_u32_unchecked(7); 1], - ] - .into_iter() - .flatten() - .for_each(|input| { - registry - .get_generator_mut::(MEMORY_COMPONENT_ID) - .add_inputs(&input); - }); - } + use crate::components::range_check_unit::component::RC_Z; + use crate::test_utils::register_test_memory; struct TestAirGenerator { pub registry: ComponentGenerationRegistry, @@ -65,10 +43,11 @@ mod tests { impl AirTraceVerifier for TestAirGenerator { fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { - let elements = channel.draw_felts(2); + let elements = channel.draw_felts(3); InteractionElements::new(BTreeMap::from_iter(vec![ (MEMORY_ALPHA.to_string(), elements[0]), (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), ])) } } @@ -133,10 +112,11 @@ mod tests { impl AirTraceVerifier for TestAir { fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { - let elements = channel.draw_felts(2); + let elements = channel.draw_felts(3); InteractionElements::new(BTreeMap::from_iter(vec![ (MEMORY_ALPHA.to_string(), elements[0]), (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), ])) } } diff --git a/stwo_cairo_prover/src/components/range_check_unit/component.rs b/stwo_cairo_prover/src/components/range_check_unit/component.rs index b98b7a12..a6180aa5 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component.rs @@ -27,6 +27,8 @@ pub const RC_LOOKUP_VALUE_1: &str = "RC_UNIT_LOOKUP_1"; pub const RC_LOOKUP_VALUE_2: &str = "RC_UNIT_LOOKUP_2"; pub const RC_LOOKUP_VALUE_3: &str = "RC_UNIT_LOOKUP_3"; +pub const N_RC_COLUMNS: usize = 2; + #[derive(Clone)] pub struct RangeCheckUnitComponent { pub log_n_rows: u32, @@ -48,7 +50,7 @@ impl Component for RangeCheckUnitComponent { fn trace_log_degree_bounds(&self) -> TreeVec> { TreeVec::new(vec![ - vec![self.log_n_rows; 2], + vec![self.log_n_rows; N_RC_COLUMNS], vec![self.log_n_rows; SECURE_EXTENSION_DEGREE], ]) } @@ -59,7 +61,7 @@ impl Component for RangeCheckUnitComponent { ) -> TreeVec>>> { let domain = CanonicCoset::new(self.log_n_rows); TreeVec::new(vec![ - fixed_mask_points(&vec![vec![0_usize]; 2], point), + fixed_mask_points(&vec![vec![0_usize]; N_RC_COLUMNS], point), vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE], ]) } @@ -89,7 +91,10 @@ impl Component for RangeCheckUnitComponent { lookup_values[RC_LOOKUP_VALUE_3], ); let numerator = value - lookup_value; - let denom = point_vanishing(constraint_zero_domain.at(1), point); + let denom = point_vanishing( + constraint_zero_domain.at(constraint_zero_domain.size() - 1), + point, + ); evaluation_accumulator.accumulate(numerator / denom); // Lookup step constraint. @@ -134,7 +139,7 @@ impl ComponentTraceGenerator for RangeCheckUnitTraceGenerator { registry.get_generator::(component_id); let rc_max_value = rc_unit_trace_generator.max_value; - let mut trace = vec![vec![BaseField::zero(); rc_max_value]; 2]; + let mut trace = vec![vec![BaseField::zero(); rc_max_value]; N_RC_COLUMNS]; for (i, multiplicity) in rc_unit_trace_generator.multiplicities.iter().enumerate() { // TODO(AlonH): Either create a constant column for the addresses and remove it from // here or add constraints to the column here. diff --git a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs index 072169ff..dc0321a6 100644 --- a/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs +++ b/stwo_cairo_prover/src/components/range_check_unit/component_prover.rs @@ -39,8 +39,10 @@ impl ComponentProver for RangeCheckUnitComponent { // TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions. let first_point_denom_inverses = point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); - let last_point_denom_inverses = - point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1)); + let last_point_denom_inverses = point_vanish_denominator_inverses( + trace_eval_domain, + zero_domain.at(zero_domain.size() - 1), + ); let mut step_denoms = vec![]; for point in trace_eval_domain.iter() { step_denoms.push( @@ -99,28 +101,28 @@ impl ComponentProver for RangeCheckUnitComponent { ( RC_LOOKUP_VALUE_0.to_string(), trace_poly[0] - .eval_at_point(domain.at(1).into_ef()) + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), ( RC_LOOKUP_VALUE_1.to_string(), trace_poly[1] - .eval_at_point(domain.at(1).into_ef()) + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), ( RC_LOOKUP_VALUE_2.to_string(), trace_poly[2] - .eval_at_point(domain.at(1).into_ef()) + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), ( RC_LOOKUP_VALUE_3.to_string(), trace_poly[3] - .eval_at_point(domain.at(1).into_ef()) + .eval_at_point(domain.at(domain.size() - 1).into_ef()) .try_into() .unwrap(), ), diff --git a/stwo_cairo_prover/src/components/ret_opcode/test_utils.rs b/stwo_cairo_prover/src/components/ret_opcode/test_utils.rs index 72518d38..29468f01 100644 --- a/stwo_cairo_prover/src/components/ret_opcode/test_utils.rs +++ b/stwo_cairo_prover/src/components/ret_opcode/test_utils.rs @@ -13,8 +13,11 @@ use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; use crate::components::memory::component::{ - MemoryComponent, MemoryTraceGenerator, MEMORY_ALPHA, MEMORY_COMPONENT_ID, MEMORY_Z, - N_M31_IN_FELT252, N_MEMORY_COLUMNS, + MemoryComponent, MemoryTraceGenerator, MAX_MEMORY_CELL_VALUE, MEMORY_ALPHA, + MEMORY_COMPONENT_ID, MEMORY_Z, N_M31_IN_FELT252, N_MEMORY_COLUMNS, +}; +use crate::components::range_check_unit::component::{ + RangeCheckUnitTraceGenerator, RC_COMPONENT_ID, RC_Z, }; use crate::components::ret_opcode::component::{RetOpcode, RET_COMPONENT_ID, RET_N_TRACE_CELLS}; use crate::components::ret_opcode::trace::RetOpcodeCpuTraceGenerator; @@ -24,6 +27,10 @@ pub fn register_test_ret_memory(registry: &mut ComponentGenerationRegistry) { MEMORY_COMPONENT_ID, MemoryTraceGenerator::new("".to_string()), ); + registry.register( + RC_COMPONENT_ID, + RangeCheckUnitTraceGenerator::new(MAX_MEMORY_CELL_VALUE), + ); let mut value = [M31::from_u32_unchecked(0); N_M31_IN_FELT252]; value[0] = M31::from_u32_unchecked(1); @@ -67,10 +74,11 @@ impl TestRetAirGenerator { impl AirTraceVerifier for TestRetAirGenerator { fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { - let elements = channel.draw_felts(2); + let elements = channel.draw_felts(3); InteractionElements::new(BTreeMap::from_iter(vec![ (MEMORY_ALPHA.to_string(), elements[0]), (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), ])) } } @@ -154,10 +162,11 @@ impl AirProver for TestAir { impl AirTraceVerifier for TestAir { fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { - let elements = channel.draw_felts(2); + let elements = channel.draw_felts(3); InteractionElements::new(BTreeMap::from_iter(vec![ (MEMORY_ALPHA.to_string(), elements[0]), (MEMORY_Z.to_string(), elements[1]), + (RC_Z.to_string(), elements[2]), ])) } } diff --git a/stwo_cairo_prover/src/main.rs b/stwo_cairo_prover/src/main.rs index d4f2eaff..b47cd5b6 100644 --- a/stwo_cairo_prover/src/main.rs +++ b/stwo_cairo_prover/src/main.rs @@ -1,4 +1,7 @@ +pub mod air; pub mod components; +#[cfg(test)] +pub mod test_utils; fn main() { println!("Hello, world!"); diff --git a/stwo_cairo_prover/src/test_utils.rs b/stwo_cairo_prover/src/test_utils.rs new file mode 100644 index 00000000..abd4cdb3 --- /dev/null +++ b/stwo_cairo_prover/src/test_utils.rs @@ -0,0 +1,19 @@ +use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; + +use crate::components::memory::component::{ + MemoryTraceGenerator, MAX_MEMORY_CELL_VALUE, MEMORY_COMPONENT_ID, +}; +use crate::components::range_check_unit::component::{ + RangeCheckUnitTraceGenerator, RC_COMPONENT_ID, +}; + +pub fn register_test_memory(registry: &mut ComponentGenerationRegistry) { + registry.register( + MEMORY_COMPONENT_ID, + MemoryTraceGenerator::new("".to_string()), + ); + registry.register( + RC_COMPONENT_ID, + RangeCheckUnitTraceGenerator::new(MAX_MEMORY_CELL_VALUE), + ); +}