Skip to content

Commit 9e1c609

Browse files
minor cleanup
1 parent 8ca1366 commit 9e1c609

8 files changed

Lines changed: 88 additions & 47 deletions

File tree

src/rewrite_system/rise/analysis.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl Analysis<Rise> for RiseAnalysis {
5454
.free
5555
.iter()
5656
.filter(|idx| !idx.is_zero() && idx.kind() == enode.kind())
57-
.map(|idx| idx.downshifted())
57+
.map(|idx| idx.dec())
5858
.collect(),
5959
_ => enode
6060
.children()
@@ -63,7 +63,7 @@ impl Analysis<Rise> for RiseAnalysis {
6363
.copied()
6464
.collect(),
6565
};
66-
let empty = enode.any(|id| egraph[id].data.beta_extract.as_ref().is_empty());
66+
let empty = enode.any(|id| egraph[id].data.beta_extract.is_empty());
6767
let beta_extract = if empty {
6868
RecExpr::default()
6969
} else {

src/rewrite_system/rise/func.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ fn vec_expr(
145145
Rise::Lambda(e) => {
146146
let v_env2 = v_env
147147
.into_iter()
148-
.map(|i| i.upshifted())
148+
.map(|i| i.inc())
149149
.chain([Index::zero(Kind::Expr)])
150150
.collect::<HashSet<_>>();
151151

src/rewrite_system/rise/indices.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ pub enum Index {
1515
}
1616

1717
impl Index {
18-
pub fn upshifted(self) -> Self {
18+
pub fn inc(self) -> Self {
1919
self + Shift::up()
2020
}
2121

22-
pub fn downshifted(self) -> Self {
22+
pub fn dec(self) -> Self {
2323
self + Shift::down()
2424
}
2525

@@ -50,7 +50,7 @@ impl Index {
5050
}
5151
}
5252

53-
fn value(self) -> u32 {
53+
pub fn value(self) -> u32 {
5454
match self {
5555
Index::Expr(i) | Index::Nat(i) | Index::Data(i) | Index::Addr(i) => i,
5656
}
@@ -61,18 +61,11 @@ impl Index {
6161
}
6262
}
6363

64-
impl PartialEq<u32> for &Index {
65-
fn eq(&self, other: &u32) -> bool {
66-
self.value().eq(other)
67-
}
68-
}
69-
7064
impl std::ops::Add<Shift> for Index {
7165
type Output = Self;
7266

7367
fn add(self, rhs: Shift) -> Self::Output {
74-
let v = |i: u32| i.checked_add_signed(rhs.0).unwrap();
75-
68+
let v = |i: u32| i.strict_add_signed(rhs.0);
7669
match self {
7770
Index::Expr(i) => Index::Expr(v(i)),
7871
Index::Nat(i) => Index::Nat(v(i)),

src/rewrite_system/rise/kind.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ impl Kindable for Var {
4141
'd' | 't' => Kind::Data,
4242
'a' => Kind::Addr,
4343
'n' => Kind::Nat,
44-
_ => Kind::Expr,
44+
x if x.is_numeric() => Kind::Expr,
45+
x => panic!("Wrong format {x}"),
4546
})
4647
}
4748
}

src/rewrite_system/rise/lang.rs

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -191,35 +191,79 @@ impl PPNode {
191191
.map(|c_id| Self::new(expr, *c_id, skip_wrapper))
192192
.collect(),
193193
expr: expr[id].to_string().into(),
194-
ty: format!("NO TYPE INFO AVAILABLE {{{}}}", kind_string(&expr[id])).red(),
194+
ty: ("NO TYPE INFO AVAILABLE").red(),
195195
};
196196
};
197197
let node = &expr[*expr_id];
198-
match color_expr(node, skip_wrapper) {
199-
Ok(expr_string) => Self {
200-
children: expr[*expr_id]
201-
.children()
202-
.iter()
203-
.map(|c_id| Self::new(expr, *c_id, skip_wrapper))
204-
.collect(),
205-
expr: expr_string,
206-
ty: format!(
207-
"{} {{{}}}",
208-
pp_ty(expr, *ty_id, false),
209-
kind_string(&expr[*expr_id])
210-
)
211-
.into(),
212-
},
213-
Err(c_id) => Self::new(expr, c_id, skip_wrapper),
198+
let colored_string = match node {
199+
Rise::Var(index) => index.to_string().magenta(),
200+
Rise::App(_) | Rise::Lambda(_) => node.to_string().red(),
201+
Rise::NatApp(_) | Rise::DataApp(_) | Rise::AddrApp(_) | Rise::NatNatApp(_) => {
202+
node.to_string().cyan()
203+
}
204+
Rise::NatLambda(c_id)
205+
| Rise::DataLambda(c_id)
206+
| Rise::AddrLambda(c_id)
207+
| Rise::NatNatLambda(c_id) => {
208+
if skip_wrapper {
209+
return Self::new(expr, *c_id, skip_wrapper);
210+
}
211+
node.to_string().cyan()
212+
}
213+
Rise::FunType(_)
214+
| Rise::NatFun(_)
215+
| Rise::DataFun(_)
216+
| Rise::AddrFun(_)
217+
| Rise::NatNatFun(_)
218+
| Rise::TypeOf(_)
219+
| Rise::ArrType(_)
220+
| Rise::VecType(_)
221+
| Rise::PairType(_)
222+
| Rise::IndexType(_)
223+
| Rise::NatType
224+
| Rise::F32 => panic!("Should not see types here: {node}"),
225+
Rise::NatAdd(_)
226+
| Rise::NatSub(_)
227+
| Rise::NatMul(_)
228+
| Rise::NatDiv(_)
229+
| Rise::NatPow(_) => {
230+
panic!("NatExpr should only appear in types: {node}")
231+
} // node.to_string().white()
232+
Rise::Let
233+
| Rise::AsVector
234+
| Rise::AsScalar
235+
| Rise::VectorFromScalar
236+
| Rise::Snd
237+
| Rise::Fst
238+
| Rise::Add
239+
| Rise::Mul
240+
| Rise::ToMem
241+
| Rise::Split
242+
| Rise::Join
243+
| Rise::Generate
244+
| Rise::Transpose
245+
| Rise::Zip
246+
| Rise::Unzip
247+
| Rise::Map
248+
| Rise::MapPar
249+
| Rise::Reduce
250+
| Rise::ReduceSeq
251+
| Rise::ReduceSeqUnroll
252+
| Rise::Float(_) => node.to_string().yellow(),
253+
Rise::Integer(i) => format!("int{i}").purple(),
254+
};
255+
Self {
256+
children: node
257+
.children()
258+
.iter()
259+
.map(|c_id| Self::new(expr, *c_id, skip_wrapper))
260+
.collect(),
261+
expr: colored_string,
262+
ty: pp_ty(expr, *ty_id, false),
214263
}
215264
}
216265
}
217266

218-
fn kind_string(node: &Rise) -> String {
219-
node.kind()
220-
.map_or_else(|| String::from("UNKINDABLE"), |k| k.to_string())
221-
}
222-
223267
fn color_expr(node: &Rise, skip_wrapper: bool) -> Result<ColoredString, Id> {
224268
let colored_string = match node {
225269
Rise::Var(index) => index.to_string().magenta(),

src/rewrite_system/rise/mod.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ mod test {
7979

8080
use egg::{AstSize, RecExpr, Runner, SimpleScheduler};
8181

82-
use crate::{rewrite_system::rise::lang::PrettyPrint, sketch::eclass_extract};
82+
use crate::sketch::eclass_extract;
8383

8484
use super::*;
8585

@@ -161,9 +161,12 @@ mod test {
161161
&baseline_goal,
162162
baseline_goal.root(),
163163
);
164-
println!("mm: {mm}");
165-
println!("baseline_goal: {baseline_goal}");
166-
println!("sketch_baseline_extr: {sketch_extracted_baseline}");
164+
println!("mm:");
165+
mm.pp(false);
166+
println!("baseline_goal:");
167+
baseline_goal.pp(false);
168+
println!("sketch_baseline_extr:");
169+
sketch_extracted_baseline.pp(false);
167170
assert_eq!(diff, None);
168171
// assert_eq!(root_mm, r.egraph.lookup_expr(&baseline_goal).unwrap());
169172
}

src/rewrite_system/rise/rules.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,31 +131,31 @@ fn replace(expr: &RecExpr<Rise>, index: Index, subs: &mut RecExpr<Rise>) -> RecE
131131
}
132132
Rise::Lambda(e) => {
133133
shift_mut(subs, Shift::up(), Index::zero_like(index));
134-
let e2 = rec(result, expr, *e, index.upshifted(), subs);
134+
let e2 = rec(result, expr, *e, index.inc(), subs);
135135
shift_mut(subs, Shift::down(), Index::zero_like(index));
136136
result.add(Rise::Lambda(e2))
137137
}
138138
Rise::NatLambda(e) => {
139139
shift_mut(subs, Shift::up(), Index::zero_like(index));
140-
let e2 = rec(result, expr, *e, index.upshifted(), subs);
140+
let e2 = rec(result, expr, *e, index.inc(), subs);
141141
shift_mut(subs, Shift::down(), Index::zero_like(index));
142142
result.add(Rise::NatLambda(e2))
143143
}
144144
Rise::DataLambda(e) => {
145145
shift_mut(subs, Shift::up(), Index::zero_like(index));
146-
let e2 = rec(result, expr, *e, index.upshifted(), subs);
146+
let e2 = rec(result, expr, *e, index.inc(), subs);
147147
shift_mut(subs, Shift::down(), Index::zero_like(index));
148148
result.add(Rise::DataLambda(e2))
149149
}
150150
Rise::AddrLambda(e) => {
151151
shift_mut(subs, Shift::up(), Index::zero_like(index));
152-
let e2 = rec(result, expr, *e, index.upshifted(), subs);
152+
let e2 = rec(result, expr, *e, index.inc(), subs);
153153
shift_mut(subs, Shift::down(), Index::zero_like(index));
154154
result.add(Rise::AddrLambda(e2))
155155
}
156156
Rise::NatNatLambda(e) => {
157157
shift_mut(subs, Shift::up(), Index::zero_like(index));
158-
let e2 = rec(result, expr, *e, index.upshifted(), subs);
158+
let e2 = rec(result, expr, *e, index.inc(), subs);
159159
shift_mut(subs, Shift::down(), Index::zero_like(index));
160160
result.add(Rise::NatNatLambda(e2))
161161
}

src/rewrite_system/rise/shifted.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ pub fn shift_mut(expr: &mut RecExpr<Rise>, shift: Shift, cutoff: Index) {
125125
| Rise::AddrLambda(e)
126126
| Rise::NatNatLambda(e) => {
127127
if expr[ei].kind() == cutoff.kind() {
128-
rec(expr, e, shift, cutoff.upshifted());
128+
rec(expr, e, shift, cutoff.inc());
129129
} else {
130130
rec(expr, e, shift, cutoff);
131131
}

0 commit comments

Comments
 (0)