diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index c0d8319fa..7ba297378 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -3,6 +3,8 @@ use std::iter::zip; use std::ops::Deref; use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use tracing::{span, Level}; use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; @@ -10,6 +12,7 @@ use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluatio use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::utils::UnsafeShared; use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; use crate::core::backend::simd::SimdBackend; use crate::core::circle::CirclePoint; @@ -58,7 +61,7 @@ impl TraceLocationAllocator { /// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for /// the SIMD backend. /// Note that the constraint framework only support components with columns of the same size. -pub trait FrameworkEval { +pub trait FrameworkEval: std::marker::Sync { fn log_size(&self) -> u32; fn max_constraint_log_degree_bound(&self) -> u32; @@ -176,7 +179,17 @@ impl ComponentProver for FrameworkComponent { let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) }; - for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) { + let iter_range = 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let col = unsafe { UnsafeShared::new(col) }; + iter.for_each(|vec_row| { + let col = unsafe { col.get() }; let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref()); // Evaluate constrains at row. @@ -197,7 +210,7 @@ impl ComponentProver for FrameworkComponent { ); col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv) } - } + }); } } diff --git a/crates/prover/src/core/backend/simd/mod.rs b/crates/prover/src/core/backend/simd/mod.rs index 280278aa6..3525e9c03 100644 --- a/crates/prover/src/core/backend/simd/mod.rs +++ b/crates/prover/src/core/backend/simd/mod.rs @@ -22,7 +22,7 @@ pub mod poseidon252; pub mod prefix_sum; pub mod qm31; pub mod quotients; -mod utils; +pub mod utils; pub mod very_packed_m31; #[derive(Copy, Clone, Debug, Deserialize, Serialize)] diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index 87dfd2246..ad89fe9f9 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -24,6 +24,21 @@ const fn parity_interleave(odd: bool) -> [usize; N] { res } +#[derive(Clone, Copy)] +pub struct UnsafeShared(pub *mut T); +impl UnsafeShared { + pub unsafe fn new(t: &mut T) -> Self { + Self(t as *mut T) + } + #[allow(clippy::mut_from_ref)] + pub unsafe fn get(&self) -> &mut T { + &mut *self.0 + } +} + +unsafe impl Sync for UnsafeShared {} +unsafe impl Send for UnsafeShared {} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle};