Skip to content

Commit a11fa89

Browse files
Feat/delayed reduction (#970)
* starting to expose API for unreduced ops + standalone reduction * more changes * fmt + clippy * added api after arkworks update * simplify new field API, start changes in spartan * debug, implement delayed reduction for Spartan * changed init_Q to use delayed reduction * implement TODO with `init_Q_dual` * optimized dense + onehot + prefix-suffix evals * fmt + clippy * generic enough? probably not... Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> * remove unsafe allocate methods Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> * replace default with zero() Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> * replace L with N Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> * update cargo.lock Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> * add comments about init_Q Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> * nit Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> --------- Signed-off-by: Andrew Tretyakov <42178850+0xAndoroid@users.noreply.github.com> Co-authored-by: Quang Dao <qvd@andrew.cmu.edu>
1 parent ee1fd08 commit a11fa89

File tree

13 files changed

+1002
-413
lines changed

13 files changed

+1002
-413
lines changed

Cargo.lock

Lines changed: 78 additions & 93 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

jolt-core/src/field/ark.rs

Lines changed: 53 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use ark_ff::{prelude::*, AdditiveGroup, BigInt, PrimeField, UniformRand};
1+
use ark_ff::{prelude::*, BigInt, PrimeField, UniformRand};
22
use rayon::prelude::*;
33

44
use crate::utils::thread::unsafe_allocate_zero_vec;
55

6-
use super::{FieldOps, JoltField};
6+
use super::{FieldOps, FmaddTrunc, JoltField, MulU64WithCarry};
77

88
impl FieldOps for ark_bn254::Fr {}
99
impl FieldOps<&ark_bn254::Fr, ark_bn254::Fr> for &ark_bn254::Fr {}
@@ -27,6 +27,7 @@ impl JoltField for ark_bn254::Fr {
2727
use ark_ff::MontConfig;
2828
std::mem::transmute(<ark_bn254::FrConfig as MontConfig<4>>::R2)
2929
};
30+
type Unreduced<const N: usize> = BigInt<N>;
3031
type SmallValueLookupTables = [Vec<Self>; 2];
3132

3233
fn random<R: rand_core::RngCore>(rng: &mut R) -> Self {
@@ -89,6 +90,7 @@ impl JoltField for ark_bn254::Fr {
8990
}
9091
}
9192

93+
#[inline]
9294
fn from_i64(val: i64) -> Self {
9395
if val.is_negative() {
9496
let val = val.unsigned_abs();
@@ -111,6 +113,7 @@ impl JoltField for ark_bn254::Fr {
111113
}
112114
}
113115

116+
#[inline]
114117
fn from_i128(val: i128) -> Self {
115118
if val.is_negative() {
116119
let val = val.unsigned_abs();
@@ -139,6 +142,7 @@ impl JoltField for ark_bn254::Fr {
139142
}
140143
}
141144

145+
#[inline]
142146
fn from_u128(val: u128) -> Self {
143147
if val <= u16::MAX as u128 {
144148
<Self as JoltField>::from_u16(val as u16)
@@ -152,6 +156,7 @@ impl JoltField for ark_bn254::Fr {
152156
}
153157
}
154158

159+
#[inline]
155160
fn to_u64(&self) -> Option<u64> {
156161
let bigint = <Self as ark_ff::PrimeField>::into_bigint(*self);
157162
let limbs: &[u64] = bigint.as_ref();
@@ -164,35 +169,33 @@ impl JoltField for ark_bn254::Fr {
164169
}
165170
}
166171

172+
#[inline]
167173
fn square(&self) -> Self {
168174
<Self as ark_ff::Field>::square(self)
169175
}
170176

177+
#[inline]
171178
fn inverse(&self) -> Option<Self> {
172179
<Self as ark_ff::Field>::inverse(self)
173180
}
174181

182+
#[inline]
175183
fn from_bytes(bytes: &[u8]) -> Self {
176184
ark_bn254::Fr::from_le_bytes_mod_order(bytes)
177185
}
178186

187+
#[inline]
179188
fn num_bits(&self) -> u32 {
180189
<Self as ark_ff::PrimeField>::into_bigint(*self).num_bits()
181190
}
182191

183192
#[inline(always)]
184-
fn as_bigint_ref(&self) -> &ark_ff::BigInt<4> {
193+
fn as_unreduced_ref(&self) -> &Self::Unreduced<4> {
185194
// arkworks field elements are just wrappers around BigInt, so we can get a direct reference
186195
&self.0
187196
}
188197

189-
#[inline(always)]
190-
fn from_montgomery_reduce_2n(unreduced: ark_ff::BigInt<8>) -> Self {
191-
// Use arkworks Montgomery backend to efficiently reduce 8-limb to 4-limb
192-
ark_bn254::Fr::montgomery_reduce_2n(unreduced)
193-
}
194-
195-
#[inline(always)]
198+
#[inline]
196199
fn mul_u64(&self, n: u64) -> Self {
197200
if n == 0 || self.is_zero() {
198201
Self::zero()
@@ -213,7 +216,7 @@ impl JoltField for ark_bn254::Fr {
213216
ark_ff::Fp::mul_u128::<5, 6>(*self, n)
214217
}
215218

216-
#[inline(always)]
219+
#[inline]
217220
fn mul_i128(&self, n: i128) -> Self {
218221
if n == 0 || self.is_zero() {
219222
Self::zero()
@@ -225,77 +228,50 @@ impl JoltField for ark_bn254::Fr {
225228
}
226229

227230
#[inline]
228-
fn linear_combination_u64(pairs: &[(Self, u64)], add_terms: &[Self]) -> Self {
229-
let mut tmp = ark_ff::BigInt::<4>::mul_u64_w_carry::<5>(&pairs[0].0 .0, pairs[0].1);
230-
for (a, b) in &pairs[1..] {
231-
let carry = tmp.add_with_carry(&ark_ff::BigInt::<4>::mul_u64_w_carry::<5>(&a.0, *b));
232-
debug_assert!(!carry, "spurious carry in linear_combination_u64");
233-
}
231+
fn mul_unreduced<const L: usize>(self, other: Self) -> BigInt<L> {
232+
self.0.mul_trunc::<4, L>(&other.0)
233+
}
234234

235-
// Add the additional terms that don't need multiplication
236-
let mut result = ark_ff::Fp::from_unchecked_nplus1(tmp);
237-
for term in add_terms {
238-
result += *term;
239-
}
240-
result
235+
#[inline]
236+
fn mul_u64_unreduced(self, other: u64) -> BigInt<5> {
237+
self.0.mul_trunc::<1, 5>(&BigInt::new([other]))
241238
}
242239

243240
#[inline]
244-
fn linear_combination_i64(
245-
pos: &[(Self, u64)],
246-
neg: &[(Self, u64)],
247-
pos_add: &[Self],
248-
neg_add: &[Self],
249-
) -> Self {
250-
// unreduced linear combination of positive and negative terms
251-
let mut pos_lc = if !pos.is_empty() {
252-
let mut tmp = ark_ff::BigInt::<4>::mul_u64_w_carry::<5>(&pos[0].0 .0, pos[0].1);
253-
for (a, b) in &pos[1..] {
254-
let carry =
255-
tmp.add_with_carry(&ark_ff::BigInt::<4>::mul_u64_w_carry::<5>(&a.0, *b));
256-
debug_assert!(!carry, "spurious carry in linear_combination_i64");
257-
}
258-
tmp
259-
} else {
260-
ark_ff::BigInt::<5>::zero()
261-
};
262-
263-
let mut neg_lc = if !neg.is_empty() {
264-
let mut tmp = ark_ff::BigInt::<4>::mul_u64_w_carry::<5>(&neg[0].0 .0, neg[0].1);
265-
for (a, b) in &neg[1..] {
266-
let carry =
267-
tmp.add_with_carry(&ark_ff::BigInt::<4>::mul_u64_w_carry::<5>(&a.0, *b));
268-
debug_assert!(!carry, "spurious carry in linear_combination_i64");
269-
}
270-
tmp
271-
} else {
272-
ark_ff::BigInt::<5>::zero()
273-
};
274-
275-
// Compute the difference of the linear combinations
276-
let diff = match pos_lc.cmp(&neg_lc) {
277-
std::cmp::Ordering::Greater => {
278-
let borrow = pos_lc.sub_with_borrow(&neg_lc);
279-
debug_assert!(!borrow, "spurious borrow in linear_combination_i64");
280-
ark_ff::Fp::from_unchecked_nplus1(pos_lc)
281-
}
282-
std::cmp::Ordering::Less => {
283-
let borrow = neg_lc.sub_with_borrow(&pos_lc);
284-
debug_assert!(!borrow, "spurious borrow in linear_combination_i64");
285-
*ark_ff::Fp::from_unchecked_nplus1(neg_lc).neg_in_place()
286-
}
287-
std::cmp::Ordering::Equal => ark_ff::Fp::zero(),
288-
};
241+
fn mul_u128_unreduced(self, other: u128) -> BigInt<6> {
242+
self.0
243+
.mul_trunc::<2, 6>(&BigInt::new([other as u64, (other >> 64) as u64]))
244+
}
289245

290-
// Add the positive and negative add terms
291-
let mut result = diff;
292-
for term in pos_add {
293-
result += *term;
294-
}
295-
for term in neg_add {
296-
result -= *term;
297-
}
298-
result
246+
#[inline]
247+
fn from_montgomery_reduce<const L: usize>(unreduced: BigInt<L>) -> Self {
248+
ark_bn254::Fr::from_montgomery_reduce::<L, 5>(unreduced)
249+
}
250+
251+
#[inline]
252+
fn from_barrett_reduce<const L: usize>(unreduced: BigInt<L>) -> Self {
253+
ark_bn254::Fr::from_barrett_reduce::<L, 5>(unreduced)
254+
}
255+
}
256+
257+
impl<const N: usize> FmaddTrunc for BigInt<N> {
258+
type Other<const M: usize> = BigInt<M>;
259+
type Acc<const P: usize> = BigInt<P>;
260+
261+
fn fmadd_trunc<const M: usize, const P: usize>(
262+
&self,
263+
other: &Self::Other<M>,
264+
acc: &mut Self::Acc<P>,
265+
) {
266+
self.fmadd_trunc(other, acc)
267+
}
268+
}
269+
270+
impl<const N: usize> MulU64WithCarry for BigInt<N> {
271+
type Output<const NPLUS1: usize> = BigInt<NPLUS1>;
272+
273+
fn mul_u64_w_carry<const NPLUS1: usize>(&self, other: u64) -> Self::Output<NPLUS1> {
274+
<BigInt<N> as BigInteger>::mul_u64_w_carry(self, other)
299275
}
300276
}
301277

0 commit comments

Comments
 (0)