Skip to content

Commit

Permalink
Add memory component. (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 authored Jul 17, 2024
2 parents 952318b + 09c5c7d commit 0b749df
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 9 deletions.
146 changes: 146 additions & 0 deletions stwo_cairo_prover/src/components/memory/component.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use itertools::{zip_eq, Itertools};
use num_traits::Zero;
use stwo_prover::core::air::accumulation::PointEvaluationAccumulator;
use stwo_prover::core::air::mask::fixed_mask_points;
use stwo_prover::core::air::Component;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::CirclePoint;
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use stwo_prover::core::pcs::TreeVec;
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues};
use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator};

const N_M31_IN_FELT252: usize = 21;
const LOG_MEMORY_ADDRESS_BOUND: u32 = 20;
const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND;

/// Addresses are continuous and start from 0.
/// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value contain 12 bits).
pub struct MemoryTraceGenerator {
// TODO(AlonH): Consider to change values to be Felt252.
pub values: Vec<[M31; N_M31_IN_FELT252]>,
pub multiplicities: Vec<u32>,
}

pub struct MemoryComponent {
pub log_n_rows: u32,
}

impl MemoryComponent {
pub const fn n_columns(&self) -> usize {
N_M31_IN_FELT252 + 2
}
}

impl MemoryTraceGenerator {
pub fn new(_path: String) -> Self {
// TODO(AlonH): change to read from file.
let values = vec![[M31::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND];
let multiplicities = vec![0; MEMORY_ADDRESS_BOUND];
Self {
values,
multiplicities,
}
}
}

impl ComponentGen for MemoryTraceGenerator {}

impl ComponentTraceGenerator<CpuBackend> for MemoryTraceGenerator {
type Component = MemoryComponent;
type Inputs = M31;

fn add_inputs(&mut self, inputs: &Self::Inputs) {
let input = inputs.0 as usize;
// TODO: replace the debug_assert! with an error return.
debug_assert!(input < MEMORY_ADDRESS_BOUND, "Input out of range");
self.multiplicities[input] += 1;
}

fn write_trace(
component_id: &str,
registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
let memory_trace_generator = registry.get_generator::<MemoryTraceGenerator>(component_id);

let mut trace = vec![vec![BaseField::zero(); MEMORY_ADDRESS_BOUND]; N_M31_IN_FELT252 + 2];
for (i, (values, multiplicity)) in zip_eq(
&memory_trace_generator.values,
&memory_trace_generator.multiplicities,
)
.enumerate()
{
// TODO(AlonH): Either create a constant column for the addresses and remove it from
// here or add constraints to the column here.
trace[0][i] = BaseField::from_u32_unchecked(i as u32);
for (j, value) in values.iter().enumerate() {
trace[j + 1][i] = BaseField::from_u32_unchecked(value.0);
}
trace[22][i] = BaseField::from_u32_unchecked(*multiplicity);
}

let domain = CanonicCoset::new(LOG_MEMORY_ADDRESS_BOUND).circle_domain();
trace
.into_iter()
.map(|eval| CircleEvaluation::<CpuBackend, _, BitReversedOrder>::new(domain, eval))
.collect_vec()
}

fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
todo!()
}

fn component(&self) -> Self::Component {
MemoryComponent {
log_n_rows: LOG_MEMORY_ADDRESS_BOUND,
}
}
}

impl Component for MemoryComponent {
fn n_constraints(&self) -> usize {
3
}

fn max_constraint_log_degree_bound(&self) -> u32 {
LOG_MEMORY_ADDRESS_BOUND + 1
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(vec![
vec![self.log_n_rows; self.n_columns()],
vec![self.log_n_rows; SECURE_EXTENSION_DEGREE],
])
}

fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let domain = CanonicCoset::new(self.log_n_rows);
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point),
vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE],
])
}

fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
todo!()
}
}
1 change: 1 addition & 0 deletions stwo_cairo_prover/src/components/memory/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod component;
1 change: 1 addition & 0 deletions stwo_cairo_prover/src/components/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod memory;
pub mod range_check_unit;
16 changes: 9 additions & 7 deletions stwo_cairo_prover/src/components/range_check_unit/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub const RC_LOOKUP_VALUE_3: &str = "RC_UNIT_LOOKUP_3";

#[derive(Clone)]
pub struct RangeCheckUnitComponent {
pub log_n_instances: u32,
pub log_n_rows: u32,
}

pub struct RangeCheckUnitTraceGenerator {
Expand All @@ -43,21 +43,21 @@ impl Component for RangeCheckUnitComponent {
}

fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_n_instances + 1
self.log_n_rows + 1
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(vec![
vec![self.log_n_instances; 2],
vec![self.log_n_instances; SECURE_EXTENSION_DEGREE],
vec![self.log_n_rows; 2],
vec![self.log_n_rows; SECURE_EXTENSION_DEGREE],
])
}

fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let domain = CanonicCoset::new(self.log_n_instances);
let domain = CanonicCoset::new(self.log_n_rows);
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; 2], point),
vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE],
Expand All @@ -73,7 +73,7 @@ impl Component for RangeCheckUnitComponent {
lookup_values: &LookupValues,
) {
// First lookup point boundary constraint.
let constraint_zero_domain = CanonicCoset::new(self.log_n_instances).coset;
let constraint_zero_domain = CanonicCoset::new(self.log_n_rows).coset;
let z = interaction_elements[RC_Z];
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
Expand Down Expand Up @@ -138,6 +138,8 @@ impl ComponentTraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {

let mut trace = vec![vec![BaseField::zero(); rc_max_value]; 2];
for (i, multiplicity) in rc_unit_trace_generator.multiplicities.iter().enumerate() {
// TODO(AlonH): Either create a constant column for the addresses and remove it from
// here or add constraints to the column here.
trace[0][i] = BaseField::from_u32_unchecked(i as u32);
trace[1][i] = BaseField::from_u32_unchecked(*multiplicity);
}
Expand Down Expand Up @@ -180,7 +182,7 @@ impl ComponentTraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {

fn component(&self) -> RangeCheckUnitComponent {
RangeCheckUnitComponent {
log_n_instances: self.max_value.checked_ilog2().unwrap(),
log_n_rows: self.max_value.checked_ilog2().unwrap(),
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain();
let trace_evals = &trace.evals;
let zero_domain = CanonicCoset::new(self.log_n_instances).coset;
let zero_domain = CanonicCoset::new(self.log_n_rows).coset;
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);

Expand Down Expand Up @@ -93,7 +93,7 @@ impl ComponentProver<CpuBackend> for RangeCheckUnitComponent {
}

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
let domain = CanonicCoset::new(self.log_n_instances);
let domain = CanonicCoset::new(self.log_n_rows);
let trace_poly = &trace.polys[INTERACTION_TRACE];
let values = BTreeMap::from_iter([
(
Expand Down

0 comments on commit 0b749df

Please sign in to comment.