Skip to content

Commit

Permalink
proof and lookup tests for ret opcode
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Jul 24, 2024
1 parent 858cc77 commit 7714e8a
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 20 deletions.
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly-2024-01-04"
2 changes: 1 addition & 1 deletion stwo_cairo_prover/src/components/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod component;
pub mod component_prover;

#[cfg(test)]
mod tests {
pub(crate) mod tests {
use std::collections::BTreeMap;

use component::{
Expand Down
55 changes: 48 additions & 7 deletions stwo_cairo_prover/src/components/ret_opcode/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub struct RetOpcode {

impl Component for RetOpcode {
fn n_constraints(&self) -> usize {
todo!()
3
}

fn max_constraint_log_degree_bound(&self) -> u32 {
Expand All @@ -59,12 +59,53 @@ impl Component for RetOpcode {

fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
point: CirclePoint<SecureField>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
todo!()
// First lookup point boundary constraint.
let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset;
let (alpha, z) = (
interaction_elements[MEMORY_ALPHA],
interaction_elements[MEMORY_Z],
);
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let address = mask[BASE_TRACE][0][0];
let value_at_pc = SecureField::one();
let mut address_and_value = [SecureField::zero(); N_M31_IN_FELT252 + 1];
address_and_value[0] = address;
address_and_value[1] = value_at_pc;

let numerator =
value * shifted_secure_combination(&address_and_value, alpha, z) - SecureField::one();
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

// Last lookup point boundary constraint.
let lookup_value = SecureField::from_m31(
lookup_values[RET_LOOKUP_VALUE_0],
lookup_values[RET_LOOKUP_VALUE_1],
lookup_values[RET_LOOKUP_VALUE_2],
lookup_values[RET_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(
constraint_zero_domain.at(constraint_zero_domain.size() - 1),
point,
);
evaluation_accumulator.accumulate(numerator / denom);

// Lookup step constraint.
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator = (value - prev_value)
* shifted_secure_combination(&address_and_value, alpha, z)
- SecureField::one();
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}
144 changes: 134 additions & 10 deletions stwo_cairo_prover/src/components/ret_opcode/cpu_prover.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,154 @@
#![allow(unused_imports)]
use std::collections::BTreeMap;

use itertools::izip;
use num_traits::identities::Zero;
use num_traits::One;
use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator;
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::{Column, CpuBackend};
use stwo_prover::core::constraints::coset_vanishing;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::poly::circle::CanonicCoset;
use stwo_prover::core::utils::bit_reverse;
use stwo_prover::core::utils::{
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
shifted_secure_combination,
};
use stwo_prover::core::{InteractionElements, LookupValues};
use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE};

use super::component::RetOpcode;
use super::component::{
RetOpcode, RET_LOOKUP_VALUE_0, RET_LOOKUP_VALUE_1, RET_LOOKUP_VALUE_2, RET_LOOKUP_VALUE_3,
};
use crate::components::memory::component::{MEMORY_ALPHA, MEMORY_Z, N_M31_IN_FELT252};

impl ComponentProver<CpuBackend> for RetOpcode {
#[allow(unused_parens)]
fn evaluate_constraint_quotients_on_domain(
&self,
_trace: &ComponentTrace<'_, CpuBackend>,
_evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
trace: &ComponentTrace<'_, CpuBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
todo!()
let max_constraint_degree = self.max_constraint_log_degree_bound();
let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain();
let trace_evals = &trace.evals;
let zero_domain = CanonicCoset::new(self.log_n_instances).coset;
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);

// TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions.
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses = point_vanish_denominator_inverses(
trace_eval_domain,
zero_domain.at(zero_domain.size() - 1),
);
let mut step_denoms = vec![];
for point in trace_eval_domain.iter() {
step_denoms.push(
coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point),
);
}
bit_reverse(&mut step_denoms);
let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses);

let (alpha, z) = (
interaction_elements[MEMORY_ALPHA],
interaction_elements[MEMORY_Z],
);

let lookup_value = SecureField::from_m31(
lookup_values[RET_LOOKUP_VALUE_0],
lookup_values[RET_LOOKUP_VALUE_1],
lookup_values[RET_LOOKUP_VALUE_2],
lookup_values[RET_LOOKUP_VALUE_3],
);

for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!(
first_point_denom_inverses,
last_point_denom_inverses,
step_denom_inverses,
)
.enumerate()
{
// Value = InteractionPoly(i);
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));

// PrevValue = InteractionPoly(i - g);
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));

// PC Column.
let address = trace_evals[BASE_TRACE][0][i];
let value_at_pc = BaseField::one();
let mut address_and_value = [BaseField::zero(); N_M31_IN_FELT252 + 1];
address_and_value[0] = address;
address_and_value[1] = value_at_pc;

// TODO(Ohad): add remaining lookup constraints.

let first_point_numerator = accum.random_coeff_powers[2]
* (value * shifted_secure_combination(&address_and_value, alpha, z)
- SecureField::one());
let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
let step_numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * shifted_secure_combination(&address_and_value, alpha, z)
- SecureField::from_u32_unchecked(1, 0, 0, 0));
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse
+ step_numerator * step_denom_inverse,
);
}
}
fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
todo!()

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
let domain = CanonicCoset::new(self.log_n_instances);
let trace_poly = &trace.polys[INTERACTION_TRACE];
let values = BTreeMap::from_iter([
(
RET_LOOKUP_VALUE_0.to_string(),
trace_poly[0]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
(
RET_LOOKUP_VALUE_1.to_string(),
trace_poly[1]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
(
RET_LOOKUP_VALUE_2.to_string(),
trace_poly[2]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
(
RET_LOOKUP_VALUE_3.to_string(),
trace_poly[3]
.eval_at_point(domain.at(domain.size() - 1).into_ef())
.try_into()
.unwrap(),
),
]);
LookupValues::new(values)
}
}
51 changes: 49 additions & 2 deletions stwo_cairo_prover/src/components/ret_opcode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,24 @@ pub mod trace;
pub(crate) mod tests {
use itertools::Itertools;
use num_traits::{One, Zero};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::channel::{Blake2sChannel, Channel};
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::IntoSlice;
use stwo_prover::core::utils::shifted_secure_combination;
use stwo_prover::core::vcs::blake2_hash::Blake2sHasher;
use stwo_prover::core::vcs::hasher::Hasher;
use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier};
use stwo_prover::trace_generation::{commit_and_prove, commit_and_verify, AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator};

use crate::components::memory::component::{MEMORY_ALPHA, MEMORY_Z, N_M31_IN_FELT252};
use crate::components::memory::component::{MemoryTraceGenerator, MEMORY_ALPHA, MEMORY_COMPONENT_ID, MEMORY_Z, N_M31_IN_FELT252};
use crate::components::ret_opcode::test_utils::TestRetAirGenerator;

use super::component::RET_COMPONENT_ID;
use super::test_utils::TestAir;
use super::trace::RetOpcodeCpuTraceGenerator;


#[test]
fn test_ret_interaction_trace() {
let mut air_generator = TestRetAirGenerator::new();
Expand Down Expand Up @@ -48,4 +54,45 @@ pub(crate) mod tests {

assert_eq!(logup_sum, expected_logup_sum);
}

#[test]
fn test_ret_lookup_values() {
let mut air_generator = TestRetAirGenerator::new();
let trace = air_generator.write_trace();
let prover_channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));

let interaction_trace =
air_generator.interact(&trace, &air_generator.interaction_elements(prover_channel));
let (ret_interaction, memory_interaction) = interaction_trace.split_at(4);
let ret_lookup_value =
SecureField::from_m31_array(std::array::from_fn(|i| ret_interaction[i].values[1]));
let memory_lookup_value =
SecureField::from_m31_array(std::array::from_fn(|i| memory_interaction[i].values[1]));

assert_eq!(ret_lookup_value, memory_lookup_value);
}

#[test]
fn test_ret_proof() {
let mut air_generator = TestRetAirGenerator::new();
let trace = air_generator.write_trace();
let prover_channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let proof = commit_and_prove::<CpuBackend>(&air_generator, prover_channel, trace).unwrap();

let verifier_channel =
&mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let air = TestAir {
memory_component: air_generator
.registry
.get_generator::<MemoryTraceGenerator>(MEMORY_COMPONENT_ID)
.component(),
ret_component: air_generator
.registry
.get_generator::<RetOpcodeCpuTraceGenerator>(RET_COMPONENT_ID)
.component(),
};
commit_and_verify(proof, &air, verifier_channel).unwrap();
}
}

0 comments on commit 7714e8a

Please sign in to comment.