Skip to content

Commit 28cb813

Browse files
Lookahead decoding eager implementation (#12491)
Summary: Implement reference lookahead decoding for CoreML implementation. Reviewed By: sxu, billmguo Differential Revision: D78323399
1 parent 80da097 commit 28cb813

File tree

4 files changed

+596
-3
lines changed

4 files changed

+596
-3
lines changed

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 298 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# Please refer to README.md in the same folder for more information.
88

9+
import logging
10+
from collections import defaultdict, deque
911
from dataclasses import dataclass
1012
from functools import partial
1113
from typing import Dict, List, Optional, Tuple
@@ -23,6 +25,8 @@
2325

2426
from torch import nn
2527

28+
logger = logging.getLogger(__name__)
29+
2630

2731
def find_multiple(n: int, k: int) -> int:
2832
if n % k == 0:
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507511

508512

509513
class InputManager:
514+
class NGramCache:
515+
def __init__(self, max_size: int):
516+
self.cache = deque()
517+
self.max_size = max_size
518+
519+
def add(self, ngram: List[int]):
520+
if ngram in self.cache:
521+
return
522+
if len(self.cache) == self.max_size:
523+
self.cache.popleft()
524+
self.cache.append(ngram)
525+
526+
def __iter__(self):
527+
return iter(self.cache)
528+
529+
def __str__(self):
530+
return str(self.cache)
531+
510532
def __init__(
511533
self,
512534
n_layers: int,
@@ -519,6 +541,7 @@ def __init__(
519541
dtype=torch.float16,
520542
minus_infinity=-torch.inf,
521543
cache_size=None,
544+
lookahead_enabled: bool = False,
522545
):
523546
if cache_size is None:
524547
cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532555

533556
self.seq_length = seq_length
534557
self.use_cache_list = use_cache_list
558+
self.lookahead_enabled = lookahead_enabled
559+
self.minus_infinity = minus_infinity
535560

536561
if self.use_cache_list:
537562
self.k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609634
if self.cache_pos == self.cache_size:
610635
self.cache_pos = 0
611636

612-
def update(self, input_length, new_k_caches, new_v_caches):
637+
def update(self, input_length, new_k_caches, new_v_caches, update_pos=0):
613638
# Copy as much new cache data into cache as possible without wrapping
614639
amount_to_copy = min(input_length, self.cache_size - self.cache_pos)
615-
self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches)
640+
self._update_cache(update_pos, amount_to_copy, new_k_caches, new_v_caches)
616641
if self.input_pos <= self.cache_size:
617642
self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = (
618643
0.0
@@ -625,7 +650,10 @@ def update(self, input_length, new_k_caches, new_v_caches):
625650
)
626651
if remaining_to_copy > 0:
627652
self._update_cache(
628-
amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches
653+
update_pos + amount_to_copy,
654+
remaining_to_copy,
655+
new_k_caches,
656+
new_v_caches,
629657
)
630658

631659
self.input_pos += input_length
@@ -661,3 +689,270 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661689
self.get_inputs(tokens[0:processed_tokens]),
662690
tokens[processed_tokens:],
663691
)
692+
693+
def _get_lookahead_decoding_mask(
694+
self, ngram_size: int, window_size: int, n_verifications: int
695+
) -> torch.Tensor:
696+
mask = torch.full((self.seq_length, self.seq_length), self.minus_infinity)
697+
mask[0][0] = 0.0
698+
699+
lookahead_submask = torch.triu(
700+
torch.full((window_size, window_size), self.minus_infinity),
701+
diagonal=1,
702+
)
703+
for i in range(ngram_size - 1):
704+
offset = window_size * i
705+
mask[offset : offset + window_size, :window_size] = lookahead_submask
706+
for j in range(1, i + 1):
707+
mask[
708+
offset : offset + window_size,
709+
window_size * j : window_size * (j + 1),
710+
].fill_diagonal_(0.0)
711+
712+
verification_offset = max(window_size * (ngram_size - 1), 1)
713+
verification_submask = torch.triu(
714+
torch.full((ngram_size - 1, ngram_size - 1), self.minus_infinity),
715+
diagonal=1,
716+
)
717+
for i in range(n_verifications):
718+
mask[
719+
verification_offset
720+
+ i * (ngram_size - 1) : verification_offset
721+
+ (i + 1) * (ngram_size - 1),
722+
verification_offset
723+
+ i * (ngram_size - 1) : verification_offset
724+
+ (i + 1) * (ngram_size - 1),
725+
] = verification_submask
726+
mask[verification_offset:, :1] = 0.0
727+
728+
return mask
729+
730+
def _get_lookahead_position_offsets(
731+
self, ngram_size: int, window_size: int, n_verifications: int
732+
) -> torch.Tensor:
733+
pos_offsets = torch.zeros(self.seq_length, dtype=torch.int32)
734+
idx = 0
735+
if window_size > 0:
736+
for i in range(ngram_size - 1):
737+
for j in range(window_size):
738+
pos_offsets[idx] = i + j
739+
idx += 1
740+
else:
741+
pos_offsets[0] = 0
742+
idx += 1
743+
744+
# Verification branches: [1, 2, ..., ngram_size - 1].
745+
for _ in range(n_verifications):
746+
for j in range(1, ngram_size):
747+
pos_offsets[idx] = j
748+
idx += 1
749+
750+
return pos_offsets
751+
752+
def _validate_lookahead_config(
753+
self, ngram_size: int, window_size: int, n_verifications: int
754+
) -> None:
755+
"""
756+
Validate the lookahead decoding configuration.
757+
"""
758+
if not self.lookahead_enabled:
759+
raise RuntimeError("Lookahead decoding is not enabled")
760+
761+
if (ngram_size - 1) * (window_size + n_verifications) > self.seq_length:
762+
raise RuntimeError(
763+
f"Lookahead decoding configuration not compatible with seq_length {self.seq_length}. "
764+
f"Required: {(ngram_size - 1) * (window_size + n_verifications)}"
765+
)
766+
767+
def _setup_lookahead_mask(
768+
self, ngram_size: int, window_size: int, n_verifications: int
769+
) -> None:
770+
"""
771+
Set up the attention mask for lookahead decoding and log debug information.
772+
"""
773+
self.attn_mask[:, self.cache_size :] = self._get_lookahead_decoding_mask(
774+
ngram_size, window_size, n_verifications
775+
)
776+
logger.debug("Lookahead decoding mask: ")
777+
for i in range(self.seq_length):
778+
logger.debug(
779+
" ".join(
780+
("X" if x == 0.0 else " ")
781+
for x in self.attn_mask[i][self.cache_size :]
782+
)
783+
)
784+
785+
def _populate_verification_branches(
786+
self, x: List[int], cache, verification_offset: int, ngram_size: int
787+
) -> None:
788+
"""
789+
Populate verification branches with tokens from the n-gram cache.
790+
"""
791+
for i, ngram in enumerate(cache):
792+
for j, token in enumerate(ngram):
793+
x[verification_offset + i * (ngram_size - 1) + j] = token
794+
795+
def _collect_ngrams(
796+
self,
797+
x: List[int],
798+
y: List[int],
799+
ngram_caches: Dict[int, "InputManager.NGramCache"],
800+
window_size: int,
801+
ngram_size: int,
802+
) -> None:
803+
"""
804+
Collect new n-grams from the current state and predictions.
805+
"""
806+
for i in range(window_size):
807+
key = x[i]
808+
suffix = []
809+
for j in range(1, ngram_size - 1):
810+
suffix.append(x[i + j * window_size])
811+
suffix.append(y[i + window_size * (ngram_size - 2)])
812+
ngram_caches[key].add(suffix)
813+
814+
def _find_longest_match(
815+
self,
816+
x: List[int],
817+
y: List[int],
818+
verification_offset: int,
819+
n_verifications: int,
820+
ngram_size: int,
821+
) -> Tuple[List[int], Optional[int]]:
822+
"""
823+
Find the longest matching sequence from verification branches.
824+
Returns the matched tokens and the branch index.
825+
"""
826+
longest_match = []
827+
matched_branch = None
828+
829+
for i in range(n_verifications):
830+
match = [y[0]]
831+
j = 0
832+
while (
833+
j < ngram_size - 1
834+
and x[verification_offset + (ngram_size - 1) * i + j] == match[-1]
835+
):
836+
match.append(y[verification_offset + (ngram_size - 1) * i + j])
837+
j += 1
838+
if len(match) - 1 > len(longest_match):
839+
longest_match = match[1:]
840+
matched_branch = i
841+
842+
return longest_match, matched_branch
843+
844+
def _update_lookahead_branches(
845+
self, x: List[int], y: List[int], ngram_size: int, window_size: int
846+
) -> None:
847+
"""
848+
Update the lookahead branches with new predictions.
849+
"""
850+
# Shift window contents up
851+
for i in range(ngram_size - 2):
852+
for j in range(window_size):
853+
x[window_size * i + j] = x[window_size * (i + 1) + j]
854+
855+
# Fill the last window with new predictions
856+
for j in range(window_size):
857+
x[window_size * (ngram_size - 2) + j] = y[
858+
window_size * (ngram_size - 2) + j
859+
]
860+
861+
def lookahead_decode(
862+
self,
863+
model,
864+
init_token: int,
865+
n: int,
866+
ngram_size: int,
867+
window_size: int,
868+
n_verifications: int,
869+
stop_tokens: Optional[List[int]] = None,
870+
ngram_caches: Optional[Dict[int, "InputManager.NGramCache"]] = None,
871+
) -> List[int]:
872+
# Validate configuration
873+
self._validate_lookahead_config(ngram_size, window_size, n_verifications)
874+
875+
# Setup attention mask and position offsets
876+
self._setup_lookahead_mask(ngram_size, window_size, n_verifications)
877+
offsets = self._get_lookahead_position_offsets(
878+
ngram_size, window_size, n_verifications
879+
)
880+
881+
# Initialize state
882+
stop_tokens = stop_tokens or []
883+
verification_offset = window_size * (ngram_size - 1)
884+
if ngram_caches is None:
885+
ngram_caches = defaultdict(lambda: InputManager.NGramCache(n_verifications))
886+
887+
new_tokens = [init_token]
888+
x = [init_token] * self.seq_length
889+
inference_count = 0
890+
891+
# Main decoding loop
892+
while len(new_tokens) < n + 1:
893+
# Populate verification branches
894+
cache = ngram_caches[x[0]]
895+
self._populate_verification_branches(
896+
x, cache, verification_offset, ngram_size
897+
)
898+
899+
# Run model inference
900+
logits, new_k, new_v = model(
901+
tokens=torch.tensor([x], dtype=torch.int64),
902+
input_pos=torch.tensor([self.input_pos], dtype=torch.long),
903+
k_caches=self.k_caches,
904+
v_caches=self.v_caches,
905+
attn_mask=self.attn_mask,
906+
input_len=torch.tensor([len(x)], dtype=torch.long),
907+
rope_indices=self.input_pos + offsets,
908+
)
909+
inference_count += 1
910+
911+
# Process model output (greedy selection)
912+
y = logits[0].argmax(dim=-1).tolist()
913+
new_tokens.append(y[0])
914+
logger.debug(f"{self.input_pos}: x = {x[0]}, y = {y[0]}")
915+
if new_tokens[-1] in stop_tokens:
916+
break
917+
918+
# Collect new n-grams
919+
self._collect_ngrams(x, y, ngram_caches, window_size, ngram_size)
920+
921+
# Find longest match from verification branches
922+
longest_match, matched_branch = self._find_longest_match(
923+
x, y, verification_offset, n_verifications, ngram_size
924+
)
925+
926+
# Process match results
927+
if matched_branch is not None:
928+
logger.debug(
929+
f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}"
930+
)
931+
# Truncate at stop token if present
932+
for stop in stop_tokens:
933+
if stop in longest_match:
934+
longest_match = longest_match[: longest_match.index(stop) + 1]
935+
936+
new_tokens.extend(longest_match)
937+
branch_offset = verification_offset + (ngram_size - 1) * matched_branch
938+
self.update(
939+
input_length=len(longest_match),
940+
new_k_caches=new_k,
941+
new_v_caches=new_v,
942+
update_pos=branch_offset,
943+
)
944+
else:
945+
self.update(input_length=1, new_k_caches=new_k, new_v_caches=new_v)
946+
947+
# Update lookahead branches
948+
self._update_lookahead_branches(x, y, ngram_size, window_size)
949+
950+
# Update first token and check for stop condition
951+
x[0] = new_tokens[-1]
952+
if new_tokens[-1] in stop_tokens:
953+
break
954+
955+
logger.info(
956+
f"Generated {len(new_tokens) - 1} tokens with {inference_count} inference(s)."
957+
)
958+
return new_tokens

0 commit comments

Comments
 (0)