Skip to content

Commit e71a7f1

Browse files
minor cleanup
1 parent 3d30cd4 commit e71a7f1

6 files changed

Lines changed: 55 additions & 63 deletions

File tree

src/rewrite_system/rise/nat/applier.rs

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use egg::{
2-
Applier, EGraph, ENodeOrVar, Id, Language, Pattern, PatternAst, RecExpr, Searcher, Subst,
3-
Symbol, Var,
4-
};
1+
use egg::{Applier, EGraph, Id, Pattern, PatternAst, Searcher, Subst, Symbol, Var};
52

63
use super::{Rise, RiseAnalysis};
74

@@ -73,7 +70,7 @@ impl<A: Applier<Rise, RiseAnalysis>> Applier<Rise, RiseAnalysis> for ComputeNatC
7370
rule_name: Symbol,
7471
) -> Vec<Id> {
7572
let expected = &egraph[subst[self.var]].data.beta_extract.clone();
76-
let extracted = &extract_small(egraph, &self.nat_pattern, subst);
73+
let extracted = &super::extract_small(egraph, &self.nat_pattern, subst);
7774
let a = &mut egraph.analysis;
7875
if super::check_equivalence(a, expected, extracted) {
7976
self.applier
@@ -83,39 +80,3 @@ impl<A: Applier<Rise, RiseAnalysis>> Applier<Rise, RiseAnalysis> for ComputeNatC
8380
}
8481
}
8582
}
86-
87-
// Quick check for trivial cases:
88-
// fn quick_check(lhs: &RecExpr<Math>, lhs_id: Id, rhs: &RecExpr<Math>, rhs_id: Id) -> bool {
89-
// lhs[lhs_id].matches(&rhs[rhs_id])
90-
// && lhs[lhs_id]
91-
// .children()
92-
// .iter()
93-
// .zip(rhs[rhs_id].children())
94-
// .all(|(lcid, rcid)| quick_check(lhs, *lcid, rhs, *rcid))
95-
// }
96-
97-
// if quick_check(expected, expected.root(), extracted, extracted.root()) {
98-
// return true;
99-
// }
100-
101-
fn extract_small(
102-
egraph: &EGraph<Rise, RiseAnalysis>,
103-
pattern: &Pattern<Rise>,
104-
subst: &Subst,
105-
) -> RecExpr<Rise> {
106-
fn rec(
107-
ast: &PatternAst<Rise>,
108-
id: Id,
109-
subst: &Subst,
110-
egraph: &EGraph<Rise, RiseAnalysis>,
111-
) -> RecExpr<Rise> {
112-
match &ast[id] {
113-
ENodeOrVar::Var(w) => egraph[subst[*w]].data.beta_extract.clone(),
114-
ENodeOrVar::ENode(e) => {
115-
let new_e = e.clone();
116-
new_e.join_recexprs(|i| rec(ast, i, subst, egraph))
117-
}
118-
}
119-
}
120-
rec(&pattern.ast, pattern.ast.root(), subst, egraph)
121-
}

src/rewrite_system/rise/nat/mod.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod rational;
55

66
use std::num::TryFromIntError;
77

8-
use egg::RecExpr;
8+
use egg::{EGraph, ENodeOrVar, Id, Language, Pattern, PatternAst, RecExpr, Subst};
99
use thiserror::Error;
1010

1111
use super::{Rise, RiseAnalysis};
@@ -46,6 +46,28 @@ fn check_equivalence<'a, 'b: 'a>(
4646
false
4747
}
4848

49+
fn extract_small(
50+
egraph: &EGraph<Rise, RiseAnalysis>,
51+
pattern: &Pattern<Rise>,
52+
subst: &Subst,
53+
) -> RecExpr<Rise> {
54+
fn rec(
55+
ast: &PatternAst<Rise>,
56+
id: Id,
57+
subst: &Subst,
58+
egraph: &EGraph<Rise, RiseAnalysis>,
59+
) -> RecExpr<Rise> {
60+
match &ast[id] {
61+
ENodeOrVar::Var(w) => egraph[subst[*w]].data.beta_extract.clone(),
62+
ENodeOrVar::ENode(e) => {
63+
let new_e = e.clone();
64+
new_e.join_recexprs(|i| rec(ast, i, subst, egraph))
65+
}
66+
}
67+
}
68+
rec(&pattern.ast, pattern.ast.root(), subst, egraph)
69+
}
70+
4971
// ============================================================================
5072
// Error Types
5173
// ============================================================================

src/rewrite_system/rise/nat/polynomial/from.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ use crate::rewrite_system::rise::Index;
99
// RecExpr Conversions
1010
// ============================================================================
1111

12-
// -------------------------------------
13-
// From Polynomial to RecExpr<Rise>
14-
// -------------------------------------
15-
1612
impl From<&Polynomial> for RecExpr<Rise> {
1713
fn from(p: &Polynomial) -> Self {
1814
let mut expr = RecExpr::default();
@@ -69,19 +65,20 @@ impl From<Polynomial> for RecExpr<Rise> {
6965
// From Simple Types
7066
// ============================================================================
7167

72-
/// Create a polynomial from an integer constant
68+
/// Create a `Polynomial` from an integer constant
7369
impl From<i32> for Polynomial {
7470
fn from(n: i32) -> Self {
7571
Self::new().add_term(n.into(), Monomial::new())
7672
}
7773
}
78-
/// Create a polynomial from an integer constant
74+
/// Create a `Polynomial` from an integer constant
7975
impl From<Ratio<i32>> for Polynomial {
8076
fn from(r: Ratio<i32>) -> Self {
8177
Self::new().add_term(r, Monomial::new())
8278
}
8379
}
8480

81+
/// Create a `Polynomial` from with a single variable
8582
impl From<Index> for Polynomial {
8683
fn from(index: Index) -> Self {
8784
Self::new().add_term(Ratio::one(), Monomial::new().with_var(index, 1))

src/rewrite_system/rise/nat/polynomial/mod.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ impl Polynomial {
4040
}
4141

4242
/// Create a polynomial from a constant integer
43-
pub fn constant(n: i32) -> Self {
43+
pub fn from_i32(n: i32) -> Self {
4444
n.into()
4545
}
4646

4747
/// Create a polynomial from a constant ratio
48-
pub fn constant_ratio(r: Ratio<i32>) -> Self {
48+
pub fn from_ratio(r: Ratio<i32>) -> Self {
4949
r.into()
5050
}
5151

@@ -228,9 +228,11 @@ impl Polynomial {
228228
if b.is_zero() {
229229
return Ok(a);
230230
}
231-
if a.is_constant() && b.is_constant() {
231+
if let Some(a_const) = a.as_constant()
232+
&& let Some(b_const) = b.as_constant()
233+
{
232234
// GCD of constants
233-
let gcd = gcd_ratio(a.as_constant().unwrap(), b.as_constant().unwrap());
235+
let gcd = gcd_ratio(a_const, b_const);
234236
return Ok((gcd).into());
235237
}
236238

@@ -243,7 +245,7 @@ impl Polynomial {
243245
a.as_constant().unwrap_or(Ratio::one()),
244246
b.as_constant().unwrap_or(Ratio::one()),
245247
);
246-
return Ok(Polynomial::constant_ratio(gcd));
248+
return Ok((gcd).into());
247249
}
248250

249251
// Factor out content (GCD of coefficients) first
@@ -265,7 +267,7 @@ impl Polynomial {
265267
if content_gcd.is_one() {
266268
Ok(primitive_gcd)
267269
} else {
268-
Ok(primitive_gcd * Polynomial::constant_ratio(content_gcd))
270+
Ok(primitive_gcd * Polynomial::from_ratio(content_gcd))
269271
}
270272
}
271273

@@ -874,7 +876,7 @@ mod tests {
874876
}
875877

876878
fn constant(n: i32) -> Polynomial {
877-
Polynomial::constant(n)
879+
Polynomial::from_i32(n)
878880
}
879881

880882
#[test]

src/rewrite_system/rise/nat/rational/from.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use egg::{Id, Language, RecExpr};
2+
use num::rational::Ratio;
23

34
use super::{NatSolverError, Polynomial, RationalFunction, Rise};
45
use crate::rewrite_system::rise::Index;
@@ -39,6 +40,7 @@ impl TryFrom<RationalFunction> for Polynomial {
3940
// RationalFunction -> RecExpr<Rise>
4041
// -------------------------------------
4142

43+
// TODO: i think this has a bug but not sure
4244
impl From<&RationalFunction> for RecExpr<Rise> {
4345
fn from(rf: &RationalFunction) -> Self {
4446
let mut expr = RecExpr::default();
@@ -161,15 +163,23 @@ impl TryFrom<&RecExpr<Rise>> for RationalFunction {
161163
// From Simple Types
162164
// ============================================================================
163165

164-
impl From<Index> for RationalFunction {
165-
fn from(index: Index) -> Self {
166-
(Polynomial::var(index)).into()
166+
/// Create a `RationalFunction` from an integer constant
167+
impl From<i32> for RationalFunction {
168+
fn from(n: i32) -> RationalFunction {
169+
Polynomial::from_i32(n).into()
167170
}
168171
}
169172

170-
impl From<i32> for RationalFunction {
171-
fn from(n: i32) -> RationalFunction {
172-
let p: Polynomial = n.into();
173-
p.into()
173+
/// Create a `RationalFunction` from an integer constant
174+
impl From<Ratio<i32>> for RationalFunction {
175+
fn from(r: Ratio<i32>) -> Self {
176+
Polynomial::from_ratio(r).into()
177+
}
178+
}
179+
180+
/// Create a `RationalFunction` from with a single variable
181+
impl From<Index> for RationalFunction {
182+
fn from(index: Index) -> Self {
183+
(Polynomial::var(index)).into()
174184
}
175185
}

src/rewrite_system/rise/nat/rational/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ mod tests {
570570
// 1/(-x) should normalize to (-1)/x or -x^(-1)
571571
let rf = RationalFunction::new(
572572
Polynomial::one(),
573-
Polynomial::constant(-1) * Polynomial::var(idx(1)),
573+
Polynomial::from_i32(-1) * Polynomial::var(idx(1)),
574574
)
575575
.unwrap();
576576

0 commit comments

Comments
 (0)