Skip to content

Commit

Permalink
Parallelize domain eval
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Sep 5, 2024
1 parent 755db7c commit 48eb60e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
19 changes: 16 additions & 3 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ use std::iter::zip;
use std::ops::Deref;

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use tracing::{span, Level};

use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::utils::UnsafeShared;
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
use crate::core::backend::simd::SimdBackend;
use crate::core::circle::CirclePoint;
Expand Down Expand Up @@ -58,7 +61,7 @@ impl TraceLocationAllocator {
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
/// Note that the constraint framework only support components with columns of the same size.
pub trait FrameworkEval {
pub trait FrameworkEval: std::marker::Sync {
fn log_size(&self) -> u32;

fn max_constraint_log_degree_bound(&self) -> u32;
Expand Down Expand Up @@ -176,7 +179,17 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) };

for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) {
let iter_range = 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS));

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let col = unsafe { UnsafeShared::new(col) };
iter.for_each(|vec_row| {
let col = unsafe { col.get() };
let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref());

// Evaluate constrains at row.
Expand All @@ -197,7 +210,7 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
);
col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv)
}
}
});
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod poseidon252;
pub mod prefix_sum;
pub mod qm31;
pub mod quotients;
mod utils;
pub mod utils;
pub mod very_packed_m31;

#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
Expand Down
15 changes: 15 additions & 0 deletions crates/prover/src/core/backend/simd/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ const fn parity_interleave<const N: usize>(odd: bool) -> [usize; N] {
res
}

#[derive(Clone, Copy)]
pub struct UnsafeShared<T>(pub *mut T);
impl<T> UnsafeShared<T> {
pub unsafe fn new(t: &mut T) -> Self {
Self(t as *mut T)
}
#[allow(clippy::mut_from_ref)]
pub unsafe fn get(&self) -> &mut T {
&mut *self.0
}
}

unsafe impl<T> Sync for UnsafeShared<T> {}
unsafe impl<T> Send for UnsafeShared<T> {}

#[cfg(test)]
mod tests {
use std::simd::{u32x4, Swizzle};
Expand Down

0 comments on commit 48eb60e

Please sign in to comment.