Skip to content

Commit

Permalink
Generalize field in CirclePoint
Browse files Browse the repository at this point in the history
  • Loading branch information
atgrosso committed Oct 1, 2024
1 parent 9258b96 commit fe6f3ad
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 41 deletions.
118 changes: 84 additions & 34 deletions stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use stwo_cairo_verifier::fields::m31::{M31, m31};
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::fields::qm31::{QM31, QM31Trait};
use super::utils::pow;
use core::num::traits::zero::Zero;
use core::num::traits::one::One;

pub const M31_CIRCLE_GEN: CirclePointM31 =
CirclePointM31 { x: M31 { inner: 2 }, y: M31 { inner: 1268011823 }, };
pub const M31_CIRCLE_GEN: CirclePoint<M31> =
CirclePoint::<M31> { x: M31 { inner: 2 }, y: M31 { inner: 1268011823 }, };

pub const CIRCLE_LOG_ORDER: u32 = 31;

Expand All @@ -15,22 +18,27 @@ pub const CIRCLE_ORDER_BIT_MASK: u32 = 0x7fffffff;
// `U32_BIT_MASK` equals 2^32 - 1
pub const U32_BIT_MASK: u64 = 0xffffffff;

/// A point on the complex circle. Treated as an additive group.
#[derive(Drop, Copy, Debug, PartialEq, Eq)]
pub struct CirclePointM31 {
pub x: M31,
pub y: M31,
pub struct CirclePoint<F> {
pub x: F,
pub y: F
}

#[generate_trait]
pub impl CirclePointM31Impl of CirclePointM31Trait {
pub trait CirclePointTrait<F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>, +Zero<F>, +One<F>> {
// Returns the neutral element of the circle.
fn zero() -> CirclePointM31 {
CirclePointM31 { x: m31(1), y: m31(0) }
fn zero() -> CirclePoint<F> {
CirclePoint::<F> { x: One::<F>::one(), y: Zero::<F>::zero() }
}

fn mul(self: @CirclePointM31, mut scalar: u32) -> CirclePointM31 {
let mut result = Self::zero();
let mut cur = *self;
fn mul(
self: @CirclePoint<F>, initial_scalar: u128
) -> CirclePoint<
F
> {
let mut scalar = initial_scalar;
let mut result: CirclePoint<F> = Self::zero();
let mut cur: CirclePoint<F> = *self;
while scalar > 0 {
if scalar & 1 == 1 {
result = result + cur;
Expand All @@ -42,13 +50,28 @@ pub impl CirclePointM31Impl of CirclePointM31Trait {
}
}

impl CirclePointM31Add of Add<CirclePointM31> {
impl CirclePointAdd<F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>> of Add<CirclePoint<F>> {
// The operation of the circle as a group with additive notation.
fn add(lhs: CirclePointM31, rhs: CirclePointM31) -> CirclePointM31 {
CirclePointM31 { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x }
fn add(lhs: CirclePoint<F>, rhs: CirclePoint<F>) -> CirclePoint<F> {
CirclePoint::<F> { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x }
}
}

pub impl CirclePointM31Impl of CirclePointTrait<M31> {}

pub impl CirclePointQM31Impl of CirclePointTrait<QM31> {}

trait ComplexConjugate {
fn complex_conjugate(self: CirclePoint<QM31>) -> CirclePoint<QM31>;
}

pub impl ComplexConjugateImpl of ComplexConjugate {
fn complex_conjugate(self: CirclePoint<QM31>) -> CirclePoint<QM31> {
CirclePoint { x: self.x.complex_conjugate(), y: self.y.complex_conjugate() }
}
}

/// Represents the coset initial + <step>.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Drop)]
pub struct Coset {
// This is an index in the range [0, 2^31)
Expand Down Expand Up @@ -86,50 +109,77 @@ pub impl CosetImpl of CosetTrait {
}
}

fn at(self: @Coset, index: usize) -> CirclePointM31 {
M31_CIRCLE_GEN.mul(self.index_at(index))
fn at(self: @Coset, index: usize) -> CirclePoint::<M31> {
M31_CIRCLE_GEN.mul(self.index_at(index).into())
}

/// Returns the size of the coset.
fn size(self: @Coset) -> usize {
pow(2, *self.log_size)
}

/// Creates a coset of the form G_2n + \<G_n\>.
/// For example, for n=8, we get the point indices \[1,3,5,7,9,11,13,15\].
fn odds(log_size: u32) -> Coset {
//CIRCLE_LOG_ORDER
let subgroup_generator_index = Self::subgroup_generator_index(log_size);
Self::new(subgroup_generator_index, log_size)
}

/// Creates a coset of the form G_4n + <G_n>.
/// For example, for n=8, we get the point indices \[1,5,9,13,17,21,25,29\].
/// Its conjugate will be \[3,7,11,15,19,23,27,31\].
fn half_odds(log_size: u32) -> Coset {
Self::new(Self::subgroup_generator_index(log_size + 2), log_size)
}

fn subgroup_generator_index(log_size: u32) -> u32 {
assert!(log_size <= CIRCLE_LOG_ORDER);
pow(2, CIRCLE_LOG_ORDER - log_size)
}
}


#[cfg(test)]
mod tests {
use super::{M31_CIRCLE_GEN, CIRCLE_ORDER, CirclePointM31, CirclePointM31Impl, Coset, CosetImpl};
use stwo_cairo_verifier::fields::m31::m31;
use super::{M31_CIRCLE_GEN, CIRCLE_ORDER, CirclePoint, CirclePointM31Impl, Coset, CosetImpl};
use core::option::OptionTrait;
use core::array::ArrayTrait;
use core::traits::TryInto;
use super::CirclePointQM31Impl;
use stwo_cairo_verifier::fields::m31::{m31, M31};
use stwo_cairo_verifier::fields::qm31::{qm31, QM31, QM31One};
use stwo_cairo_verifier::utils::pow;

#[test]
fn test_add_1() {
let i = CirclePointM31 { x: m31(0), y: m31(1) };
let i = CirclePoint::<M31> { x: m31(0), y: m31(1) };
let result = i + i;
let expected_result = CirclePointM31 { x: -m31(1), y: m31(0) };
let expected_result = CirclePoint::<M31> { x: -m31(1), y: m31(0) };

assert_eq!(result, expected_result);
}

#[test]
fn test_add_2() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePointM31 { x: m31(1737427771), y: m31(309481134) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePoint::<M31> { x: m31(1737427771), y: m31(309481134) };
let result = point_1 + point_2;
let expected_result = CirclePointM31 { x: m31(1476625263), y: m31(1040927458) };
let expected_result = CirclePoint::<M31> { x: m31(1476625263), y: m31(1040927458) };

assert_eq!(result, expected_result);
}

#[test]
fn test_zero_1() {
let result = CirclePointM31Impl::zero();
let expected_result = CirclePointM31 { x: m31(1), y: m31(0) };
let expected_result = CirclePoint::<M31> { x: m31(1), y: m31(0) };
assert_eq!(result, expected_result);
}

#[test]
fn test_zero_2() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePointM31Impl::zero();
let expected_result = point_1.clone();
let result = point_1 + point_2;
Expand All @@ -139,7 +189,7 @@ mod tests {

#[test]
fn test_mul_1() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let result = point_1.mul(5);
let expected_result = point_1 + point_1 + point_1 + point_1 + point_1;

Expand All @@ -148,7 +198,7 @@ mod tests {

#[test]
fn test_mul_2() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let result = point_1.mul(8);
let mut expected_result = point_1 + point_1;
expected_result = expected_result + expected_result;
Expand All @@ -159,18 +209,18 @@ mod tests {

#[test]
fn test_mul_3() {
let point_1 = CirclePointM31 { x: m31(750649172), y: m31(1991648574) };
let point_1 = CirclePoint::<M31> { x: m31(750649172), y: m31(1991648574) };
let result = point_1.mul(418776494);
let expected_result = CirclePointM31 { x: m31(1987283985), y: m31(1500510905) };
let expected_result = CirclePoint::<M31> { x: m31(1987283985), y: m31(1500510905) };

assert_eq!(result, expected_result);
}

#[test]
fn test_generator_order() {
let half_order = CIRCLE_ORDER / 2;
let mut result = M31_CIRCLE_GEN.mul(half_order);
let expected_result = CirclePointM31 { x: -m31(1), y: m31(0) };
let mut result = M31_CIRCLE_GEN.mul(half_order.into());
let expected_result = CirclePoint::<M31> { x: -m31(1), y: m31(0) };

// Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`.
assert_eq!(expected_result, result);
Expand Down Expand Up @@ -204,7 +254,7 @@ mod tests {
fn test_coset_at() {
let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 };
let result = coset.at(17);
let expected_result = CirclePointM31 { x: m31(7144319), y: m31(1742797653) };
let expected_result = CirclePoint::<M31> { x: m31(7144319), y: m31(1742797653) };
assert_eq!(expected_result, result);
}

Expand Down
4 changes: 4 additions & 0 deletions stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ pub impl QM31Impl of QM31Trait {
let denom_inverse = denom.inverse();
QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}

fn complex_conjugate(self: QM31) -> QM31 {
QM31 { a: self.a, b: -self.b }
}
}

pub impl QM31Add of core::traits::Add<QM31> {
Expand Down
13 changes: 6 additions & 7 deletions stwo_cairo_verifier/src/poly/circle.cairo
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use stwo_cairo_verifier::circle::CirclePointM31Trait;
use core::option::OptionTrait;
use core::clone::Clone;
use core::result::ResultTrait;
use stwo_cairo_verifier::fields::m31::{M31, m31};
use stwo_cairo_verifier::utils::pow;
use stwo_cairo_verifier::circle::{
Coset, CosetImpl, CirclePointM31, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER
Coset, CosetImpl, CirclePoint, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER
};

/// A valid domain for circle polynomial interpolation and evaluation.
Expand All @@ -31,8 +30,8 @@ pub impl CircleDomainImpl of CircleDomainTrait {
}
}

fn at(self: @CircleDomain, index: usize) -> CirclePointM31 {
M31_CIRCLE_GEN.mul(self.index_at(index))
fn at(self: @CircleDomain, index: usize) -> CirclePoint::<M31> {
M31_CIRCLE_GEN.mul(self.index_at(index).into())
}
}

Expand All @@ -41,7 +40,7 @@ pub impl CircleDomainImpl of CircleDomainTrait {
mod tests {
use super::{CircleDomain, CircleDomainTrait};
use stwo_cairo_verifier::circle::{
Coset, CosetImpl, CirclePointM31, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER
Coset, CosetImpl, CirclePoint, CirclePointM31Impl, M31_CIRCLE_GEN, CIRCLE_ORDER
};
use stwo_cairo_verifier::fields::m31::{M31, m31};

Expand All @@ -51,7 +50,7 @@ mod tests {
let domain = CircleDomain { half_coset };
let index = 17;
let result = domain.at(index);
let expected_result = CirclePointM31 { x: m31(7144319), y: m31(1742797653) };
let expected_result = CirclePoint::<M31> { x: m31(7144319), y: m31(1742797653) };
assert_eq!(expected_result, result);
}

Expand All @@ -61,7 +60,7 @@ mod tests {
let domain = CircleDomain { half_coset };
let index = 37;
let result = domain.at(index);
let expected_result = CirclePointM31 { x: m31(9803698), y: m31(2079025011) };
let expected_result = CirclePoint::<M31> { x: m31(9803698), y: m31(2079025011) };
assert_eq!(expected_result, result);
}
}

0 comments on commit fe6f3ad

Please sign in to comment.