diff --git a/stwo_cairo_prover/src/components/memory/component.rs b/stwo_cairo_prover/src/components/memory/component.rs index a1dd83ab..9924031d 100644 --- a/stwo_cairo_prover/src/components/memory/component.rs +++ b/stwo_cairo_prover/src/components/memory/component.rs @@ -70,6 +70,10 @@ impl MemoryTraceGenerator { multiplicities, } } + + pub fn deduce_output(&self, input: BaseField) -> [BaseField; N_M31_IN_FELT252] { + self.values[input.0 as usize] + } } impl ComponentGen for MemoryTraceGenerator {} diff --git a/stwo_cairo_prover/src/components/mod.rs b/stwo_cairo_prover/src/components/mod.rs index 9d62543b..6bc603ce 100644 --- a/stwo_cairo_prover/src/components/mod.rs +++ b/stwo_cairo_prover/src/components/mod.rs @@ -1,2 +1,3 @@ pub mod memory; pub mod range_check_unit; +pub mod ret_opcode; diff --git a/stwo_cairo_prover/src/components/ret_opcode/component.rs b/stwo_cairo_prover/src/components/ret_opcode/component.rs new file mode 100644 index 00000000..e7274d0f --- /dev/null +++ b/stwo_cairo_prover/src/components/ret_opcode/component.rs @@ -0,0 +1,70 @@ +#![allow(unused_imports)] +use num_traits::{One, Zero}; +use stwo_prover::core::air::accumulation::PointEvaluationAccumulator; +use stwo_prover::core::air::mask::fixed_mask_points; +use stwo_prover::core::air::Component; +use stwo_prover::core::circle::CirclePoint; +use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use stwo_prover::core::fields::FieldExpOps; +use stwo_prover::core::pcs::TreeVec; +use stwo_prover::core::poly::circle::CanonicCoset; +use stwo_prover::core::utils::shifted_secure_combination; +use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; +use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE}; + +use crate::components::memory::component::{MEMORY_ALPHA, MEMORY_Z, N_M31_IN_FELT252}; + +pub const RET_COMPONENT_ID: &str = "RET"; +pub const RET_LOOKUP_VALUE_0: &str = "RET_LOOKUP_0"; +pub const RET_LOOKUP_VALUE_1: &str = "RET_LOOKUP_1"; +pub const RET_LOOKUP_VALUE_2: &str = "RET_LOOKUP_2"; +pub const RET_LOOKUP_VALUE_3: &str = "RET_LOOKUP_3"; + +#[allow(non_camel_case_types)] +#[derive(Clone)] + +pub struct RetOpcode { + pub log_n_instances: u32, +} + +impl Component for RetOpcode { + fn n_constraints(&self) -> usize { + todo!() + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_n_instances + 1 + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + TreeVec(vec![ + vec![self.log_n_instances; 7], + vec![self.log_n_instances; 4], + ]) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let domain = CanonicCoset::new(self.log_n_instances); + TreeVec(vec![ + fixed_mask_points(&vec![vec![0_usize]; 7], point), + vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE], + ]) + } + + fn evaluate_constraint_quotients_at_point( + &self, + _point: CirclePoint, + _mask: &TreeVec>>, + _evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, + _lookup_values: &LookupValues, + ) { + todo!() + } +} diff --git a/stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs b/stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs new file mode 100644 index 00000000..79fa96c5 --- /dev/null +++ b/stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs @@ -0,0 +1,30 @@ +#![allow(unused_imports)] +use num_traits::identities::Zero; +use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator; +use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace}; +use stwo_prover::core::backend::{Column, CpuBackend}; +use stwo_prover::core::constraints::coset_vanishing; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::FieldExpOps; +use stwo_prover::core::poly::circle::CanonicCoset; +use stwo_prover::core::utils::bit_reverse; +use stwo_prover::core::{InteractionElements, LookupValues}; + +use super::component::RetOpcode; + +impl ComponentProver for RetOpcode { + #[allow(unused_parens)] + fn evaluate_constraint_quotients_on_domain( + &self, + _trace: &ComponentTrace<'_, CpuBackend>, + _evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, + _lookup_values: &LookupValues, + ) { + todo!() + } + fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + todo!() + } +} diff --git a/stwo_cairo_prover/src/components/ret_opcode/mod.rs b/stwo_cairo_prover/src/components/ret_opcode/mod.rs new file mode 100644 index 00000000..83a49781 --- /dev/null +++ b/stwo_cairo_prover/src/components/ret_opcode/mod.rs @@ -0,0 +1,51 @@ +pub mod component; +pub mod cpu_prover; +pub mod test_utils; +pub mod trace; + +#[cfg(test)] +pub(crate) mod tests { + use itertools::Itertools; + use num_traits::{One, Zero}; + use stwo_prover::core::channel::{Blake2sChannel, Channel}; + use stwo_prover::core::fields::m31::{BaseField, M31}; + use stwo_prover::core::fields::qm31::SecureField; + use stwo_prover::core::fields::IntoSlice; + use stwo_prover::core::utils::shifted_secure_combination; + use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; + use stwo_prover::core::vcs::hasher::Hasher; + use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier}; + + use crate::components::memory::component::{MEMORY_ALPHA, MEMORY_Z, N_M31_IN_FELT252}; + use crate::components::ret_opcode::test_utils::TestRetAirGenerator; + + #[test] + fn test_ret_interaction_trace() { + let mut air_generator = TestRetAirGenerator::new(); + let trace = air_generator.write_trace(); + let prover_channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let interaction_elements = air_generator.interaction_elements(prover_channel); + let alpha = interaction_elements[MEMORY_ALPHA]; + let z = interaction_elements[MEMORY_Z]; + let mut expected_logup_sum = SecureField::zero(); + for i in 0..8 { + assert_eq!(trace[0].values[i], M31::from_u32_unchecked(i as u32)); + let mut address_and_value = [M31::zero(); N_M31_IN_FELT252 + 1]; + address_and_value[0] = M31::from_u32_unchecked(i as u32); + address_and_value[1] = M31::one(); + expected_logup_sum += + M31::one() / shifted_secure_combination(&address_and_value, alpha, z); + } + + let interaction_trace = air_generator + .interact(&trace, &interaction_elements) + .into_iter() + .take(4) + .collect_vec(); + let logup_sum = + SecureField::from_m31_array(std::array::from_fn(|j| interaction_trace[j][1])); + + assert_eq!(logup_sum, expected_logup_sum); + } +} diff --git a/stwo_cairo_prover/src/components/ret_opcode/test_utils.rs b/stwo_cairo_prover/src/components/ret_opcode/test_utils.rs new file mode 100644 index 00000000..474defd0 --- /dev/null +++ b/stwo_cairo_prover/src/components/ret_opcode/test_utils.rs @@ -0,0 +1,164 @@ +#![cfg(test)] +use std::collections::BTreeMap; + +use component::{RetOpcode, RET_COMPONENT_ID}; +use itertools::Itertools; +use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver}; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::channel::{Blake2sChannel, Channel}; +use stwo_prover::core::fields::m31::{BaseField, M31}; +use stwo_prover::core::poly::circle::CircleEvaluation; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::prover::VerificationError; +use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues}; +use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; +use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; +use trace::RetOpcodeCpuTraceGenerator; + +use super::*; +use crate::components::memory::component::{ + MemoryComponent, MemoryTraceGenerator, MEMORY_ALPHA, MEMORY_COMPONENT_ID, MEMORY_Z, + N_M31_IN_FELT252, +}; + +pub fn register_test_ret_memory(registry: &mut ComponentGenerationRegistry) { + registry.register( + MEMORY_COMPONENT_ID, + MemoryTraceGenerator::new("".to_string()), + ); + let mut value = [M31::from_u32_unchecked(0); N_M31_IN_FELT252]; + value[0] = M31::from_u32_unchecked(1); + + registry + .get_generator_mut::(MEMORY_COMPONENT_ID) + .values = vec![value; 8]; +} + +pub fn register_test_ret(registry: &mut ComponentGenerationRegistry) { + registry.register( + RET_COMPONENT_ID, + RetOpcodeCpuTraceGenerator { inputs: vec![] }, + ); + let inputs = (0..8) + .map(|i| { + [ + M31::from_u32_unchecked(i), + M31::from_u32_unchecked(2), + M31::from_u32_unchecked(2), + ] + }) + .collect_vec(); + registry + .get_generator_mut::(RET_COMPONENT_ID) + .add_inputs(&inputs); +} + +pub(crate) struct TestRetAirGenerator { + pub registry: ComponentGenerationRegistry, +} + +impl TestRetAirGenerator { + pub fn new() -> Self { + let mut registry = ComponentGenerationRegistry::default(); + register_test_ret_memory(&mut registry); + register_test_ret(&mut registry); + Self { registry } + } +} + +impl AirTraceVerifier for TestRetAirGenerator { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(2); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + ])) + } +} + +impl AirTraceGenerator for TestRetAirGenerator { + fn write_trace(&mut self) -> Vec> { + // TODO(Ohad): add memory trace. + let ret_trace = + RetOpcodeCpuTraceGenerator::write_trace(RET_COMPONENT_ID, &mut self.registry); + let memory_trace = + MemoryTraceGenerator::write_trace(MEMORY_COMPONENT_ID, &mut self.registry); + ret_trace.into_iter().chain(memory_trace).collect() + } + + fn interact( + &self, + trace: &ColumnVec>, + elements: &InteractionElements, + ) -> Vec> { + let ret_trace = trace.iter().take(7).collect_vec(); + let memory_trace = trace.iter().skip(7).collect_vec(); + let ret_intraction_trace = self + .registry + .get_generator::(RET_COMPONENT_ID) + .write_interaction_trace(&ret_trace, elements); + let memory_interaction_trace = self + .registry + .get_generator::(MEMORY_COMPONENT_ID) + .write_interaction_trace(&memory_trace, elements); + + ret_intraction_trace + .into_iter() + .chain(memory_interaction_trace) + .collect() + } + + fn to_air_prover(&self) -> impl AirProver { + let ret_component_generator = self + .registry + .get_generator::(RET_COMPONENT_ID); + let memory_component_generator = self + .registry + .get_generator::(MEMORY_COMPONENT_ID); + TestAir { + ret_component: ret_component_generator.component(), + memory_component: memory_component_generator.component(), + } + } + + fn composition_log_degree_bound(&self) -> u32 { + let component_generator = self + .registry + .get_generator::(RET_COMPONENT_ID); + component_generator + .component() + .max_constraint_log_degree_bound() + } +} + +#[derive(Clone)] +pub struct TestAir { + pub ret_component: RetOpcode, + pub memory_component: MemoryComponent, +} + +impl Air for TestAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.ret_component, &self.memory_component] + } + + fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> { + Ok(()) + } +} + +impl AirProver for TestAir { + fn prover_components(&self) -> Vec<&dyn ComponentProver> { + vec![&self.ret_component, &self.memory_component] + } +} + +impl AirTraceVerifier for TestAir { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let elements = channel.draw_felts(2); + InteractionElements::new(BTreeMap::from_iter(vec![ + (MEMORY_ALPHA.to_string(), elements[0]), + (MEMORY_Z.to_string(), elements[1]), + ])) + } +} diff --git a/stwo_cairo_prover/src/components/ret_opcode/trace.rs b/stwo_cairo_prover/src/components/ret_opcode/trace.rs new file mode 100644 index 00000000..e466b0a2 --- /dev/null +++ b/stwo_cairo_prover/src/components/ret_opcode/trace.rs @@ -0,0 +1,198 @@ +#![allow(unused_imports)] +use itertools::Itertools; +use num_traits::{One, Zero}; +use stwo_prover::core::air::Component; +use stwo_prover::core::backend::cpu::CpuCircleEvaluation; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SecureColumn; +use stwo_prover::core::fields::FieldExpOps; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::{ + bit_reverse_index, coset_order_to_circle_domain_order_index, shifted_secure_combination, +}; +use stwo_prover::core::{ColumnVec, InteractionElements}; +use stwo_prover::trace_generation::registry::ComponentGenerationRegistry; +use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator}; + +use super::component::RetOpcode; +use crate::components::memory::component::{ + MemoryComponent, MemoryTraceGenerator, MEMORY_ADDRESS_BOUND, MEMORY_ALPHA, MEMORY_COMPONENT_ID, + MEMORY_Z, N_M31_IN_FELT252, +}; + +#[allow(non_camel_case_types)] +pub struct RetOpcodeCpuTraceGenerator { + pub inputs: Vec<[M31; 3]>, +} +impl ComponentGen for RetOpcodeCpuTraceGenerator {} + +impl ComponentTraceGenerator for RetOpcodeCpuTraceGenerator { + type Component = RetOpcode; + type Inputs = Vec<[M31; 3]>; + + fn write_trace( + component_id: &str, + registry: &mut ComponentGenerationRegistry, + ) -> Vec> { + let generator = registry.get_generator::(component_id); + let memory_trace_generator = + registry.get_generator::(MEMORY_COMPONENT_ID); + let (trace, sub_component_inputs) = write_trace_cpu( + &generator.component(), + &generator.inputs, + memory_trace_generator, + ); + let trace_generator = + registry.get_generator_mut::(MEMORY_COMPONENT_ID); + sub_component_inputs.memory_inputs.iter().for_each(|input| { + trace_generator.add_inputs(input); + }); + trace + } + + fn add_inputs(&mut self, inputs: &Self::Inputs) { + self.inputs.extend(inputs); + } + + fn component(&self) -> RetOpcode { + RetOpcode { + log_n_instances: self + .inputs + .len() + .checked_ilog2() + .expect("Input not a power of 2!"), + } + } + + fn write_interaction_trace( + &self, + trace: &ColumnVec<&CircleEvaluation>, + elements: &InteractionElements, + ) -> ColumnVec> { + let interaction_trace_domain = trace[0].domain; + let domain_size = interaction_trace_domain.size(); + let log_domain_size = interaction_trace_domain.log_size(); + let (memory_alpha, memory_z) = (elements[MEMORY_ALPHA], elements[MEMORY_Z]); + + // PC Column. + let pc_column = &trace[0].values; + let denoms = pc_column + .iter() + .copied() + .map(|pc| { + let mut address_and_value = [M31::zero(); N_M31_IN_FELT252 + 1]; + address_and_value[0] = pc; + address_and_value[1] = M31::one(); + shifted_secure_combination(&address_and_value, memory_alpha, memory_z) + }) + .collect_vec(); + let mut denom_inverses = vec![SecureField::zero(); domain_size]; + SecureField::batch_inverse(&denoms, &mut denom_inverses); + + let mut logup_values = vec![SecureField::zero(); domain_size]; + let mut prefix_sum = SecureField::zero(); + for i in 0..domain_size { + let index = bit_reverse_index( + coset_order_to_circle_domain_order_index(i, log_domain_size), + log_domain_size, + ); + prefix_sum += denom_inverses[index]; + logup_values[index] = prefix_sum; + } + + let secure_column: SecureColumn = logup_values.into_iter().collect(); + secure_column + .columns + .into_iter() + .map(|eval| CircleEvaluation::new(interaction_trace_domain, eval)) + .collect_vec() + } +} + +#[allow(non_snake_case)] +pub struct ReturnedInputs { + pub memory_inputs: Vec, +} + +impl ReturnedInputs { + #[allow(unused_variables)] + fn with_capacity(capacity: usize) -> Self { + Self { + memory_inputs: Vec::with_capacity(capacity * 3), + } + } +} + +#[allow(clippy::ptr_arg)] +#[allow(clippy::type_complexity)] +#[allow(clippy::let_unit_value)] +pub fn write_trace_cpu( + component: &RetOpcode, + secrets: &Vec<[M31; 3]>, + memory_trace_generator: &MemoryTraceGenerator, +) -> ( + Vec>, + ReturnedInputs, +) { + let n_trace_columns = component.trace_log_degree_bounds()[0].len(); + let mut trace_values = vec![vec![M31::zero(); secrets.len()]; n_trace_columns]; + let mut sub_components_inputs = ReturnedInputs::with_capacity(secrets.len()); + secrets.iter().enumerate().for_each(|(i, secret)| { + write_trace_row( + &mut trace_values, + *secret, + i, + &mut sub_components_inputs, + memory_trace_generator, + ); + }); + + let trace = trace_values + .into_iter() + .map(|eval| { + let domain = + CanonicCoset::new(eval.len().checked_ilog2().expect("Input not a power of 2!")) + .circle_domain(); + CpuCircleEvaluation::::new(domain, eval) + }) + .collect_vec(); + + (trace, sub_components_inputs) +} + +#[allow(non_snake_case)] +#[allow(clippy::useless_conversion)] +#[allow(unused_variables)] +fn write_trace_row( + dst: &mut [Vec], + RetOpcode_input: [M31; 3], + row_index: usize, + returned_inputs: &mut ReturnedInputs, + memory_trace_generator: &MemoryTraceGenerator, +) { + let deduction_tmp_0 = [RetOpcode_input[0], RetOpcode_input[1], RetOpcode_input[2]]; + let col0 = deduction_tmp_0[0]; + dst[0][row_index] = col0.into(); + let col1 = deduction_tmp_0[1]; + dst[1][row_index] = col1.into(); + let col2 = deduction_tmp_0[2]; + dst[2][row_index] = col2.into(); + returned_inputs.memory_inputs.push(col0); + let deduction_tmp_2 = memory_trace_generator.deduce_output(col0); + // TODO(Ohad): implement and uncomment. + // returned_inputs.memory_inputs.push((col2) - (M31::from(1))); + let deduction_tmp_4 = memory_trace_generator.deduce_output((col2) - (M31::from(1))); + let col3 = deduction_tmp_4[0]; + dst[3][row_index] = col3.into(); + let col4 = deduction_tmp_4[1]; + dst[4][row_index] = col4.into(); + // returned_inputs.memory_inputs.push((col2) - (M31::from(2))); + let deduction_tmp_5 = memory_trace_generator.deduce_output((col2) - (M31::from(2))); + let col5 = deduction_tmp_5[0]; + dst[5][row_index] = col5.into(); + let col6 = deduction_tmp_5[1]; + dst[6][row_index] = col6.into(); +}