6
6
7
7
# Please refer to README.md in the same folder for more information.
8
8
9
+ import logging
10
+ from collections import defaultdict , deque
9
11
from dataclasses import dataclass
10
12
from functools import partial
11
13
from typing import Dict , List , Optional , Tuple
23
25
24
26
from torch import nn
25
27
28
+ logger = logging .getLogger (__name__ )
29
+
26
30
27
31
def find_multiple (n : int , k : int ) -> int :
28
32
if n % k == 0 :
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507
511
508
512
509
513
class InputManager :
514
+ class NGramCache :
515
+ def __init__ (self , max_size : int ):
516
+ self .cache = deque ()
517
+ self .max_size = max_size
518
+
519
+ def add (self , ngram : List [int ]):
520
+ if ngram in self .cache :
521
+ return
522
+ if len (self .cache ) == self .max_size :
523
+ self .cache .popleft ()
524
+ self .cache .append (ngram )
525
+
526
+ def __iter__ (self ):
527
+ return iter (self .cache )
528
+
529
+ def __str__ (self ):
530
+ return str (self .cache )
531
+
510
532
def __init__ (
511
533
self ,
512
534
n_layers : int ,
@@ -519,6 +541,7 @@ def __init__(
519
541
dtype = torch .float16 ,
520
542
minus_infinity = - torch .inf ,
521
543
cache_size = None ,
544
+ lookahead_enabled : bool = False ,
522
545
):
523
546
if cache_size is None :
524
547
cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532
555
533
556
self .seq_length = seq_length
534
557
self .use_cache_list = use_cache_list
558
+ self .lookahead_enabled = lookahead_enabled
559
+ self .minus_infinity = minus_infinity
535
560
536
561
if self .use_cache_list :
537
562
self .k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609
634
if self .cache_pos == self .cache_size :
610
635
self .cache_pos = 0
611
636
612
- def update (self , input_length , new_k_caches , new_v_caches ):
637
+ def update (self , input_length , new_k_caches , new_v_caches , update_pos = 0 ):
613
638
# Copy as much new cache data into cache as possible without wrapping
614
639
amount_to_copy = min (input_length , self .cache_size - self .cache_pos )
615
- self ._update_cache (0 , amount_to_copy , new_k_caches , new_v_caches )
640
+ self ._update_cache (update_pos , amount_to_copy , new_k_caches , new_v_caches )
616
641
if self .input_pos <= self .cache_size :
617
642
self .attn_mask [:, (self .input_pos ) : (self .input_pos + amount_to_copy )] = (
618
643
0.0
@@ -625,7 +650,7 @@ def update(self, input_length, new_k_caches, new_v_caches):
625
650
)
626
651
if remaining_to_copy > 0 :
627
652
self ._update_cache (
628
- amount_to_copy , remaining_to_copy , new_k_caches , new_v_caches
653
+ update_pos + amount_to_copy , remaining_to_copy , new_k_caches , new_v_caches
629
654
)
630
655
631
656
self .input_pos += input_length
@@ -661,3 +686,192 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661
686
self .get_inputs (tokens [0 :processed_tokens ]),
662
687
tokens [processed_tokens :],
663
688
)
689
+
690
+ def _get_lookahead_decoding_mask (
691
+ self , ngram_size : int , window_size : int , n_verifications : int
692
+ ) -> torch .Tensor :
693
+ mask = torch .full ((self .seq_length , self .seq_length ), self .minus_infinity )
694
+ mask [0 ][0 ] = 0.0
695
+
696
+ lookahead_submask = torch .triu (
697
+ torch .full ((window_size , window_size ), self .minus_infinity ),
698
+ diagonal = 1 ,
699
+ )
700
+ for i in range (ngram_size - 1 ):
701
+ offset = window_size * i
702
+ mask [offset : offset + window_size , :window_size ] = lookahead_submask
703
+ for j in range (1 , i + 1 ):
704
+ mask [
705
+ offset : offset + window_size ,
706
+ window_size * j : window_size * (j + 1 ),
707
+ ].fill_diagonal_ (0.0 )
708
+
709
+ verification_offset = max (window_size * (ngram_size - 1 ), 1 )
710
+ verification_submask = torch .triu (
711
+ torch .full ((ngram_size - 1 , ngram_size - 1 ), self .minus_infinity ),
712
+ diagonal = 1 ,
713
+ )
714
+ for i in range (n_verifications ):
715
+ mask [
716
+ verification_offset + i * (ngram_size - 1 ) : verification_offset
717
+ + (i + 1 ) * (ngram_size - 1 ),
718
+ verification_offset + i * (ngram_size - 1 ) : verification_offset
719
+ + (i + 1 ) * (ngram_size - 1 ),
720
+ ] = verification_submask
721
+ mask [verification_offset :, :1 ] = 0.0
722
+
723
+ return mask
724
+
725
+ def _get_lookahead_position_offsets (
726
+ self , ngram_size : int , window_size : int , n_verifications : int
727
+ ) -> torch .Tensor :
728
+ pos_offsets = torch .zeros (self .seq_length , dtype = torch .int32 )
729
+ idx = 0
730
+ if window_size > 0 :
731
+ for i in range (ngram_size - 1 ):
732
+ for j in range (window_size ):
733
+ pos_offsets [idx ] = i + j
734
+ idx += 1
735
+ else :
736
+ pos_offsets [0 ] = 0
737
+ idx += 1
738
+
739
+ # Verification branches: [1, 2, ..., ngram_size - 1].
740
+ for _ in range (n_verifications ):
741
+ for j in range (1 , ngram_size ):
742
+ pos_offsets [idx ] = j
743
+ idx += 1
744
+
745
+ return pos_offsets
746
+
747
+ def lookahead_decode (
748
+ self ,
749
+ model ,
750
+ init_token : int ,
751
+ n : int ,
752
+ ngram_size : int ,
753
+ window_size : int ,
754
+ n_verifications : int ,
755
+ stop_tokens : Optional [List [int ]] = None ,
756
+ ngram_caches : Optional [Dict [int , "InputManager.NGramCache" ]] = None ,
757
+ ) -> List [int ]:
758
+ if not self .lookahead_enabled :
759
+ raise RuntimeError ("Lookahead decoding is not enabled" )
760
+
761
+ if (ngram_size - 1 ) * (window_size + n_verifications ) > self .seq_length :
762
+ raise RuntimeError (
763
+ f"Lookahead decoding configuration not compatible with seq_length { self .seq_length } . "
764
+ f"Required: { (ngram_size - 1 ) * (window_size + n_verifications )} "
765
+ )
766
+
767
+ self .attn_mask [:, self .cache_size :] = self ._get_lookahead_decoding_mask (
768
+ ngram_size , window_size , n_verifications
769
+ )
770
+ logger .debug ("Lookahead decoding mask: " )
771
+ for i in range (self .seq_length ):
772
+ logger .debug (
773
+ " " .join (
774
+ ("X" if x == 0.0 else " " )
775
+ for x in self .attn_mask [i ][self .cache_size :]
776
+ )
777
+ )
778
+
779
+ offsets = self ._get_lookahead_position_offsets (
780
+ ngram_size , window_size , n_verifications
781
+ )
782
+
783
+ stop_tokens = stop_tokens or []
784
+ verification_offset = window_size * (ngram_size - 1 )
785
+
786
+ if ngram_caches is None :
787
+ ngram_caches = defaultdict (lambda : InputManager .NGramCache (n_verifications ))
788
+ new_tokens = [init_token ]
789
+ x = [init_token ] * self .seq_length
790
+ inference_count = 0
791
+
792
+ while len (new_tokens ) < n + 1 :
793
+ cache = ngram_caches [x [0 ]]
794
+ for i , ngram in enumerate (cache ):
795
+ for j , token in enumerate (ngram ):
796
+ x [verification_offset + i * (ngram_size - 1 ) + j ] = token
797
+
798
+ logits , new_k , new_v = model (
799
+ tokens = torch .tensor ([x ], dtype = torch .int64 ),
800
+ input_pos = torch .tensor ([self .input_pos ], dtype = torch .long ),
801
+ k_caches = self .k_caches ,
802
+ v_caches = self .v_caches ,
803
+ attn_mask = self .attn_mask ,
804
+ input_len = torch .tensor ([len (x )], dtype = torch .long ),
805
+ rope_indices = self .input_pos + offsets ,
806
+ )
807
+ inference_count += 1
808
+
809
+ # Greedy only
810
+ y = logits [0 ].argmax (dim = - 1 ).tolist ()
811
+ new_tokens .append (y [0 ])
812
+ logger .debug (f"{ self .input_pos } : x = { x [0 ]} , y = { y [0 ]} " )
813
+ if new_tokens [- 1 ] in stop_tokens :
814
+ break
815
+
816
+ # Collect new n-grams.
817
+ for i in range (window_size ):
818
+ key = x [i ]
819
+ suffix = []
820
+ for j in range (1 , ngram_size - 1 ):
821
+ suffix .append (x [i + j * window_size ])
822
+ suffix .append (y [i + window_size * (ngram_size - 2 )])
823
+ ngram_caches [key ].add (suffix )
824
+
825
+ # Verification.
826
+ longest_match = []
827
+ matched_branch = None
828
+ for i in range (n_verifications ):
829
+ match = [y [0 ]]
830
+ j = 0
831
+ # for j in range(ngram_size - 1):
832
+ while (
833
+ j < ngram_size - 1
834
+ and x [verification_offset + (ngram_size - 1 ) * i + j ] == match [- 1 ]
835
+ ):
836
+ match .append (y [verification_offset + (ngram_size - 1 ) * i + j ])
837
+ j += 1
838
+ if len (match ) - 1 > len (longest_match ):
839
+ longest_match = match [1 :]
840
+ matched_branch = i
841
+
842
+ if matched_branch is not None :
843
+ logger .debug (
844
+ f"Matched { len (longest_match )} additional tokens from n-grams: { longest_match } "
845
+ )
846
+ for stop in stop_tokens :
847
+ if stop in longest_match :
848
+ longest_match = longest_match [: longest_match .index (stop ) + 1 ]
849
+
850
+ new_tokens .extend (longest_match )
851
+ branch_offset = verification_offset + (ngram_size - 1 ) * matched_branch
852
+ self .update (
853
+ input_length = len (longest_match ),
854
+ new_k_caches = new_k ,
855
+ new_v_caches = new_v ,
856
+ update_pos = branch_offset ,
857
+ )
858
+ else :
859
+ self .update (input_length = 1 , new_k_caches = new_k , new_v_caches = new_v )
860
+
861
+ # Update lookahead branch.
862
+ for i in range (ngram_size - 2 ):
863
+ for j in range (window_size ):
864
+ x [window_size * i + j ] = x [window_size * (i + 1 ) + j ]
865
+ for j in range (window_size ):
866
+ x [window_size * (ngram_size - 2 ) + j ] = y [
867
+ window_size * (ngram_size - 2 ) + j
868
+ ]
869
+
870
+ x [0 ] = new_tokens [- 1 ]
871
+ if new_tokens [- 1 ] in stop_tokens :
872
+ break
873
+
874
+ logger .info (
875
+ f"Generated { len (new_tokens ) - 1 } tokens with { inference_count } inference(s)."
876
+ )
877
+ return new_tokens
0 commit comments