@@ -77,6 +77,8 @@ const PARALLEL_GOAL: &str = "(typeOf (natLam (typeOf (natLam (typeOf (natLam (ty
7777#[ cfg( test) ]
7878mod test {
7979
80+ use std:: time:: Duration ;
81+
8082 use egg:: { AstSize , RecExpr , Runner , SimpleScheduler } ;
8183
8284 use crate :: sketch:: eclass_extract;
@@ -90,7 +92,7 @@ mod test {
9092 }
9193
9294 #[ test]
93- fn mm_full_test ( ) {
95+ fn mm_terms_parse_test ( ) {
9496 // START TERM
9597 let mm: RecExpr < Rise > = MM . parse ( ) . unwrap ( ) ;
9698 println ! ( "mm nodes: {}" , mm. len( ) ) ;
@@ -126,39 +128,72 @@ mod test {
126128 pub fn baseline_goal ( ) {
127129 let mm: RecExpr < Rise > = MM . parse ( ) . unwrap ( ) ;
128130 let baseline_goal: RecExpr < Rise > = BASELINE_GOAL . parse ( ) . unwrap ( ) ;
129- let bg_for_closure = baseline_goal. clone ( ) ;
130131
131132 let runner = Runner :: default ( )
132133 . with_expr ( & mm)
133134 . with_iter_limit ( 3 )
134135 . with_scheduler ( SimpleScheduler )
135- . with_hook ( move |r| {
136- // printer_hook(r);
137- if r. egraph . lookup_expr ( & bg_for_closure) . is_some ( ) {
138- return Err ( "FOUND IT" . to_owned ( ) ) ;
139- }
140- Ok ( ( ) )
141- } )
136+ . with_hook ( find_closure ( baseline_goal. clone ( ) ) )
142137 . run ( & rules ( RiseRuleset :: MM ) ) ;
143138
144139 let root_mm = runner. egraph . find ( runner. roots [ 0 ] ) ;
145140 assert_eq ! ( root_mm, runner. egraph. lookup_expr( & mm) . unwrap( ) ) ;
146141
147142 let baseline_sketch = sketchify ( BASELINE_GOAL ) ;
148- let ( _, sketch_extracted_baseline) =
149- eclass_extract ( & baseline_sketch, AstSize , & runner. egraph , root_mm) . unwrap ( ) ;
150-
151- let diff = find_diff (
152- & sketch_extracted_baseline,
153- sketch_extracted_baseline. root ( ) ,
154- & baseline_goal,
155- baseline_goal. root ( ) ,
156- ) ;
157143
158- assert_eq ! ( diff, None ) ;
159144 assert_eq ! ( root_mm, runner. egraph. lookup_expr( & baseline_goal) . unwrap( ) ) ;
160145 }
161146
147+ #[ test]
148+ pub fn blocking_goal ( ) {
149+ let mm: RecExpr < Rise > = MM . parse ( ) . unwrap ( ) ;
150+ let split_guide: RecExpr < Rise > = SPLIT_GUIDE . parse ( ) . unwrap ( ) ;
151+
152+ let runner_1 = Runner :: default ( )
153+ . with_expr ( & mm)
154+ . with_iter_limit ( 5 )
155+ . with_time_limit ( Duration :: from_secs ( 30 ) )
156+ . with_node_limit ( 1_000_000 )
157+ . with_scheduler ( SimpleScheduler )
158+ . with_hook ( find_closure ( split_guide. clone ( ) ) )
159+ . run ( & rules ( RiseRuleset :: MM ) ) ;
160+
161+ println ! ( "{}" , runner_1. report( ) ) ;
162+
163+ let root_mm = runner_1. egraph . find ( runner_1. roots [ 0 ] ) ;
164+ let baseline_sketch = sketchify ( SPLIT_GUIDE ) ;
165+ let ( _, sketch_extracted_split_guide) =
166+ eclass_extract ( & baseline_sketch, AstSize , & runner_1. egraph , root_mm) . unwrap ( ) ;
167+
168+ assert_eq ! ( None , find_diff( & sketch_extracted_split_guide, & split_guide) ) ;
169+ assert_eq ! ( root_mm, runner_1. egraph. lookup_expr( & split_guide) . unwrap( ) ) ;
170+
171+ let blocking_goal: RecExpr < Rise > = BLOCKING_GOAL . parse ( ) . unwrap ( ) ;
172+ let runner_2 = Runner :: default ( )
173+ . with_expr ( & split_guide)
174+ . with_iter_limit ( 6 )
175+ . with_scheduler ( SimpleScheduler )
176+ . with_hook ( find_closure ( blocking_goal. clone ( ) ) )
177+ . run ( & rules ( RiseRuleset :: MM ) ) ;
178+
179+ let root_guide = runner_2. egraph . find ( runner_2. roots [ 0 ] ) ;
180+ assert_eq ! (
181+ root_guide,
182+ runner_2. egraph. lookup_expr( & blocking_goal) . unwrap( )
183+ ) ;
184+ }
185+
186+ fn find_closure (
187+ bg_for_closure : RecExpr < Rise > ,
188+ ) -> impl FnMut ( & mut Runner < Rise , RiseAnalysis > ) -> Result < ( ) , String > {
189+ move |r : & mut Runner < Rise , RiseAnalysis > | {
190+ if r. egraph . lookup_expr ( & bg_for_closure) . is_some ( ) {
191+ return Err ( "FOUND IT" . to_owned ( ) ) ;
192+ }
193+ Ok ( ( ) )
194+ }
195+ }
196+
162197 fn printer_hook ( r : & mut Runner < Rise , RiseAnalysis > ) {
163198 println ! ( "ITERATION {}:\n " , r. iterations. len( ) ) ;
164199 println ! ( "Nodes: {}\n " , r. egraph. nodes( ) . len( ) ) ;
@@ -188,25 +223,28 @@ mod test {
188223 . unwrap ( )
189224 }
190225
191- fn find_diff (
192- lhs : & RecExpr < Rise > ,
193- lhs_id : Id ,
194- rhs : & RecExpr < Rise > ,
195- rhs_id : Id ,
196- ) -> Option < ( Rise , Rise ) > {
197- if let Rise :: Var ( index) = lhs[ lhs_id]
198- && let Rise :: Var ( index2) = rhs[ rhs_id]
199- && index != index2
200- {
201- Some ( ( lhs[ lhs_id] . clone ( ) , rhs[ rhs_id] . clone ( ) ) )
202- } else if lhs[ lhs_id] . matches ( & rhs[ rhs_id] ) {
203- lhs[ lhs_id]
204- . children ( )
205- . iter ( )
206- . zip ( rhs[ rhs_id] . children ( ) )
207- . find_map ( |( lcid, rcid) | find_diff ( lhs, * lcid, rhs, * rcid) )
208- } else {
209- Some ( ( lhs[ lhs_id] . clone ( ) , rhs[ rhs_id] . clone ( ) ) )
226+ fn find_diff ( lhs : & RecExpr < Rise > , rhs : & RecExpr < Rise > ) -> Option < ( Rise , Rise ) > {
227+ fn rec (
228+ lhs : & RecExpr < Rise > ,
229+ lhs_id : Id ,
230+ rhs : & RecExpr < Rise > ,
231+ rhs_id : Id ,
232+ ) -> Option < ( Rise , Rise ) > {
233+ if let Rise :: Var ( index) = lhs[ lhs_id]
234+ && let Rise :: Var ( index2) = rhs[ rhs_id]
235+ && index != index2
236+ {
237+ Some ( ( lhs[ lhs_id] . clone ( ) , rhs[ rhs_id] . clone ( ) ) )
238+ } else if lhs[ lhs_id] . matches ( & rhs[ rhs_id] ) {
239+ lhs[ lhs_id]
240+ . children ( )
241+ . iter ( )
242+ . zip ( rhs[ rhs_id] . children ( ) )
243+ . find_map ( |( lcid, rcid) | rec ( lhs, * lcid, rhs, * rcid) )
244+ } else {
245+ Some ( ( lhs[ lhs_id] . clone ( ) , rhs[ rhs_id] . clone ( ) ) )
246+ }
210247 }
248+ rec ( lhs, lhs. root ( ) , rhs, rhs. root ( ) )
211249 }
212250}
0 commit comments