Skip to content

Commit 25d40c6

Browse files
First take at partial term reconstruction
1 parent c545e5f commit 25d40c6

10 files changed

Lines changed: 403 additions & 210 deletions

File tree

Lines changed: 251 additions & 210 deletions
Large diffs are not rendered by default.

src/python/data/nodes.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use pyo3::prelude::*;
2+
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
3+
4+
use serde::Serialize;
5+
6+
use crate::trs::SymbolInfo;
7+
8+
#[derive(Debug, PartialEq, Clone, Serialize)]
9+
pub enum NodeOrPlaceHolder {
10+
Node(Node),
11+
Placeholder {
12+
depth: usize,
13+
dfs_order: usize,
14+
nth_child: usize,
15+
},
16+
}
17+
18+
impl NodeOrPlaceHolder {
19+
#[must_use]
20+
pub fn id(&self) -> Option<usize> {
21+
match self {
22+
NodeOrPlaceHolder::Node(node) => Some(node.id()),
23+
NodeOrPlaceHolder::Placeholder { .. } => None,
24+
}
25+
}
26+
27+
#[must_use]
28+
pub fn value(&self) -> Option<String> {
29+
match self {
30+
NodeOrPlaceHolder::Node(node) => node.symbol_info.value(),
31+
NodeOrPlaceHolder::Placeholder { .. } => None,
32+
}
33+
}
34+
35+
#[must_use]
36+
pub fn name(&self) -> String {
37+
if let NodeOrPlaceHolder::Node(n) = self {
38+
match n.symbol_info.symbol_type() {
39+
crate::trs::SymbolType::Constant(_) => "[constant]".to_owned(),
40+
crate::trs::SymbolType::Variable(_) => "[variable]".to_owned(),
41+
crate::trs::SymbolType::MetaSymbol | crate::trs::SymbolType::Operator => {
42+
n.raw_name.clone()
43+
}
44+
}
45+
} else {
46+
"[PLACEHOLDER]".to_owned()
47+
}
48+
}
49+
50+
#[must_use]
51+
pub fn depth(&self) -> usize {
52+
match self {
53+
NodeOrPlaceHolder::Node(node) => node.depth,
54+
NodeOrPlaceHolder::Placeholder { depth, .. } => *depth,
55+
}
56+
}
57+
}
58+
59+
#[gen_stub_pyclass]
60+
#[pyclass(frozen, module = "eggshell")]
61+
#[derive(Debug, PartialEq, Clone, Serialize)]
62+
pub struct Node {
63+
#[pyo3(get)]
64+
raw_name: String,
65+
#[pyo3(get)]
66+
arity: usize,
67+
#[pyo3(get)]
68+
nth_child: usize,
69+
#[pyo3(get)]
70+
dfs_order: usize,
71+
#[pyo3(get)]
72+
depth: usize,
73+
symbol_info: SymbolInfo,
74+
}
75+
76+
impl Node {
77+
#[must_use]
78+
pub fn new(
79+
raw_name: String,
80+
arity: usize,
81+
nth_child: usize,
82+
dfs_order: usize,
83+
depth: usize,
84+
symbol_info: SymbolInfo,
85+
) -> Self {
86+
Self {
87+
raw_name,
88+
arity,
89+
nth_child,
90+
dfs_order,
91+
depth,
92+
symbol_info,
93+
}
94+
}
95+
96+
#[must_use]
97+
pub fn symbol_info(&self) -> &SymbolInfo {
98+
&self.symbol_info
99+
}
100+
}
101+
102+
#[gen_stub_pymethods]
103+
#[pymethods]
104+
impl Node {
105+
#[must_use]
106+
#[getter]
107+
pub fn id(&self) -> usize {
108+
self.symbol_info.id()
109+
}
110+
111+
#[must_use]
112+
#[getter]
113+
pub fn value(&self) -> Option<String> {
114+
self.symbol_info.value()
115+
}
116+
117+
#[must_use]
118+
#[getter]
119+
pub fn name(&self) -> String {
120+
match self.symbol_info.symbol_type() {
121+
crate::trs::SymbolType::Constant(_) => "[constant]".to_owned(),
122+
crate::trs::SymbolType::Variable(_) => "[variable]".to_owned(),
123+
crate::trs::SymbolType::MetaSymbol | crate::trs::SymbolType::Operator => {
124+
self.raw_name.clone()
125+
}
126+
}
127+
}
128+
}

src/python/monomorphize.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ macro_rules! monomorphize {
99

1010
use $crate::eqsat::conf::EqsatConf;
1111
use $crate::eqsat::{Eqsat, StartMaterial};
12+
use $crate::python::data;
1213
use $crate::python::data::TreeData;
1314
use $crate::python::err::EggshellError;
1415
use $crate::trs::{MetaInfo, TermRewriteSystem};
@@ -87,6 +88,16 @@ macro_rules! monomorphize {
8788
}
8889
}
8990

91+
#[gen_stub_pyfunction(module = $module_name)]
92+
#[pyfunction]
93+
#[expect(clippy::missing_errors_doc)]
94+
pub fn partial_parse(token_list: Vec<String>) -> PyResult<TreeData> {
95+
let r = data::partial_parse::<L>(token_list)
96+
.and_then(|partial: Vec<Option<L>>| TreeData::try_from(partial))
97+
.map_err(|e| EggshellError::<L>::from(e))?;
98+
Ok(r)
99+
}
100+
90101
#[gen_stub_pyfunction(module = $module_name)]
91102
#[pyfunction]
92103
#[must_use]

src/sketch/full_sketch.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ impl<L: Language + MetaInfo> MetaInfo for SketchLang<L> {
116116
}
117117

118118
const NUM_SYMBOLS: usize = L::NUM_SYMBOLS + Self::COUNT;
119+
120+
const MAX_ARITY: usize = { if L::MAX_ARITY > 2 { L::MAX_ARITY } else { 2 } };
119121
}
120122

121123
impl<L: Language + Display> Display for SketchLang<L> {

src/sketch/partial_sketch.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ impl<L: Language + MetaInfo> MetaInfo for PartialSketchLang<L> {
104104
}
105105

106106
const NUM_SYMBOLS: usize = SketchLang::<L>::NUM_SYMBOLS + Self::COUNT;
107+
108+
const MAX_ARITY: usize = SketchLang::<L>::MAX_ARITY;
107109
}
108110

109111
impl<L: Language + Display> Display for PartialSketchLang<L> {

src/trs/arithmetic/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ impl MetaInfo for Math {
5858
"d", "i", "+", "-", "*", "/", "pow", "ln", "sqrt", "sin", "cos",
5959
]
6060
}
61+
62+
const MAX_ARITY: usize = 2;
6163
}
6264

6365
// impl Typeable for Math {

src/trs/halide/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ impl MetaInfo for HalideLang {
6565
"&&",
6666
]
6767
}
68+
69+
const MAX_ARITY: usize = 2;
6870
}
6971

7072
// impl Typeable for HalideLang {

src/trs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ pub trait MetaInfo: Display + Language + EnumCount {
9090
fn operators() -> Vec<&'static str>;
9191

9292
const NUM_SYMBOLS: usize = Self::COUNT;
93+
const MAX_ARITY: usize;
9394
}
9495

9596
#[derive(Debug, Error)]

src/trs/rise/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ impl MetaInfo for RiseLang {
9191
"snd",
9292
]
9393
}
94+
95+
const MAX_ARITY: usize = 3;
9496
}
9597

9698
#[derive(Default, Debug, Clone, Copy, Serialize)]

src/trs/simple/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ impl MetaInfo for SimpleLang {
3636
fn operators() -> Vec<&'static str> {
3737
vec!["+", "*"]
3838
}
39+
40+
const MAX_ARITY: usize = 2;
3941
}
4042

4143
// impl Typeable for SimpleLang {

0 commit comments

Comments
 (0)