Skip to content

Commit c545e5f

Browse files
Minor cleanup in wording of docs and vars
1 parent e9f5522 commit c545e5f

2 files changed

Lines changed: 42 additions & 27 deletions

File tree

python/eggshell/__init__.pyi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,20 @@ class TreeData:
2222
def transposed_adjacency(self) -> list[list[int]]:
2323
...
2424

25-
def anc_matrix(self, max_abs_distance,double_pad = ...) -> list[list[int]]:
25+
def anc_matrix(self, max_rel_distance,double_pad = ...) -> list[list[int]]:
2626
r"""
2727
Gives a matrix that describes the relationship of an ancestor to a child as a distance between them
28-
maximum distance (positive or negative) to be encoded.
29-
If the distance is too large or no relationship exists, -1 is returned
28+
maximum distance (positive or negative) to be encoded mapped to the range 2 * max_rel_distance
29+
If the distance is too large or no relationship exists, 0 is returned
3030
"""
3131
...
3232

33-
def sib_matrix(self, max_abs_distance,double_pad = ...) -> list[list[int]]:
33+
def sib_matrix(self, max_rel_distance,double_pad = ...) -> list[list[int]]:
3434
r"""
3535
Gives a matrix that describes the sibling relationship in nodes
36-
max_abs_distance describes the maximum distance (positive or negative) to be encoded.
37-
If the distance is too large or no relationship exists, -1 is returned
36+
max_relative_distance describes the maximum distance (positive or negative) to be encoded,
37+
mapped to the range 2 * max_relative_distance
38+
If the distance is too large or no relationship exists, 0 is returned
3839
"""
3940
...
4041

src/python/data.rs

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ impl TreeData {
102102
}
103103

104104
/// Gives a matrix that describes the relationship of an ancestor to a child as a distance between them
105-
/// maximum distance (positive or negative) to be encoded.
106-
/// If the distance is too large or no relationship exists, -1 is returned
105+
/// maximum distance (positive or negative) to be encoded mapped to the range 2 * max_rel_distance
106+
/// If the distance is too large or no relationship exists, 0 is returned
107107
#[must_use]
108-
#[pyo3(signature = (max_abs_distance, double_pad=true))]
109-
pub fn anc_matrix(&self, max_abs_distance: usize, double_pad: bool) -> Vec<Vec<usize>> {
108+
#[pyo3(signature = (max_rel_distance, double_pad=true))]
109+
pub fn anc_matrix(&self, max_rel_distance: usize, double_pad: bool) -> Vec<Vec<usize>> {
110110
fn cmp_nodes(
111111
a: usize,
112112
b: usize,
@@ -134,16 +134,18 @@ impl TreeData {
134134
},
135135
);
136136

137+
let center = max_rel_distance + 1;
138+
137139
let i = (0..self.adjacency.len()).map(|a_idx| {
138140
let inner_i = (0..self.adjacency.len()).map(|b_idx| {
139141
if a_idx == b_idx {
140-
return max_abs_distance; // Distance to self is always 0
142+
return center; // Distance to self is always 0
141143
}
142144
if let Some(d) = cmp_nodes(a_idx, b_idx, &par_child, 1) {
143-
(d < max_abs_distance).then_some(max_abs_distance + d)
145+
(d < max_rel_distance).then_some(center + d)
144146
// Positive since parent to child
145147
} else if let Some(d) = cmp_nodes(b_idx, a_idx, &par_child, 1) {
146-
(d < max_abs_distance).then_some(max_abs_distance - d)
148+
(d < max_rel_distance).then_some(center - d)
147149
// Negative since child to parent
148150
} else {
149151
None
@@ -169,23 +171,25 @@ impl TreeData {
169171
}
170172

171173
/// Gives a matrix that describes the sibling relationship in nodes
172-
/// max_abs_distance describes the maximum distance (positive or negative) to be encoded.
173-
/// If the distance is too large or no relationship exists, -1 is returned
174+
/// max_relative_distance describes the maximum distance (positive or negative) to be encoded,
175+
/// mapped to the range 2 * max_relative_distance
176+
/// If the distance is too large or no relationship exists, 0 is returned
174177
#[must_use]
175-
#[pyo3(signature = (max_abs_distance, double_pad=true))]
176-
pub fn sib_matrix(&self, max_abs_distance: usize, double_pad: bool) -> Vec<Vec<usize>> {
178+
#[pyo3(signature = (max_rel_distance, double_pad=true))]
179+
pub fn sib_matrix(&self, max_rel_distance: usize, double_pad: bool) -> Vec<Vec<usize>> {
177180
fn cmp_nodes(
178181
a: usize,
179182
b: usize,
180183
par_child: &HashMap<usize, Vec<usize>>,
181184
child_par: &HashMap<usize, usize>,
182-
max_abs_distance: usize,
185+
max_relative_distance: usize,
186+
center: usize,
183187
) -> Option<usize> {
184188
// Distance to self is always 0 aka center
185189
// This catches the special case where root is compared to root
186190
// which would be problematic in the if let since root has no parents
187191
if a == b {
188-
return Some(max_abs_distance);
192+
return Some(center);
189193
}
190194

191195
// Root case where a and b are both root and have no parents is caught by a==b
@@ -200,21 +204,23 @@ impl TreeData {
200204
let pos_b = sibilings.iter().position(|x| x == &b).unwrap();
201205
let d = usize::abs_diff(pos_a, pos_b);
202206

203-
if d >= max_abs_distance {
207+
if d >= max_relative_distance {
204208
return None;
205209
}
206210
// == case caught earlier
207211

208212
if pos_a < pos_b {
209-
Some(max_abs_distance + d)
213+
Some(center + d)
210214
} else {
211-
Some(max_abs_distance - d)
215+
Some(center - d)
212216
}
213217
} else {
214218
None // Either not related or bigger distance than max so we return max
215219
}
216220
}
217221

222+
let center = max_rel_distance + 1;
223+
218224
let (par_child, child_par) = self.adjacency.iter().fold(
219225
(HashMap::new(), HashMap::new()),
220226
|(mut par_child, mut child_par), (parent, child)| {
@@ -231,7 +237,15 @@ impl TreeData {
231237

232238
let i = (0..self.adjacency.len()).map(|a_idx| {
233239
let inner_i = (0..self.adjacency.len()).map(|b_idx| {
234-
cmp_nodes(a_idx, b_idx, &par_child, &child_par, max_abs_distance).unwrap_or(0)
240+
cmp_nodes(
241+
a_idx,
242+
b_idx,
243+
&par_child,
244+
&child_par,
245+
max_rel_distance,
246+
center,
247+
)
248+
.unwrap_or(0)
235249
});
236250
if !double_pad {
237251
return inner_i.collect();
@@ -446,7 +460,7 @@ mod tests {
446460
fn sib_matrix() {
447461
let expr: RecExpr<HalideLang> = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )".parse().unwrap();
448462
let data: TreeData = (&expr).try_into().unwrap();
449-
let par_sib = data.sib_matrix(16, false);
463+
let par_sib = data.sib_matrix(15, false);
450464

451465
assert_eq!(
452466
par_sib,
@@ -467,7 +481,7 @@ mod tests {
467481
fn anc_matrix() {
468482
let expr: RecExpr<HalideLang> = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )".parse().unwrap();
469483
let data: TreeData = (&expr).try_into().unwrap();
470-
let par_sib = data.anc_matrix(16, false);
484+
let par_sib = data.anc_matrix(15, false);
471485

472486
assert_eq!(
473487
par_sib,
@@ -488,7 +502,7 @@ mod tests {
488502
fn anc_matrix_padded() {
489503
let expr: RecExpr<HalideLang> = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )".parse().unwrap();
490504
let data: TreeData = (&expr).try_into().unwrap();
491-
let par_sib = data.anc_matrix(16, true);
505+
let par_sib = data.anc_matrix(15, true);
492506

493507
assert_eq!(
494508
par_sib,
@@ -511,7 +525,7 @@ mod tests {
511525
fn sib_matrix_padded() {
512526
let expr: RecExpr<HalideLang> = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )".parse().unwrap();
513527
let data: TreeData = (&expr).try_into().unwrap();
514-
let par_sib = data.sib_matrix(16, true);
528+
let par_sib = data.sib_matrix(15, true);
515529

516530
assert_eq!(
517531
par_sib,

0 commit comments

Comments
 (0)