@@ -102,11 +102,11 @@ impl TreeData {
102102 }
103103
104104 /// Gives a matrix that describes the relationship of an ancestor to a child as a distance between them
105- /// maximum distance (positive or negative) to be encoded.
106- /// If the distance is too large or no relationship exists, -1 is returned
105+ /// maximum distance (positive or negative) to be encoded mapped to the range 2 * max_rel_distance
106+ /// If the distance is too large or no relationship exists, 0 is returned
107107 #[ must_use]
108- #[ pyo3( signature = ( max_abs_distance , double_pad=true ) ) ]
109- pub fn anc_matrix ( & self , max_abs_distance : usize , double_pad : bool ) -> Vec < Vec < usize > > {
108+ #[ pyo3( signature = ( max_rel_distance , double_pad=true ) ) ]
109+ pub fn anc_matrix ( & self , max_rel_distance : usize , double_pad : bool ) -> Vec < Vec < usize > > {
110110 fn cmp_nodes (
111111 a : usize ,
112112 b : usize ,
@@ -134,16 +134,18 @@ impl TreeData {
134134 } ,
135135 ) ;
136136
137+ let center = max_rel_distance + 1 ;
138+
137139 let i = ( 0 ..self . adjacency . len ( ) ) . map ( |a_idx| {
138140 let inner_i = ( 0 ..self . adjacency . len ( ) ) . map ( |b_idx| {
139141 if a_idx == b_idx {
140- return max_abs_distance ; // Distance to self is always 0
142+ return center ; // Distance to self is always 0
141143 }
142144 if let Some ( d) = cmp_nodes ( a_idx, b_idx, & par_child, 1 ) {
143- ( d < max_abs_distance ) . then_some ( max_abs_distance + d)
145+ ( d < max_rel_distance ) . then_some ( center + d)
144146 // Positive since parent to child
145147 } else if let Some ( d) = cmp_nodes ( b_idx, a_idx, & par_child, 1 ) {
146- ( d < max_abs_distance ) . then_some ( max_abs_distance - d)
148+ ( d < max_rel_distance ) . then_some ( center - d)
147149 // Negative since child to parent
148150 } else {
149151 None
@@ -169,23 +171,25 @@ impl TreeData {
169171 }
170172
171173 /// Gives a matrix that describes the sibling relationship in nodes
172- /// max_abs_distance describes the maximum distance (positive or negative) to be encoded.
173- /// If the distance is too large or no relationship exists, -1 is returned
174+ /// max_relative_distance describes the maximum distance (positive or negative) to be encoded,
175+ /// mapped to the range 2 * max_relative_distance
176+ /// If the distance is too large or no relationship exists, 0 is returned
174177 #[ must_use]
175- #[ pyo3( signature = ( max_abs_distance , double_pad=true ) ) ]
176- pub fn sib_matrix ( & self , max_abs_distance : usize , double_pad : bool ) -> Vec < Vec < usize > > {
178+ #[ pyo3( signature = ( max_rel_distance , double_pad=true ) ) ]
179+ pub fn sib_matrix ( & self , max_rel_distance : usize , double_pad : bool ) -> Vec < Vec < usize > > {
177180 fn cmp_nodes (
178181 a : usize ,
179182 b : usize ,
180183 par_child : & HashMap < usize , Vec < usize > > ,
181184 child_par : & HashMap < usize , usize > ,
182- max_abs_distance : usize ,
185+ max_relative_distance : usize ,
186+ center : usize ,
183187 ) -> Option < usize > {
184188 // Distance to self is always 0 aka center
185189 // This catches the special case where root is compared to root
186190 // which would be problematic in the if let since root has no parents
187191 if a == b {
188- return Some ( max_abs_distance ) ;
192+ return Some ( center ) ;
189193 }
190194
191195 // Root case where a and b are both root and have no parents is caught by a==b
@@ -200,21 +204,23 @@ impl TreeData {
200204 let pos_b = sibilings. iter ( ) . position ( |x| x == & b) . unwrap ( ) ;
201205 let d = usize:: abs_diff ( pos_a, pos_b) ;
202206
203- if d >= max_abs_distance {
207+ if d >= max_relative_distance {
204208 return None ;
205209 }
206210 // == case caught earlier
207211
208212 if pos_a < pos_b {
209- Some ( max_abs_distance + d)
213+ Some ( center + d)
210214 } else {
211- Some ( max_abs_distance - d)
215+ Some ( center - d)
212216 }
213217 } else {
214218 None // Either not related or bigger distance than max so we return max
215219 }
216220 }
217221
222+ let center = max_rel_distance + 1 ;
223+
218224 let ( par_child, child_par) = self . adjacency . iter ( ) . fold (
219225 ( HashMap :: new ( ) , HashMap :: new ( ) ) ,
220226 |( mut par_child, mut child_par) , ( parent, child) | {
@@ -231,7 +237,15 @@ impl TreeData {
231237
232238 let i = ( 0 ..self . adjacency . len ( ) ) . map ( |a_idx| {
233239 let inner_i = ( 0 ..self . adjacency . len ( ) ) . map ( |b_idx| {
234- cmp_nodes ( a_idx, b_idx, & par_child, & child_par, max_abs_distance) . unwrap_or ( 0 )
240+ cmp_nodes (
241+ a_idx,
242+ b_idx,
243+ & par_child,
244+ & child_par,
245+ max_rel_distance,
246+ center,
247+ )
248+ . unwrap_or ( 0 )
235249 } ) ;
236250 if !double_pad {
237251 return inner_i. collect ( ) ;
@@ -446,7 +460,7 @@ mod tests {
446460 fn sib_matrix ( ) {
447461 let expr: RecExpr < HalideLang > = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )" . parse ( ) . unwrap ( ) ;
448462 let data: TreeData = ( & expr) . try_into ( ) . unwrap ( ) ;
449- let par_sib = data. sib_matrix ( 16 , false ) ;
463+ let par_sib = data. sib_matrix ( 15 , false ) ;
450464
451465 assert_eq ! (
452466 par_sib,
@@ -467,7 +481,7 @@ mod tests {
467481 fn anc_matrix ( ) {
468482 let expr: RecExpr < HalideLang > = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )" . parse ( ) . unwrap ( ) ;
469483 let data: TreeData = ( & expr) . try_into ( ) . unwrap ( ) ;
470- let par_sib = data. anc_matrix ( 16 , false ) ;
484+ let par_sib = data. anc_matrix ( 15 , false ) ;
471485
472486 assert_eq ! (
473487 par_sib,
@@ -488,7 +502,7 @@ mod tests {
488502 fn anc_matrix_padded ( ) {
489503 let expr: RecExpr < HalideLang > = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )" . parse ( ) . unwrap ( ) ;
490504 let data: TreeData = ( & expr) . try_into ( ) . unwrap ( ) ;
491- let par_sib = data. anc_matrix ( 16 , true ) ;
505+ let par_sib = data. anc_matrix ( 15 , true ) ;
492506
493507 assert_eq ! (
494508 par_sib,
@@ -511,7 +525,7 @@ mod tests {
511525 fn sib_matrix_padded ( ) {
512526 let expr: RecExpr < HalideLang > = "( < ( * v0 35 ) ( * ( + v0 5 ) 17 ) )" . parse ( ) . unwrap ( ) ;
513527 let data: TreeData = ( & expr) . try_into ( ) . unwrap ( ) ;
514- let par_sib = data. sib_matrix ( 16 , true ) ;
528+ let par_sib = data. sib_matrix ( 15 , true ) ;
515529
516530 assert_eq ! (
517531 par_sib,
0 commit comments