Skip to content

Commit ed12bc2

Browse files
more work to fix the db shift bug
1 parent 9f6cbe3 commit ed12bc2

14 files changed

Lines changed: 220 additions & 209 deletions

File tree

examples/blocking.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ fn main() {
2828
let (_, sketch_extracted_split_guide) =
2929
sketch::eclass_extract(&split_guide_sketch, AstSize, &runner_1.egraph, root_mm).unwrap();
3030

31-
println!("Guide Ground Truth");
31+
println!("\nGuide Ground Truth");
3232
split_guide.pp(false);
33-
3433
println!("\nSketch Extracted:");
3534
sketch_extracted_split_guide.pp(false);
3635

src/rewrite_system/rise/analysis.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use hashbrown::HashSet;
44
use crate::rewrite_system::rise::kind::Kindable;
55
use crate::rewrite_system::rise::nat::try_simplify;
66

7-
use super::{Index, Rise};
7+
use super::{DBIndex, Rise};
88

99
#[derive(Default, Debug)]
1010
pub struct RiseAnalysis {
@@ -38,7 +38,7 @@ impl RiseAnalysis {
3838

3939
#[derive(Default, Debug)]
4040
pub struct AnalysisData {
41-
pub free: HashSet<Index>,
41+
pub free: HashSet<DBIndex>,
4242
pub beta_extract: RecExpr<Rise>,
4343
// pub simple_nat: Option<RecExpr<Rise>>,
4444
}

src/rewrite_system/rise/func.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
use egg::{Applier, EGraph, Id, Language, Pattern, PatternAst, RecExpr, Subst, Symbol, Var};
22
use hashbrown::HashSet;
33

4-
use super::{Index, Kind, Kindable, Rise, RiseAnalysis};
4+
use super::{DBIndex, Kind, Kindable, Rise, RiseAnalysis};
55

66
pub fn pat(pat: &str) -> impl Applier<Rise, RiseAnalysis> {
77
pat.parse::<Pattern<Rise>>().unwrap()
88
}
99

1010
pub struct NotFreeIn<A: Applier<Rise, RiseAnalysis>> {
1111
var: Var,
12-
index: Index,
12+
index: DBIndex,
1313
applier: A,
1414
}
1515

1616
impl<A: Applier<Rise, RiseAnalysis>> NotFreeIn<A> {
1717
pub fn new(var_str: &str, index: u32, applier: A) -> Self {
1818
let var: Var = var_str.parse().unwrap();
19-
let kind = var.kind().unwrap();
19+
let kind = var.kind();
2020
NotFreeIn {
2121
var,
22-
index: Index::new(index, kind),
22+
index: DBIndex::new(index, kind),
2323
applier,
2424
}
2525
}
@@ -108,7 +108,7 @@ fn extracted_int(expr: &RecExpr<Rise>) -> i32 {
108108
fn vec_expr(
109109
expr: &RecExpr<Rise>,
110110
n: i32,
111-
v_env: HashSet<Index>,
111+
v_env: HashSet<DBIndex>,
112112
type_of_id: Id,
113113
) -> Option<(RecExpr<Rise>, Id, Id)> {
114114
let Rise::TypeOf([expr_id, ty_id]) = &expr[type_of_id] else {
@@ -144,7 +144,7 @@ fn vec_expr(
144144
let v_env2 = v_env
145145
.into_iter()
146146
.map(|i| i.inc())
147-
.chain([Index::zero(Kind::Expr)])
147+
.chain([DBIndex::zero(Kind::Expr)])
148148
.collect::<HashSet<_>>();
149149

150150
// Vectorize e

src/rewrite_system/rise/indices.rs

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,46 @@
11
use serde::{Deserialize, Serialize};
22
use thiserror::Error;
33

4-
use super::Kind;
5-
6-
// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash, Copy, Serialize, Deserialize)]
7-
// pub struct Index(u32);
4+
use super::{Kind, Kindable};
85

96
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash, Copy, Serialize, Deserialize)]
10-
pub enum Index {
7+
pub enum DBIndex {
118
Expr(u32),
129
Nat(u32),
1310
Data(u32),
1411
Addr(u32),
1512
}
1613

17-
impl Index {
14+
impl DBIndex {
1815
pub fn inc(self) -> Self {
19-
self + Shift::up()
16+
self + DBShift::up()
2017
}
2118

2219
pub fn dec(self) -> Self {
23-
self + Shift::down()
20+
self + DBShift::down()
2421
}
2522

2623
pub fn new(value: u32, kind: Kind) -> Self {
2724
match kind {
28-
Kind::Expr => Index::Expr(value),
29-
Kind::Nat => Index::Nat(value),
30-
Kind::Data => Index::Data(value),
31-
Kind::Addr => Index::Addr(value),
25+
Kind::Expr => DBIndex::Expr(value),
26+
Kind::Nat => DBIndex::Nat(value),
27+
Kind::Data => DBIndex::Data(value),
28+
Kind::Addr => DBIndex::Addr(value),
3229
}
3330
}
3431

3532
pub fn zero(kind: Kind) -> Self {
3633
match kind {
37-
Kind::Expr => Index::Expr(0),
38-
Kind::Nat => Index::Nat(0),
39-
Kind::Data => Index::Data(0),
40-
Kind::Addr => Index::Addr(0),
41-
}
42-
}
43-
44-
pub fn zero_like(other: Self) -> Self {
45-
match other {
46-
Index::Expr(_) => Index::Expr(0),
47-
Index::Nat(_) => Index::Nat(0),
48-
Index::Data(_) => Index::Data(0),
49-
Index::Addr(_) => Index::Addr(0),
34+
Kind::Expr => DBIndex::Expr(0),
35+
Kind::Nat => DBIndex::Nat(0),
36+
Kind::Data => DBIndex::Data(0),
37+
Kind::Addr => DBIndex::Addr(0),
5038
}
5139
}
5240

5341
pub fn value(self) -> u32 {
5442
match self {
55-
Index::Expr(i) | Index::Nat(i) | Index::Data(i) | Index::Addr(i) => i,
43+
DBIndex::Expr(i) | DBIndex::Nat(i) | DBIndex::Data(i) | DBIndex::Addr(i) => i,
5644
}
5745
}
5846

@@ -61,30 +49,30 @@ impl Index {
6149
}
6250
}
6351

64-
impl std::ops::Add<Shift> for Index {
52+
impl std::ops::Add<DBShift> for DBIndex {
6553
type Output = Self;
6654

67-
fn add(self, rhs: Shift) -> Self::Output {
55+
fn add(self, rhs: DBShift) -> Self::Output {
6856
match self {
69-
Index::Expr(i) => Index::Expr(i.strict_add_signed(rhs.0)),
70-
Index::Nat(i) => Index::Nat(i.strict_add_signed(rhs.0)),
71-
Index::Data(i) => Index::Data(i.strict_add_signed(rhs.0)),
72-
Index::Addr(i) => Index::Addr(i.strict_add_signed(rhs.0)),
57+
DBIndex::Expr(i) => DBIndex::Expr(i.strict_add_signed(rhs.0)),
58+
DBIndex::Nat(i) => DBIndex::Nat(i.strict_add_signed(rhs.0)),
59+
DBIndex::Data(i) => DBIndex::Data(i.strict_add_signed(rhs.0)),
60+
DBIndex::Addr(i) => DBIndex::Addr(i.strict_add_signed(rhs.0)),
7361
}
7462
}
7563
}
7664

77-
impl std::str::FromStr for Index {
65+
impl std::str::FromStr for DBIndex {
7866
type Err = IndexError;
7967

8068
fn from_str(s: &str) -> Result<Self, Self::Err> {
8169
if let Some(stripped_s) = s.strip_prefix("%") {
8270
if let Some((tag, i)) = stripped_s.split_at_checked(1) {
8371
match tag {
84-
"e" => Ok(Index::Expr(i.parse()?)),
85-
"n" => Ok(Index::Nat(i.parse()?)),
86-
"d" => Ok(Index::Data(i.parse()?)),
87-
"a" => Ok(Index::Addr(i.parse()?)),
72+
"e" => Ok(DBIndex::Expr(i.parse()?)),
73+
"n" => Ok(DBIndex::Nat(i.parse()?)),
74+
"d" => Ok(DBIndex::Data(i.parse()?)),
75+
"a" => Ok(DBIndex::Addr(i.parse()?)),
8876
_ => Err(IndexError::ImproperTag(stripped_s.to_owned())),
8977
}
9078
} else {
@@ -96,21 +84,32 @@ impl std::str::FromStr for Index {
9684
}
9785
}
9886

99-
impl std::fmt::Display for Index {
87+
impl std::fmt::Display for DBIndex {
10088
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
10189
match self {
102-
Index::Expr(i) => write!(f, "%e{i}"),
103-
Index::Nat(i) => write!(f, "%n{i}"),
104-
Index::Data(i) => write!(f, "%d{i}"),
105-
Index::Addr(i) => write!(f, "%a{i}"),
90+
DBIndex::Expr(i) => write!(f, "%e{i}"),
91+
DBIndex::Nat(i) => write!(f, "%n{i}"),
92+
DBIndex::Data(i) => write!(f, "%d{i}"),
93+
DBIndex::Addr(i) => write!(f, "%a{i}"),
94+
}
95+
}
96+
}
97+
98+
impl Kindable for DBIndex {
99+
fn kind(&self) -> Kind {
100+
match self {
101+
DBIndex::Expr(_) => Kind::Expr,
102+
DBIndex::Nat(_) => Kind::Nat,
103+
DBIndex::Data(_) => Kind::Data,
104+
DBIndex::Addr(_) => Kind::Addr,
106105
}
107106
}
108107
}
109108

110109
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash, Copy, Serialize, Deserialize)]
111-
pub struct Shift(i32);
110+
pub struct DBShift(i32);
112111

113-
impl Shift {
112+
impl DBShift {
114113
pub fn up() -> Self {
115114
Self(1)
116115
}
@@ -120,18 +119,18 @@ impl Shift {
120119
}
121120
}
122121

123-
impl TryFrom<i32> for Shift {
122+
impl TryFrom<i32> for DBShift {
124123
type Error = IndexError;
125124

126125
fn try_from(value: i32) -> Result<Self, Self::Error> {
127126
if value == 0 {
128127
return Err(IndexError::ZeroShift);
129128
}
130-
Ok(Shift(value))
129+
Ok(DBShift(value))
131130
}
132131
}
133132

134-
impl std::fmt::Display for Shift {
133+
impl std::fmt::Display for DBShift {
135134
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136135
write!(f, "{}", self.0)
137136
}

src/rewrite_system/rise/kind.rs

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@ use std::fmt::Display;
33
use egg::Var;
44
use serde::{Deserialize, Serialize};
55

6-
use super::Index;
7-
86
pub trait Kindable {
9-
fn kind(&self) -> Option<Kind>;
7+
fn kind(&self) -> Kind;
108
}
119

1210
impl<T: Kindable> Kindable for &T {
13-
fn kind(&self) -> Option<Kind> {
11+
fn kind(&self) -> Kind {
1412
(*self).kind()
1513
}
1614
}
@@ -35,25 +33,18 @@ impl Display for Kind {
3533
}
3634

3735
impl Kindable for Var {
38-
fn kind(&self) -> Option<Kind> {
36+
fn kind(&self) -> Kind {
3937
let var_str = self.to_string();
40-
var_str.chars().nth(1).map(|c| match c {
41-
'd' | 't' => Kind::Data,
42-
'a' => Kind::Addr,
43-
'n' => Kind::Nat,
44-
x if x.is_numeric() => Kind::Expr,
45-
x => panic!("Wrong format {x}"),
46-
})
47-
}
48-
}
49-
50-
impl Kindable for Index {
51-
fn kind(&self) -> Option<Kind> {
52-
Some(match self {
53-
Index::Expr(_) => Kind::Expr,
54-
Index::Nat(_) => Kind::Nat,
55-
Index::Data(_) => Kind::Data,
56-
Index::Addr(_) => Kind::Addr,
57-
})
38+
var_str
39+
.chars()
40+
.nth(1)
41+
.map(|c| match c {
42+
'd' | 't' => Kind::Data,
43+
'a' => Kind::Addr,
44+
'n' => Kind::Nat,
45+
x if x.is_numeric() => Kind::Expr,
46+
x => panic!("Wrong format {x}"),
47+
})
48+
.expect("Wrong format {x}")
5849
}
5950
}

src/rewrite_system/rise/lang.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ use egg::Id;
22
use ordered_float::NotNan;
33
use serde::{Deserialize, Serialize};
44

5-
use super::{Index, Kind, Kindable};
5+
use super::{DBIndex, Kind, Kindable};
66

77
egg::define_language! {
88
#[derive(Serialize, Deserialize)]
99
pub enum Rise {
10-
Var(Index),
10+
Var(DBIndex),
1111
"app" = App([Id; 2]),
1212
"natApp" = NatApp([Id; 2]),
1313
"dataApp" = DataApp([Id; 2]),
@@ -83,11 +83,15 @@ egg::define_language! {
8383
}
8484
}
8585

86-
impl Rise {
87-
#[must_use]
88-
pub fn kind(&self) -> Option<Kind> {
89-
Some(match self {
90-
Rise::Var(index) => index.kind()?,
86+
impl Kindable for Rise {
87+
/// Returns the kind of this [`Rise`].
88+
///
89+
/// # Panics
90+
///
91+
/// Panics if called on an unkindable type
92+
fn kind(&self) -> Kind {
93+
match self {
94+
Rise::Var(index) => index.kind(),
9195
Rise::App(_)
9296
| Rise::Lambda(_)
9397
| Rise::Let
@@ -133,8 +137,8 @@ impl Rise {
133137
| Rise::F32 => Kind::Data,
134138
Rise::AddrApp(_) | Rise::AddrLambda(_) => Kind::Addr,
135139
Rise::TypeOf(_) | Rise::NatNatApp(_) | Rise::NatNatLambda(_) | Rise::Integer(_) => {
136-
return None;
140+
panic!("NOT KINDABLE");
137141
}
138-
})
142+
}
139143
}
140144
}

src/rewrite_system/rise/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ mod shifted;
1010

1111
use egg::{Id, Language, RecExpr, Rewrite};
1212

13-
use indices::{Index, Shift};
13+
use indices::{DBIndex, DBShift};
1414
use kind::{Kind, Kindable};
1515

1616
pub use analysis::RiseAnalysis;

0 commit comments

Comments
 (0)