@@ -327,6 +327,35 @@ where
327327 TrieHard :: U256 ( trie) => TrieIter :: U256 ( trie. prefix_search ( prefix) ) ,
328328 }
329329 }
330+
331+ /// Find the closest ancestor to the given key, where an ancestor is defined as the longest
332+ /// string present in the trie that appears as a prefix of the given key.
333+ ///
334+ /// ```
335+ /// # use trie_hard::TrieHard;
336+ /// let trie = ["dad", "ant", "and", "dot", "do"]
337+ /// .into_iter()
338+ /// .collect::<TrieHard<'_, _>>();
339+ ///
340+ /// assert_eq!(
341+ /// trie.ancestor("dada").map(|(_, v)| v),
342+ /// Some("dad")
343+ /// );
344+ /// assert_eq!(
345+ /// trie.ancestor("an").map(|(_, v)| v),
346+ /// None
347+ /// );
348+ /// ```
349+ pub fn ancestor < K : AsRef < [ u8 ] > > ( & self , key : K ) -> Option < ( & [ u8 ] , T ) > {
350+ match self {
351+ TrieHard :: U8 ( trie) => trie. ancestor ( key) ,
352+ TrieHard :: U16 ( trie) => trie. ancestor ( key) ,
353+ TrieHard :: U32 ( trie) => trie. ancestor ( key) ,
354+ TrieHard :: U64 ( trie) => trie. ancestor ( key) ,
355+ TrieHard :: U128 ( trie) => trie. ancestor ( key) ,
356+ TrieHard :: U256 ( trie) => trie. ancestor ( key) ,
357+ }
358+ }
330359}
331360
332361/// Structure used for iterative over the contents of trie
@@ -591,6 +620,71 @@ macro_rules! trie_impls {
591620
592621 TrieIterSized :: new( self , node_index)
593622 }
623+
624+ /// Find the closest ancestor to the given key, where an ancestor is defined as the
625+ /// longest string present in the trie that appears as a prefix of the given key.
626+ ///
627+ /// ```
628+ /// # use trie_hard::TrieHard;
629+ /// let trie = ["dad", "ant", "and", "dot", "do"]
630+ /// .into_iter()
631+ /// .collect::<TrieHard<'_, _>>();
632+ ///
633+ /// let TrieHard::U8(sized_trie) = trie else {
634+ /// unreachable!()
635+ /// };
636+ ///
637+ /// assert_eq!(
638+ /// sized_trie.ancestor("dada").map(|(_, v)| v),
639+ /// Some("dad")
640+ /// );
641+ /// assert_eq!(
642+ /// sized_trie.ancestor("an").map(|(_, v)| v),
643+ /// None
644+ /// );
645+ /// ```
646+ pub fn ancestor<K : AsRef <[ u8 ] >>(
647+ & self ,
648+ key: K ,
649+ ) -> Option <( & [ u8 ] , T ) > {
650+ self . ancestor_recurse( 0 , key. as_ref( ) , self . nodes. get( 0 ) ?)
651+ }
652+
653+ fn ancestor_recurse(
654+ & self ,
655+ i: usize ,
656+ key: & [ u8 ] ,
657+ state: & TrieState <' a, T , $int_type>,
658+ ) -> Option <( & [ u8 ] , T ) > {
659+ match state {
660+ TrieState :: Leaf ( k, value) => {
661+ (
662+ k. len( ) <= key. len( )
663+ && k[ i..] == key[ i..k. len( ) ]
664+ ) . then_some( ( k, * value) )
665+ }
666+ TrieState :: Search ( search) => {
667+ let c = key. get( i) ?;
668+ let next_state_index = search. evaluate( * c, self ) ?;
669+ self . ancestor_recurse( i + 1 , key, & self . nodes[ next_state_index] )
670+ }
671+ TrieState :: SearchOrLeaf ( k, value, search) => {
672+ // lambda to enable using `?` operator
673+ let search = || {
674+ let c = key. get( i) ?;
675+ let next_state_index = search. evaluate( * c, self ) ?;
676+ self . ancestor_recurse( i + 1 , key, & self . nodes[ next_state_index] )
677+ } ;
678+
679+ search( ) . or_else( || {
680+ (
681+ k. len( ) <= key. len( )
682+ && k[ i..] == key[ i..k. len( ) ]
683+ ) . then_some( ( k, * value) )
684+ } )
685+ }
686+ }
687+ }
594688 }
595689
596690 impl <' a, T > TrieHardSized <' a, T , $int_type> where T : ' a + Copy {
@@ -932,4 +1026,30 @@ mod tests {
9321026 . collect :: < Vec < _ > > ( ) ;
9331027 assert_eq ! ( emitted, output) ;
9341028 }
1029+
1030+ #[ rstest]
1031+ #[ case( & [ ] , "" , None ) ]
1032+ #[ case( & [ "" ] , "" , Some ( "" ) ) ]
1033+ #[ case( & [ "aaa" , "a" , "" ] , "" , Some ( "" ) ) ]
1034+ #[ case( & [ "aaa" , "a" , "" ] , "a" , Some ( "a" ) ) ]
1035+ #[ case( & [ "aaa" , "a" , "" ] , "aa" , Some ( "a" ) ) ]
1036+ #[ case( & [ "aaa" , "a" , "" ] , "aab" , Some ( "a" ) ) ]
1037+ #[ case( & [ "aaa" , "a" , "" ] , "aaa" , Some ( "aaa" ) ) ]
1038+ #[ case( & [ "aaa" , "a" , "" ] , "b" , Some ( "" ) ) ]
1039+ #[ case( & [ "dad" , "ant" , "and" , "dot" , "do" ] , "d" , None ) ]
1040+ #[ case( & [ "dad" , "ant" , "and" , "dot" , "do" ] , "dad" , Some ( "dad" ) ) ]
1041+ #[ case( & [ "dad" , "ant" , "and" , "dot" , "do" ] , "dada" , Some ( "dad" ) ) ]
1042+ #[ case( & [ "dad" , "ant" , "and" , "dot" , "do" ] , "do" , Some ( "do" ) ) ]
1043+ #[ case( & [ "dad" , "ant" , "and" , "dot" , "do" ] , "dot" , Some ( "dot" ) ) ]
1044+ #[ case( & [ "dad" , "ant" , "and" , "dot" , "do" ] , "dob" , Some ( "do" ) ) ]
1045+ #[ case( & [ "dad" , "ant" , "and" , "dot" , "do" ] , "doto" , Some ( "dot" ) ) ]
1046+ fn test_ancestor (
1047+ #[ case] input : & [ & str ] ,
1048+ #[ case] key : & str ,
1049+ #[ case] output : Option < & str > ,
1050+ ) {
1051+ let trie = input. iter ( ) . copied ( ) . collect :: < TrieHard < ' _ , _ > > ( ) ;
1052+ let emitted = trie. ancestor ( key) . map ( |( _, v) | v) ;
1053+ assert_eq ! ( emitted, output) ;
1054+ }
9351055}
0 commit comments