Skip to content
19 changes: 18 additions & 1 deletion jolt-core/src/poly/commitment/hyperkzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
};
use crate::field::JoltField;
use crate::poly::multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation};
use crate::poly::rlc_polynomial::RLCPolynomial;
use crate::{
msm::VariableBaseMSM,
poly::{commitment::kzg::SRS, dense_mlpoly::DensePolynomial, unipoly::UniPoly},
Expand Down Expand Up @@ -158,7 +159,23 @@ where
let scalars = v.iter().flatten().collect::<Vec<&P::ScalarField>>();
transcript.append_scalars::<P::ScalarField>(&scalars);
let q_powers: Vec<P::ScalarField> = transcript.challenge_scalar_powers(f.len());
let B = MultilinearPolynomial::linear_combination(&f.iter().collect::<Vec<_>>(), &q_powers);
let f_arc: Vec<Arc<MultilinearPolynomial<P::ScalarField>>> =
f.iter().map(|poly| Arc::new(poly.clone())).collect();

// @TODO(markosg04) right now we don't use HyperKZG so we just handle both cases
let has_one_hot = f_arc
.iter()
.any(|poly| matches!(poly.as_ref(), MultilinearPolynomial::OneHot(_)));

let B = if has_one_hot {
let rlc_result = RLCPolynomial::linear_combination(f_arc, &q_powers);
MultilinearPolynomial::RLC(rlc_result)
} else {
let poly_refs: Vec<&MultilinearPolynomial<P::ScalarField>> =
f_arc.iter().map(|arc| arc.as_ref()).collect();
let dense_result = DensePolynomial::linear_combination(&poly_refs, &q_powers);
MultilinearPolynomial::from(dense_result.Z)
};

// Now open B at u0, ..., u_{t-1}
let w = kzg_batch_open_no_rem(&B, u, pk);
Expand Down
51 changes: 50 additions & 1 deletion jolt-core/src/poly/dense_mlpoly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ use crate::utils::thread::unsafe_allocate_zero_vec;
use crate::utils::{compute_dotproduct, compute_dotproduct_low_optimized};

use crate::field::{JoltField, OptimizedMul};
use crate::poly::compact_polynomial::SmallScalar;
use crate::utils::math::Math;
use allocative::Allocative;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use core::ops::Index;
use rand_core::{CryptoRng, RngCore};
use rayon::prelude::*;

use super::multilinear_polynomial::BindingOrder;
use super::multilinear_polynomial::{BindingOrder, MultilinearPolynomial};

#[derive(Default, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize, Allocative)]
pub struct DensePolynomial<F: JoltField> {
Expand Down Expand Up @@ -408,6 +409,54 @@ impl<F: JoltField> DensePolynomial<F> {
.collect(),
)
}

#[tracing::instrument(skip_all)]
pub fn linear_combination(
polynomials: &[&MultilinearPolynomial<F>],
coefficients: &[F],
) -> Self {
debug_assert_eq!(polynomials.len(), coefficients.len());

let max_length = polynomials
.iter()
.map(|poly| poly.original_len())
.max()
.unwrap();

let result: Vec<F> = (0..max_length)
.into_par_iter()
.map(|i| {
let mut acc = F::zero();
for (coeff, poly) in coefficients.iter().zip(polynomials.iter()) {
if i < poly.original_len() {
match poly {
MultilinearPolynomial::LargeScalars(p) => {
acc += p.evals_ref()[i].mul_01_optimized(*coeff);
}
MultilinearPolynomial::U8Scalars(p) => {
acc += p.coeffs[i].field_mul(*coeff);
}
MultilinearPolynomial::U16Scalars(p) => {
acc += p.coeffs[i].field_mul(*coeff);
}
MultilinearPolynomial::U32Scalars(p) => {
acc += p.coeffs[i].field_mul(*coeff);
}
MultilinearPolynomial::U64Scalars(p) => {
acc += p.coeffs[i].field_mul(*coeff);
}
MultilinearPolynomial::I64Scalars(p) => {
acc += p.coeffs[i].field_mul(*coeff);
}
_ => unreachable!(),
}
}
}
acc
})
.collect();
DensePolynomial::new(result)
}
}

impl<F: JoltField> Clone for DensePolynomial<F> {
Expand Down
103 changes: 1 addition & 102 deletions jolt-core/src/poly/multilinear_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
};
use allocative::Allocative;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Valid};
use num_traits::MulAdd;
use rayon::prelude::*;
use strum_macros::EnumIter;

Expand All @@ -13,10 +12,7 @@ use super::{
dense_mlpoly::DensePolynomial,
eq_poly::EqPolynomial,
};
use crate::{
field::{JoltField, OptimizedMul},
utils::thread::unsafe_allocate_zero_vec,
};
use crate::field::{JoltField, OptimizedMul};

/// Wrapper enum for the various multilinear polynomial types used in Jolt
#[repr(u8)]
Expand Down Expand Up @@ -115,103 +111,6 @@ impl<F: JoltField> MultilinearPolynomial<F> {
}
}

#[tracing::instrument(skip_all)]
pub fn linear_combination(polynomials: &[&Self], coefficients: &[F]) -> Self {
debug_assert_eq!(polynomials.len(), coefficients.len());

// If there's at least one sparse polynomial in `polynomials`, the linear
// combination will be represented by an `RLCPolynomial`. Otherwise, it will
// be represented by a `DensePolynomial`.
if polynomials
.iter()
.any(|poly| matches!(poly, MultilinearPolynomial::OneHot(_)))
{
let mut result = RLCPolynomial::<F>::new();
for (coeff, polynomial) in coefficients.iter().zip(polynomials.iter()) {
result = match polynomial {
MultilinearPolynomial::LargeScalars(poly) => poly.mul_add(*coeff, result),
MultilinearPolynomial::U8Scalars(poly) => poly.mul_add(*coeff, result),
MultilinearPolynomial::U16Scalars(poly) => poly.mul_add(*coeff, result),
MultilinearPolynomial::U32Scalars(poly) => poly.mul_add(*coeff, result),
MultilinearPolynomial::U64Scalars(poly) => poly.mul_add(*coeff, result),
MultilinearPolynomial::I64Scalars(poly) => poly.mul_add(*coeff, result),
MultilinearPolynomial::OneHot(poly) => poly.mul_add(*coeff, result),
_ => unimplemented!("Unexpected polynomial type"),
};
}
return MultilinearPolynomial::RLC(result);
}

let max_length = polynomials
.iter()
.map(|poly| poly.original_len())
.max()
.unwrap();
let num_chunks = rayon::current_num_threads()
.next_power_of_two()
.min(max_length);
let chunk_size = (max_length / num_chunks).max(1);

let lc_coeffs: Vec<F> = (0..num_chunks)
.into_par_iter()
.flat_map_iter(|chunk_index| {
let index = chunk_index * chunk_size;
let mut chunk = unsafe_allocate_zero_vec::<F>(chunk_size);

for (coeff, poly) in coefficients.iter().zip(polynomials.iter()) {
let poly_len = poly.original_len();
if index >= poly_len {
continue;
}

match poly {
MultilinearPolynomial::LargeScalars(poly) => {
debug_assert!(!poly.is_bound());
let poly_evals = &poly.evals_ref()[index..];
for (rlc, poly_eval) in chunk.iter_mut().zip(poly_evals.iter()) {
*rlc += poly_eval.mul_01_optimized(*coeff);
}
}
MultilinearPolynomial::U8Scalars(poly) => {
let poly_evals = &poly.coeffs[index..];
for (rlc, poly_eval) in chunk.iter_mut().zip(poly_evals.iter()) {
*rlc += poly_eval.field_mul(*coeff);
}
}
MultilinearPolynomial::U16Scalars(poly) => {
let poly_evals = &poly.coeffs[index..];
for (rlc, poly_eval) in chunk.iter_mut().zip(poly_evals.iter()) {
*rlc += poly_eval.field_mul(*coeff);
}
}
MultilinearPolynomial::U32Scalars(poly) => {
let poly_evals = &poly.coeffs[index..];
for (rlc, poly_eval) in chunk.iter_mut().zip(poly_evals.iter()) {
*rlc += poly_eval.field_mul(*coeff);
}
}
MultilinearPolynomial::U64Scalars(poly) => {
let poly_evals = &poly.coeffs[index..];
for (rlc, poly_eval) in chunk.iter_mut().zip(poly_evals.iter()) {
*rlc += poly_eval.field_mul(*coeff);
}
}
MultilinearPolynomial::I64Scalars(poly) => {
let poly_evals = &poly.coeffs[index..];
for (rlc, poly_eval) in chunk.iter_mut().zip(poly_evals.iter()) {
*rlc += poly_eval.field_mul(*coeff);
}
}
_ => unimplemented!("Unexpected MultilinearPolynomial variant"),
}
}
chunk
})
.collect();

MultilinearPolynomial::from(lc_coeffs)
}

/// Gets the polynomial coefficient at the given `index`
pub fn get_coeff(&self, index: usize) -> F {
match self {
Expand Down
22 changes: 14 additions & 8 deletions jolt-core/src/poly/opening_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};

use super::{
commitment::commitment_scheme::CommitmentScheme,
dense_mlpoly::DensePolynomial,
eq_poly::EqPolynomial,
multilinear_polynomial::{BindingOrder, MultilinearPolynomial, PolynomialBinding},
rlc_polynomial::RLCPolynomial,
split_eq_poly::GruenSplitEqPolynomial,
};
#[cfg(feature = "allocative")]
Expand Down Expand Up @@ -407,14 +409,18 @@ where

if let Some(prover_state) = self.prover_state.as_mut() {
let polynomials_map = polynomials_map.unwrap();
let polynomials: Vec<_> = self

let polynomials: Vec<&MultilinearPolynomial<F>> = self
.polynomials
.par_iter()
.map(|label| polynomials_map.get(label).unwrap())
.collect();

let rlc_poly =
MultilinearPolynomial::linear_combination(&polynomials, &self.rlc_coeffs);
let result =
DensePolynomial::linear_combination(polynomials.as_ref(), &self.rlc_coeffs);

let rlc_poly = MultilinearPolynomial::from(result.Z);

debug_assert_eq!(rlc_poly.evaluate(&self.opening_point), reduced_claim);
let num_vars = rlc_poly.get_num_vars();

Expand Down Expand Up @@ -450,8 +456,8 @@ where
match prover_state {
ProverOpening::Dense(opening) => opening.polynomial = Some(poly.clone()),
ProverOpening::OneHot(opening) => {
if let MultilinearPolynomial::OneHot(poly) = poly {
opening.initialize(poly.clone());
if let MultilinearPolynomial::OneHot(one_hot) = poly {
opening.initialize(one_hot.clone());
} else {
panic!("Unexpected non-one-hot polynomial")
}
Expand Down Expand Up @@ -846,10 +852,10 @@ where
.map(|(k, v)| (v, polynomials.remove(k).unwrap()))
.unzip();

MultilinearPolynomial::linear_combination(
&polynomials.iter().collect::<Vec<_>>(),
MultilinearPolynomial::RLC(RLCPolynomial::linear_combination(
polynomials.into_iter().map(Arc::new).collect(),
&coeffs,
)
))
};

#[cfg(test)]
Expand Down
Loading