Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize domain eval #821

Open
wants to merge 3 commits into
base: graphite-base/821
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ use std::iter::zip;
use std::ops::Deref;

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use tracing::{span, Level};

use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::utils::UnsafeShared;
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
use crate::core::backend::simd::SimdBackend;
use crate::core::circle::CirclePoint;
Expand Down Expand Up @@ -58,7 +61,7 @@ impl TraceLocationAllocator {
/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for
/// the SIMD backend.
/// Note that the constraint framework only support components with columns of the same size.
pub trait FrameworkEval {
pub trait FrameworkEval: std::marker::Sync {
fn log_size(&self) -> u32;

fn max_constraint_log_degree_bound(&self) -> u32;
Expand Down Expand Up @@ -176,7 +179,17 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) };

for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) {
let iter_range = 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS));

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let col = unsafe { UnsafeShared::new(col) };
iter.for_each(|vec_row| {
let col = unsafe { col.get() };
let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref());

// Evaluate constrains at row.
Expand All @@ -197,7 +210,7 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
);
col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv)
}
}
});
}
}

Expand Down
102 changes: 84 additions & 18 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use std::iter::zip;
use std::mem::transmute;

use bytemuck::{cast_slice, Zeroable};
use itertools::Itertools;
use num_traits::One;

use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE};
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::{Col, CpuBackend};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::backend::simd::m31::PackedM31;
use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::poly::circle::{
Expand All @@ -20,6 +22,7 @@ use crate::core::poly::circle::{
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, bit_reverse_index};

impl SimdBackend {
// TODO(Ohad): optimize.
Expand Down Expand Up @@ -275,32 +278,95 @@ impl PolyOps for SimdBackend {
)
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
let mut twiddles = Vec::with_capacity(coset.size());
let mut itwiddles = Vec::with_capacity(coset.size());
#[allow(clippy::int_plus_one)]
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
let root_coset = coset;

// TODO(spapini): Optimize.
for layer in &rfft::get_twiddle_dbls(coset) {
twiddles.extend(layer);
// Generate xs for descending cosets, each bit reversed.
let mut xs = Vec::with_capacity(coset.size() / N_LANES);
while coset.log_size() - 1 >= LOG_N_LANES {
gen_coset_xs(coset, &mut xs);
coset = coset.double();
}

let mut extra = Vec::with_capacity(N_LANES);
while coset.log_size() > 0 {
let start = extra.len();
extra.extend(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x)
.collect_vec(),
);
bit_reverse(&mut extra[start..]);
coset = coset.double();
}
// Pad by any value, to make the size a power of 2.
twiddles.push(1);
assert_eq!(twiddles.len(), coset.size());
for layer in &ifft::get_itwiddle_dbls(coset) {
itwiddles.extend(layer);
extra.push(M31::one());

if extra.len() < N_LANES {
let twiddles = extra.iter().map(|x| x.0 * 2).collect();
let itwiddles = extra.iter().map(|x| x.inverse().0 * 2).collect();
return TwiddleTree {
root_coset,
twiddles,
itwiddles,
};
}
// Pad by any value, to make the size a power of 2.
itwiddles.push(1);
assert_eq!(itwiddles.len(), coset.size());

xs.push(PackedM31::from_array(extra.try_into().unwrap()));

let mut ixs = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&xs, &mut ixs);

let twiddles = xs
.into_iter()
.flat_map(|x| x.to_array().map(|x| x.0 * 2))
.collect();
let itwiddles = ixs
.into_iter()
.flat_map(|x| x.to_array().map(|x| x.0 * 2))
.collect();

TwiddleTree {
root_coset: coset,
root_coset,
twiddles,
itwiddles,
}
}
}

#[allow(clippy::int_plus_one)]
fn gen_coset_xs(coset: Coset, res: &mut Vec<PackedM31>) {
let log_size = coset.log_size() - 1;
assert!(log_size >= LOG_N_LANES);

let initial_points = std::array::from_fn(|i| coset.at(bit_reverse_index(i, log_size)));
let mut current = CirclePoint {
x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)),
y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)),
};

let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize];
for i in 0..(log_size - LOG_N_LANES) {
let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES);
let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES);
let flip = coset.step.mul(new_mul as u128) - coset.step.mul(prev_mul as u128);
flips[i as usize] = flip;
}

for i in 0u32..1 << (log_size - LOG_N_LANES) {
let x = current.x;
let flip_index = i.trailing_ones() as usize;
let flip = CirclePoint {
x: PackedM31::broadcast(flips[flip_index].x),
y: PackedM31::broadcast(flips[flip_index].y),
};
current = current + flip;
res.push(x);
}
}

fn slow_eval_at_point(
poly: &CirclePoly<SimdBackend>,
point: CirclePoint<SecureField>,
Expand Down
32 changes: 28 additions & 4 deletions crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::UnsafeMutI32;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
Expand All @@ -29,6 +32,7 @@ use crate::core::utils::bit_reverse;
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);

let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
Expand Down Expand Up @@ -81,7 +85,17 @@ pub unsafe fn ifft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let iter_range = 0..1 << (log_size - fft_layers);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
Expand All @@ -102,7 +116,7 @@ pub unsafe fn ifft_lower_with_vecwise(
}
}
}
}
});
}

/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -131,7 +145,17 @@ pub unsafe fn ifft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
for layer in (0..fft_layers).step_by(3) {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -152,7 +176,7 @@ pub unsafe fn ifft_lower_without_vecwise(
}
}
}
}
});
}

/// Runs the first 5 ifft layers across the entire array.
Expand Down
38 changes: 35 additions & 3 deletions crates/prover/src/core/backend/simd/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::simd::{simd_swizzle, u32x16, u32x8};

#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::m31::PackedBaseField;
use crate::core::fields::m31::P;

Expand All @@ -10,6 +13,26 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16;

pub const MIN_FFT_LOG_SIZE: u32 = 5;

pub struct UnsafeMutI32(pub *mut u32);
impl UnsafeMutI32 {
pub fn get(&self) -> *mut u32 {
self.0
}
}

unsafe impl Send for UnsafeMutI32 {}
unsafe impl Sync for UnsafeMutI32 {}

pub struct UnsafeConstI32(pub *const u32);
impl UnsafeConstI32 {
pub fn get(&self) -> *const u32 {
self.0
}
}

unsafe impl Send for UnsafeConstI32 {}
unsafe impl Sync for UnsafeConstI32 {}

// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

Expand All @@ -29,8 +52,17 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5;
/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`].
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..1 << (log_n_vecs & 1) {
for a in 0..1 << half {

#[cfg(not(feature = "parallel"))]
let iter = 0..1 << half;

#[cfg(feature = "parallel")]
let iter = (0..1 << half).into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|a| {
let values = values.get();
for b in 0..1 << (log_n_vecs & 1) {
for c in 0..1 << half {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
Expand All @@ -43,7 +75,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
store(values.add(j << 4), val0);
}
}
}
});
}

/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
Expand Down
Loading
Loading