Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 92 additions & 30 deletions constraint-solver/src/grouped_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ impl<F: FieldElement, T: RuntimeConstant<FieldType = F>, V> GroupedExpression<T,
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Zero for GroupedExpression<T, V> {
impl<T: Zero, V: Clone + Ord + Eq> Zero for GroupedExpression<T, V>
where
// The bounds are so strict because Zero requires Add
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + Clone,
V: Clone + Ord + Eq,
{
fn zero() -> Self {
Self {
quadratic: Default::default(),
Expand All @@ -67,11 +72,16 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> Zero for GroupedExpression<T, V> {
}

fn is_zero(&self) -> bool {
self.try_to_known().is_some_and(|k| k.is_known_zero())
self.try_to_known().is_some_and(|k| k.is_zero())
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> One for GroupedExpression<T, V> {
impl<T, V> One for GroupedExpression<T, V>
where
// The bounds are so strict because One requise Mul
T: Zero + One + PartialEq + Neg<Output = T> + AddAssign<T> + MulAssign<T> + Clone,
V: Clone + Ord + Eq,
{
fn one() -> Self {
Self {
quadratic: Default::default(),
Expand All @@ -81,7 +91,7 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> One for GroupedExpression<T, V> {
}

fn is_one(&self) -> bool {
self.try_to_known().is_some_and(|k| k.is_known_one())
self.try_to_known().is_some_and(|k| k.is_one())
}
}

Expand All @@ -91,23 +101,27 @@ impl<F: FieldElement, V: Ord + Clone + Eq> GroupedExpression<SymbolicExpression<
}
}

impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
impl<T, V> GroupedExpression<T, V> {
pub fn from_runtime_constant(constant: T) -> Self {
Self {
quadratic: Default::default(),
linear: Default::default(),
constant,
}
}
}

impl<T: Zero + One, V: Ord> GroupedExpression<T, V> {
pub fn from_unknown_variable(var: V) -> Self {
Self {
quadratic: Default::default(),
linear: [(var.clone(), T::one())].into_iter().collect(),
linear: [(var, T::one())].into_iter().collect(),
constant: T::zero(),
}
}
}

impl<T, V> GroupedExpression<T, V> {
/// If this expression does not contain unknown variables, returns the symbolic expression.
pub fn try_to_known(&self) -> Option<&T> {
if self.quadratic.is_empty() && self.linear.is_empty() {
Expand All @@ -122,33 +136,37 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
!self.is_quadratic()
}

/// Returns true if this expression contains at least one quadratic term.
pub fn is_quadratic(&self) -> bool {
!self.quadratic.is_empty()
}
}

impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
/// If the expression is a known number, returns it.
pub fn try_to_number(&self) -> Option<T::FieldType> {
self.try_to_known()?.try_to_number()
}
}

impl<T: One + Zero + PartialEq, V: Ord + Clone + Eq> GroupedExpression<T, V> {
/// If the expression is equal to `GroupedExpression::from_unknown_variable(v)`, returns `v`.
pub fn try_to_simple_unknown(&self) -> Option<V> {
if self.is_quadratic() || !self.constant.is_known_zero() {
if self.is_quadratic() || !self.constant.is_zero() {
return None;
}
let Ok((var, coeff)) = self.linear.iter().exactly_one() else {
return None;
};
if !coeff.is_known_one() {
if !coeff.is_one() {
return None;
}
Some(var.clone())
}

/// Returns true if this expression contains at least one quadratic term.
pub fn is_quadratic(&self) -> bool {
!self.quadratic.is_empty()
}

/// Returns `(l, r)` if `self == l * r`.
pub fn try_as_single_product(&self) -> Option<(&Self, &Self)> {
if self.linear.is_empty() && self.constant.is_known_zero() {
if self.linear.is_empty() && self.constant.is_zero() {
match self.quadratic.as_slice() {
[(l, r)] => Some((l, r)),
_ => None,
Expand All @@ -157,7 +175,9 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
None
}
}
}

impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
/// Returns `vec![f1, f2, ..., fn]` such that `self` is equivalent to
/// `c * f1 * f2 * ... * fn` for some constant `c`.
/// Tries to find as many factors as possible and also tries to normalize
Expand Down Expand Up @@ -195,7 +215,9 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
vec![self.clone() * T::one().field_div(&divide_by)]
}
}
}

impl<T: Clone, V: Ord + Clone + Eq> GroupedExpression<T, V> {
/// Returns the quadratic, linear and constant components of this expression.
pub fn components(
&self,
Expand Down Expand Up @@ -242,7 +264,9 @@ impl<T: RuntimeConstant, V: Ord + Clone + Eq> GroupedExpression<T, V> {
assert!(!self.is_quadratic());
self.linear.get(var)
}
}

impl<T: RuntimeConstant + Clone, V: Ord + Clone + Eq> GroupedExpression<T, V> {
/// Returns the range constraint of the full expression.
pub fn range_constraint(
&self,
Expand Down Expand Up @@ -440,7 +464,11 @@ impl<T: FieldElement, V> RangeConstraintProvider<T, V> for NoRangeConstraints {
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Add for GroupedExpression<T, V> {
impl<T, V> Add for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + Clone,
V: Clone + Ord + Eq,
{
type Output = GroupedExpression<T, V>;

fn add(mut self, rhs: Self) -> Self {
Expand All @@ -449,16 +477,22 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> Add for GroupedExpression<T, V> {
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Add for &GroupedExpression<T, V> {
impl<T, V> Add for &GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + Clone,
V: Clone + Ord + Eq,
{
type Output = GroupedExpression<T, V>;

fn add(self, rhs: Self) -> Self::Output {
self.clone() + rhs.clone()
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> AddAssign<GroupedExpression<T, V>>
for GroupedExpression<T, V>
impl<T, V> AddAssign<GroupedExpression<T, V>> for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + Clone,
V: Clone + Ord + Eq,
{
fn add_assign(&mut self, rhs: Self) {
self.quadratic = combine_removing_zeros(std::mem::take(&mut self.quadratic), rhs.quadratic);
Expand All @@ -469,7 +503,7 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> AddAssign<GroupedExpression<T, V>>
.or_insert_with(|| coeff);
}
self.constant += rhs.constant.clone();
self.linear.retain(|_, f| !f.is_known_zero());
self.linear.retain(|_, f| !f.is_zero());
}
}

Expand Down Expand Up @@ -550,23 +584,31 @@ where
[n1, n2].contains(&(&first.0, &first.1)) || [n1, n2].contains(&(&first.1, &first.0))
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Sub for &GroupedExpression<T, V> {
impl<T, V> Sub for &GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + Clone,
V: Clone + Ord + Eq,
{
type Output = GroupedExpression<T, V>;

fn sub(self, rhs: Self) -> Self::Output {
self + &-rhs
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Sub for GroupedExpression<T, V> {
impl<T, V> Sub for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + Clone,
V: Clone + Ord + Eq,
{
type Output = GroupedExpression<T, V>;

fn sub(self, rhs: Self) -> Self::Output {
&self - &rhs
}
}

impl<T: RuntimeConstant, V: Clone + Ord> GroupedExpression<T, V> {
impl<T: Neg<Output = T> + Clone, V: Clone + Ord> GroupedExpression<T, V> {
fn negate(&mut self) {
for (first, _) in &mut self.quadratic {
first.negate()
Expand All @@ -578,7 +620,7 @@ impl<T: RuntimeConstant, V: Clone + Ord> GroupedExpression<T, V> {
}
}

impl<T: RuntimeConstant, V: Clone + Ord> Neg for GroupedExpression<T, V> {
impl<T: Neg<Output = T> + Clone, V: Clone + Ord> Neg for GroupedExpression<T, V> {
type Output = GroupedExpression<T, V>;

fn neg(mut self) -> Self {
Expand All @@ -587,7 +629,7 @@ impl<T: RuntimeConstant, V: Clone + Ord> Neg for GroupedExpression<T, V> {
}
}

impl<T: RuntimeConstant, V: Clone + Ord> Neg for &GroupedExpression<T, V> {
impl<T: Neg<Output = T> + Clone, V: Clone + Ord> Neg for &GroupedExpression<T, V> {
type Output = GroupedExpression<T, V>;

fn neg(self) -> Self::Output {
Expand All @@ -596,7 +638,11 @@ impl<T: RuntimeConstant, V: Clone + Ord> Neg for &GroupedExpression<T, V> {
}

/// Multiply by known symbolic expression.
impl<T: RuntimeConstant, V: Clone + Ord + Eq> Mul<&T> for GroupedExpression<T, V> {
impl<T, V> Mul<&T> for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + MulAssign<T> + Clone,
V: Clone + Ord + Eq,
{
type Output = GroupedExpression<T, V>;

fn mul(mut self, rhs: &T) -> Self {
Expand All @@ -605,17 +651,25 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> Mul<&T> for GroupedExpression<T, V
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Mul<T> for GroupedExpression<T, V> {
impl<T, V> Mul<T> for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + MulAssign<T> + Clone,
V: Clone + Ord + Eq,
{
type Output = GroupedExpression<T, V>;

fn mul(self, rhs: T) -> Self {
self * &rhs
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> MulAssign<&T> for GroupedExpression<T, V> {
impl<T, V> MulAssign<&T> for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + MulAssign<T> + Clone,
V: Clone + Ord + Eq,
{
fn mul_assign(&mut self, rhs: &T) {
if rhs.is_known_zero() {
if rhs.is_zero() {
*self = Self::zero();
} else {
for (first, _) in &mut self.quadratic {
Expand All @@ -629,7 +683,11 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> MulAssign<&T> for GroupedExpressio
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Sum for GroupedExpression<T, V> {
impl<T, V> Sum for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + Clone,
V: Clone + Ord + Eq,
{
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::zero(), |mut acc, item| {
acc += item;
Expand All @@ -638,7 +696,11 @@ impl<T: RuntimeConstant, V: Clone + Ord + Eq> Sum for GroupedExpression<T, V> {
}
}

impl<T: RuntimeConstant, V: Clone + Ord + Eq> Mul for GroupedExpression<T, V> {
impl<T, V> Mul for GroupedExpression<T, V>
where
T: Zero + PartialEq + Neg<Output = T> + AddAssign<T> + MulAssign<T> + Clone,
V: Clone + Ord + Eq,
{
type Output = GroupedExpression<T, V>;

fn mul(self, rhs: GroupedExpression<T, V>) -> Self {
Expand Down
Loading