6
6
7
7
# Please refer to README.md in the same folder for more information.
8
8
9
+ import logging
10
+ from collections import defaultdict , deque
9
11
from dataclasses import dataclass
10
12
from functools import partial
11
13
from typing import Dict , List , Optional , Tuple
23
25
24
26
from torch import nn
25
27
28
+ logger = logging .getLogger (__name__ )
29
+
26
30
27
31
def find_multiple (n : int , k : int ) -> int :
28
32
if n % k == 0 :
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507
511
508
512
509
513
class InputManager :
514
+ class NGramCache :
515
+ def __init__ (self , max_size : int ):
516
+ self .cache = deque ()
517
+ self .max_size = max_size
518
+
519
+ def add (self , ngram : List [int ]):
520
+ if ngram in self .cache :
521
+ return
522
+ if len (self .cache ) == self .max_size :
523
+ self .cache .popleft ()
524
+ self .cache .append (ngram )
525
+
526
+ def __iter__ (self ):
527
+ return iter (self .cache )
528
+
529
+ def __str__ (self ):
530
+ return str (self .cache )
531
+
510
532
def __init__ (
511
533
self ,
512
534
n_layers : int ,
@@ -519,6 +541,7 @@ def __init__(
519
541
dtype = torch .float16 ,
520
542
minus_infinity = - torch .inf ,
521
543
cache_size = None ,
544
+ lookahead_enabled : bool = False ,
522
545
):
523
546
if cache_size is None :
524
547
cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532
555
533
556
self .seq_length = seq_length
534
557
self .use_cache_list = use_cache_list
558
+ self .lookahead_enabled = lookahead_enabled
559
+ self .minus_infinity = minus_infinity
535
560
536
561
if self .use_cache_list :
537
562
self .k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609
634
if self .cache_pos == self .cache_size :
610
635
self .cache_pos = 0
611
636
612
- def update (self , input_length , new_k_caches , new_v_caches ):
637
+ def update (self , input_length , new_k_caches , new_v_caches , update_pos = 0 ):
613
638
# Copy as much new cache data into cache as possible without wrapping
614
639
amount_to_copy = min (input_length , self .cache_size - self .cache_pos )
615
- self ._update_cache (0 , amount_to_copy , new_k_caches , new_v_caches )
640
+ self ._update_cache (update_pos , amount_to_copy , new_k_caches , new_v_caches )
616
641
if self .input_pos <= self .cache_size :
617
642
self .attn_mask [:, (self .input_pos ) : (self .input_pos + amount_to_copy )] = (
618
643
0.0
@@ -625,7 +650,10 @@ def update(self, input_length, new_k_caches, new_v_caches):
625
650
)
626
651
if remaining_to_copy > 0 :
627
652
self ._update_cache (
628
- amount_to_copy , remaining_to_copy , new_k_caches , new_v_caches
653
+ update_pos + amount_to_copy ,
654
+ remaining_to_copy ,
655
+ new_k_caches ,
656
+ new_v_caches ,
629
657
)
630
658
631
659
self .input_pos += input_length
@@ -661,3 +689,270 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661
689
self .get_inputs (tokens [0 :processed_tokens ]),
662
690
tokens [processed_tokens :],
663
691
)
692
+
693
+ def _get_lookahead_decoding_mask (
694
+ self , ngram_size : int , window_size : int , n_verifications : int
695
+ ) -> torch .Tensor :
696
+ mask = torch .full ((self .seq_length , self .seq_length ), self .minus_infinity )
697
+ mask [0 ][0 ] = 0.0
698
+
699
+ lookahead_submask = torch .triu (
700
+ torch .full ((window_size , window_size ), self .minus_infinity ),
701
+ diagonal = 1 ,
702
+ )
703
+ for i in range (ngram_size - 1 ):
704
+ offset = window_size * i
705
+ mask [offset : offset + window_size , :window_size ] = lookahead_submask
706
+ for j in range (1 , i + 1 ):
707
+ mask [
708
+ offset : offset + window_size ,
709
+ window_size * j : window_size * (j + 1 ),
710
+ ].fill_diagonal_ (0.0 )
711
+
712
+ verification_offset = max (window_size * (ngram_size - 1 ), 1 )
713
+ verification_submask = torch .triu (
714
+ torch .full ((ngram_size - 1 , ngram_size - 1 ), self .minus_infinity ),
715
+ diagonal = 1 ,
716
+ )
717
+ for i in range (n_verifications ):
718
+ mask [
719
+ verification_offset
720
+ + i * (ngram_size - 1 ) : verification_offset
721
+ + (i + 1 ) * (ngram_size - 1 ),
722
+ verification_offset
723
+ + i * (ngram_size - 1 ) : verification_offset
724
+ + (i + 1 ) * (ngram_size - 1 ),
725
+ ] = verification_submask
726
+ mask [verification_offset :, :1 ] = 0.0
727
+
728
+ return mask
729
+
730
+ def _get_lookahead_position_offsets (
731
+ self , ngram_size : int , window_size : int , n_verifications : int
732
+ ) -> torch .Tensor :
733
+ pos_offsets = torch .zeros (self .seq_length , dtype = torch .int32 )
734
+ idx = 0
735
+ if window_size > 0 :
736
+ for i in range (ngram_size - 1 ):
737
+ for j in range (window_size ):
738
+ pos_offsets [idx ] = i + j
739
+ idx += 1
740
+ else :
741
+ pos_offsets [0 ] = 0
742
+ idx += 1
743
+
744
+ # Verification branches: [1, 2, ..., ngram_size - 1].
745
+ for _ in range (n_verifications ):
746
+ for j in range (1 , ngram_size ):
747
+ pos_offsets [idx ] = j
748
+ idx += 1
749
+
750
+ return pos_offsets
751
+
752
+ def _validate_lookahead_config (
753
+ self , ngram_size : int , window_size : int , n_verifications : int
754
+ ) -> None :
755
+ """
756
+ Validate the lookahead decoding configuration.
757
+ """
758
+ if not self .lookahead_enabled :
759
+ raise RuntimeError ("Lookahead decoding is not enabled" )
760
+
761
+ if (ngram_size - 1 ) * (window_size + n_verifications ) > self .seq_length :
762
+ raise RuntimeError (
763
+ f"Lookahead decoding configuration not compatible with seq_length { self .seq_length } . "
764
+ f"Required: { (ngram_size - 1 ) * (window_size + n_verifications )} "
765
+ )
766
+
767
+ def _setup_lookahead_mask (
768
+ self , ngram_size : int , window_size : int , n_verifications : int
769
+ ) -> None :
770
+ """
771
+ Set up the attention mask for lookahead decoding and log debug information.
772
+ """
773
+ self .attn_mask [:, self .cache_size :] = self ._get_lookahead_decoding_mask (
774
+ ngram_size , window_size , n_verifications
775
+ )
776
+ logger .debug ("Lookahead decoding mask: " )
777
+ for i in range (self .seq_length ):
778
+ logger .debug (
779
+ " " .join (
780
+ ("X" if x == 0.0 else " " )
781
+ for x in self .attn_mask [i ][self .cache_size :]
782
+ )
783
+ )
784
+
785
+ def _populate_verification_branches (
786
+ self , x : List [int ], cache , verification_offset : int , ngram_size : int
787
+ ) -> None :
788
+ """
789
+ Populate verification branches with tokens from the n-gram cache.
790
+ """
791
+ for i , ngram in enumerate (cache ):
792
+ for j , token in enumerate (ngram ):
793
+ x [verification_offset + i * (ngram_size - 1 ) + j ] = token
794
+
795
+ def _collect_ngrams (
796
+ self ,
797
+ x : List [int ],
798
+ y : List [int ],
799
+ ngram_caches : Dict [int , "InputManager.NGramCache" ],
800
+ window_size : int ,
801
+ ngram_size : int ,
802
+ ) -> None :
803
+ """
804
+ Collect new n-grams from the current state and predictions.
805
+ """
806
+ for i in range (window_size ):
807
+ key = x [i ]
808
+ suffix = []
809
+ for j in range (1 , ngram_size - 1 ):
810
+ suffix .append (x [i + j * window_size ])
811
+ suffix .append (y [i + window_size * (ngram_size - 2 )])
812
+ ngram_caches [key ].add (suffix )
813
+
814
+ def _find_longest_match (
815
+ self ,
816
+ x : List [int ],
817
+ y : List [int ],
818
+ verification_offset : int ,
819
+ n_verifications : int ,
820
+ ngram_size : int ,
821
+ ) -> Tuple [List [int ], Optional [int ]]:
822
+ """
823
+ Find the longest matching sequence from verification branches.
824
+ Returns the matched tokens and the branch index.
825
+ """
826
+ longest_match = []
827
+ matched_branch = None
828
+
829
+ for i in range (n_verifications ):
830
+ match = [y [0 ]]
831
+ j = 0
832
+ while (
833
+ j < ngram_size - 1
834
+ and x [verification_offset + (ngram_size - 1 ) * i + j ] == match [- 1 ]
835
+ ):
836
+ match .append (y [verification_offset + (ngram_size - 1 ) * i + j ])
837
+ j += 1
838
+ if len (match ) - 1 > len (longest_match ):
839
+ longest_match = match [1 :]
840
+ matched_branch = i
841
+
842
+ return longest_match , matched_branch
843
+
844
+ def _update_lookahead_branches (
845
+ self , x : List [int ], y : List [int ], ngram_size : int , window_size : int
846
+ ) -> None :
847
+ """
848
+ Update the lookahead branches with new predictions.
849
+ """
850
+ # Shift window contents up
851
+ for i in range (ngram_size - 2 ):
852
+ for j in range (window_size ):
853
+ x [window_size * i + j ] = x [window_size * (i + 1 ) + j ]
854
+
855
+ # Fill the last window with new predictions
856
+ for j in range (window_size ):
857
+ x [window_size * (ngram_size - 2 ) + j ] = y [
858
+ window_size * (ngram_size - 2 ) + j
859
+ ]
860
+
861
+ def lookahead_decode (
862
+ self ,
863
+ model ,
864
+ init_token : int ,
865
+ n : int ,
866
+ ngram_size : int ,
867
+ window_size : int ,
868
+ n_verifications : int ,
869
+ stop_tokens : Optional [List [int ]] = None ,
870
+ ngram_caches : Optional [Dict [int , "InputManager.NGramCache" ]] = None ,
871
+ ) -> List [int ]:
872
+ # Validate configuration
873
+ self ._validate_lookahead_config (ngram_size , window_size , n_verifications )
874
+
875
+ # Setup attention mask and position offsets
876
+ self ._setup_lookahead_mask (ngram_size , window_size , n_verifications )
877
+ offsets = self ._get_lookahead_position_offsets (
878
+ ngram_size , window_size , n_verifications
879
+ )
880
+
881
+ # Initialize state
882
+ stop_tokens = stop_tokens or []
883
+ verification_offset = window_size * (ngram_size - 1 )
884
+ if ngram_caches is None :
885
+ ngram_caches = defaultdict (lambda : InputManager .NGramCache (n_verifications ))
886
+
887
+ new_tokens = [init_token ]
888
+ x = [init_token ] * self .seq_length
889
+ inference_count = 0
890
+
891
+ # Main decoding loop
892
+ while len (new_tokens ) < n + 1 :
893
+ # Populate verification branches
894
+ cache = ngram_caches [x [0 ]]
895
+ self ._populate_verification_branches (
896
+ x , cache , verification_offset , ngram_size
897
+ )
898
+
899
+ # Run model inference
900
+ logits , new_k , new_v = model (
901
+ tokens = torch .tensor ([x ], dtype = torch .int64 ),
902
+ input_pos = torch .tensor ([self .input_pos ], dtype = torch .long ),
903
+ k_caches = self .k_caches ,
904
+ v_caches = self .v_caches ,
905
+ attn_mask = self .attn_mask ,
906
+ input_len = torch .tensor ([len (x )], dtype = torch .long ),
907
+ rope_indices = self .input_pos + offsets ,
908
+ )
909
+ inference_count += 1
910
+
911
+ # Process model output (greedy selection)
912
+ y = logits [0 ].argmax (dim = - 1 ).tolist ()
913
+ new_tokens .append (y [0 ])
914
+ logger .debug (f"{ self .input_pos } : x = { x [0 ]} , y = { y [0 ]} " )
915
+ if new_tokens [- 1 ] in stop_tokens :
916
+ break
917
+
918
+ # Collect new n-grams
919
+ self ._collect_ngrams (x , y , ngram_caches , window_size , ngram_size )
920
+
921
+ # Find longest match from verification branches
922
+ longest_match , matched_branch = self ._find_longest_match (
923
+ x , y , verification_offset , n_verifications , ngram_size
924
+ )
925
+
926
+ # Process match results
927
+ if matched_branch is not None :
928
+ logger .debug (
929
+ f"Matched { len (longest_match )} additional tokens from n-grams: { longest_match } "
930
+ )
931
+ # Truncate at stop token if present
932
+ for stop in stop_tokens :
933
+ if stop in longest_match :
934
+ longest_match = longest_match [: longest_match .index (stop ) + 1 ]
935
+
936
+ new_tokens .extend (longest_match )
937
+ branch_offset = verification_offset + (ngram_size - 1 ) * matched_branch
938
+ self .update (
939
+ input_length = len (longest_match ),
940
+ new_k_caches = new_k ,
941
+ new_v_caches = new_v ,
942
+ update_pos = branch_offset ,
943
+ )
944
+ else :
945
+ self .update (input_length = 1 , new_k_caches = new_k , new_v_caches = new_v )
946
+
947
+ # Update lookahead branches
948
+ self ._update_lookahead_branches (x , y , ngram_size , window_size )
949
+
950
+ # Update first token and check for stop condition
951
+ x [0 ] = new_tokens [- 1 ]
952
+ if new_tokens [- 1 ] in stop_tokens :
953
+ break
954
+
955
+ logger .info (
956
+ f"Generated { len (new_tokens ) - 1 } tokens with { inference_count } inference(s)."
957
+ )
958
+ return new_tokens
0 commit comments