Skip to content

Commit 6dc214e

Browse files
Lookahead decoding eager implementation (#12491)
Summary: Implement reference lookahead decoding for CoreML implementation. Reviewed By: sxu Differential Revision: D78323399
1 parent e1db341 commit 6dc214e

File tree

4 files changed

+515
-3
lines changed

4 files changed

+515
-3
lines changed

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 217 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# Please refer to README.md in the same folder for more information.
88

9+
import logging
10+
from collections import defaultdict, deque
911
from dataclasses import dataclass
1012
from functools import partial
1113
from typing import Dict, List, Optional, Tuple
@@ -23,6 +25,8 @@
2325

2426
from torch import nn
2527

28+
logger = logging.getLogger(__name__)
29+
2630

2731
def find_multiple(n: int, k: int) -> int:
2832
if n % k == 0:
@@ -507,6 +511,24 @@ def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
507511

508512

509513
class InputManager:
514+
class NGramCache:
515+
def __init__(self, max_size: int):
516+
self.cache = deque()
517+
self.max_size = max_size
518+
519+
def add(self, ngram: List[int]):
520+
if ngram in self.cache:
521+
return
522+
if len(self.cache) == self.max_size:
523+
self.cache.popleft()
524+
self.cache.append(ngram)
525+
526+
def __iter__(self):
527+
return iter(self.cache)
528+
529+
def __str__(self):
530+
return str(self.cache)
531+
510532
def __init__(
511533
self,
512534
n_layers: int,
@@ -519,6 +541,7 @@ def __init__(
519541
dtype=torch.float16,
520542
minus_infinity=-torch.inf,
521543
cache_size=None,
544+
lookahead_enabled: bool = False,
522545
):
523546
if cache_size is None:
524547
cache_size = max_seq_length - seq_length
@@ -532,6 +555,8 @@ def __init__(
532555

533556
self.seq_length = seq_length
534557
self.use_cache_list = use_cache_list
558+
self.lookahead_enabled = lookahead_enabled
559+
self.minus_infinity = minus_infinity
535560

536561
if self.use_cache_list:
537562
self.k_caches = [
@@ -609,10 +634,10 @@ def _update_cache(self, start, length, new_k_caches, new_v_caches):
609634
if self.cache_pos == self.cache_size:
610635
self.cache_pos = 0
611636

612-
def update(self, input_length, new_k_caches, new_v_caches):
637+
def update(self, input_length, new_k_caches, new_v_caches, update_pos=0):
613638
# Copy as much new cache data into cache as possible without wrapping
614639
amount_to_copy = min(input_length, self.cache_size - self.cache_pos)
615-
self._update_cache(0, amount_to_copy, new_k_caches, new_v_caches)
640+
self._update_cache(update_pos, amount_to_copy, new_k_caches, new_v_caches)
616641
if self.input_pos <= self.cache_size:
617642
self.attn_mask[:, (self.input_pos) : (self.input_pos + amount_to_copy)] = (
618643
0.0
@@ -625,7 +650,7 @@ def update(self, input_length, new_k_caches, new_v_caches):
625650
)
626651
if remaining_to_copy > 0:
627652
self._update_cache(
628-
amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches
653+
update_pos + amount_to_copy, remaining_to_copy, new_k_caches, new_v_caches
629654
)
630655

631656
self.input_pos += input_length
@@ -661,3 +686,192 @@ def get_inputs_and_remaining_tokens(self, tokens: List[int]):
661686
self.get_inputs(tokens[0:processed_tokens]),
662687
tokens[processed_tokens:],
663688
)
689+
690+
def _get_lookahead_decoding_mask(
691+
self, ngram_size: int, window_size: int, n_verifications: int
692+
) -> torch.Tensor:
693+
mask = torch.full((self.seq_length, self.seq_length), self.minus_infinity)
694+
mask[0][0] = 0.0
695+
696+
lookahead_submask = torch.triu(
697+
torch.full((window_size, window_size), self.minus_infinity),
698+
diagonal=1,
699+
)
700+
for i in range(ngram_size - 1):
701+
offset = window_size * i
702+
mask[offset : offset + window_size, :window_size] = lookahead_submask
703+
for j in range(1, i + 1):
704+
mask[
705+
offset : offset + window_size,
706+
window_size * j : window_size * (j + 1),
707+
].fill_diagonal_(0.0)
708+
709+
verification_offset = max(window_size * (ngram_size - 1), 1)
710+
verification_submask = torch.triu(
711+
torch.full((ngram_size - 1, ngram_size - 1), self.minus_infinity),
712+
diagonal=1,
713+
)
714+
for i in range(n_verifications):
715+
mask[
716+
verification_offset + i * (ngram_size - 1) : verification_offset
717+
+ (i + 1) * (ngram_size - 1),
718+
verification_offset + i * (ngram_size - 1) : verification_offset
719+
+ (i + 1) * (ngram_size - 1),
720+
] = verification_submask
721+
mask[verification_offset:, :1] = 0.0
722+
723+
return mask
724+
725+
def _get_lookahead_position_offsets(
726+
self, ngram_size: int, window_size: int, n_verifications: int
727+
) -> torch.Tensor:
728+
pos_offsets = torch.zeros(self.seq_length, dtype=torch.int32)
729+
idx = 0
730+
if window_size > 0:
731+
for i in range(ngram_size - 1):
732+
for j in range(window_size):
733+
pos_offsets[idx] = i + j
734+
idx += 1
735+
else:
736+
pos_offsets[0] = 0
737+
idx += 1
738+
739+
# Verification branches: [1, 2, ..., ngram_size - 1].
740+
for _ in range(n_verifications):
741+
for j in range(1, ngram_size):
742+
pos_offsets[idx] = j
743+
idx += 1
744+
745+
return pos_offsets
746+
747+
def lookahead_decode(
748+
self,
749+
model,
750+
init_token: int,
751+
n: int,
752+
ngram_size: int,
753+
window_size: int,
754+
n_verifications: int,
755+
stop_tokens: Optional[List[int]] = None,
756+
ngram_caches: Optional[Dict[int, "InputManager.NGramCache"]] = None,
757+
) -> List[int]:
758+
if not self.lookahead_enabled:
759+
raise RuntimeError("Lookahead decoding is not enabled")
760+
761+
if (ngram_size - 1) * (window_size + n_verifications) > self.seq_length:
762+
raise RuntimeError(
763+
f"Lookahead decoding configuration not compatible with seq_length {self.seq_length}. "
764+
f"Required: {(ngram_size - 1) * (window_size + n_verifications)}"
765+
)
766+
767+
self.attn_mask[:, self.cache_size :] = self._get_lookahead_decoding_mask(
768+
ngram_size, window_size, n_verifications
769+
)
770+
logger.debug("Lookahead decoding mask: ")
771+
for i in range(self.seq_length):
772+
logger.debug(
773+
" ".join(
774+
("X" if x == 0.0 else " ")
775+
for x in self.attn_mask[i][self.cache_size :]
776+
)
777+
)
778+
779+
offsets = self._get_lookahead_position_offsets(
780+
ngram_size, window_size, n_verifications
781+
)
782+
783+
stop_tokens = stop_tokens or []
784+
verification_offset = window_size * (ngram_size - 1)
785+
786+
if ngram_caches is None:
787+
ngram_caches = defaultdict(lambda: InputManager.NGramCache(n_verifications))
788+
new_tokens = [init_token]
789+
x = [init_token] * self.seq_length
790+
inference_count = 0
791+
792+
while len(new_tokens) < n + 1:
793+
cache = ngram_caches[x[0]]
794+
for i, ngram in enumerate(cache):
795+
for j, token in enumerate(ngram):
796+
x[verification_offset + i * (ngram_size - 1) + j] = token
797+
798+
logits, new_k, new_v = model(
799+
tokens=torch.tensor([x], dtype=torch.int64),
800+
input_pos=torch.tensor([self.input_pos], dtype=torch.long),
801+
k_caches=self.k_caches,
802+
v_caches=self.v_caches,
803+
attn_mask=self.attn_mask,
804+
input_len=torch.tensor([len(x)], dtype=torch.long),
805+
rope_indices=self.input_pos + offsets,
806+
)
807+
inference_count += 1
808+
809+
# Greedy only
810+
y = logits[0].argmax(dim=-1).tolist()
811+
new_tokens.append(y[0])
812+
logger.debug(f"{self.input_pos}: x = {x[0]}, y = {y[0]}")
813+
if new_tokens[-1] in stop_tokens:
814+
break
815+
816+
# Collect new n-grams.
817+
for i in range(window_size):
818+
key = x[i]
819+
suffix = []
820+
for j in range(1, ngram_size - 1):
821+
suffix.append(x[i + j * window_size])
822+
suffix.append(y[i + window_size * (ngram_size - 2)])
823+
ngram_caches[key].add(suffix)
824+
825+
# Verification.
826+
longest_match = []
827+
matched_branch = None
828+
for i in range(n_verifications):
829+
match = [y[0]]
830+
j = 0
831+
# for j in range(ngram_size - 1):
832+
while (
833+
j < ngram_size - 1
834+
and x[verification_offset + (ngram_size - 1) * i + j] == match[-1]
835+
):
836+
match.append(y[verification_offset + (ngram_size - 1) * i + j])
837+
j += 1
838+
if len(match) - 1 > len(longest_match):
839+
longest_match = match[1:]
840+
matched_branch = i
841+
842+
if matched_branch is not None:
843+
logger.debug(
844+
f"Matched {len(longest_match)} additional tokens from n-grams: {longest_match}"
845+
)
846+
for stop in stop_tokens:
847+
if stop in longest_match:
848+
longest_match = longest_match[: longest_match.index(stop) + 1]
849+
850+
new_tokens.extend(longest_match)
851+
branch_offset = verification_offset + (ngram_size - 1) * matched_branch
852+
self.update(
853+
input_length=len(longest_match),
854+
new_k_caches=new_k,
855+
new_v_caches=new_v,
856+
update_pos=branch_offset,
857+
)
858+
else:
859+
self.update(input_length=1, new_k_caches=new_k, new_v_caches=new_v)
860+
861+
# Update lookahead branch.
862+
for i in range(ngram_size - 2):
863+
for j in range(window_size):
864+
x[window_size * i + j] = x[window_size * (i + 1) + j]
865+
for j in range(window_size):
866+
x[window_size * (ngram_size - 2) + j] = y[
867+
window_size * (ngram_size - 2) + j
868+
]
869+
870+
x[0] = new_tokens[-1]
871+
if new_tokens[-1] in stop_tokens:
872+
break
873+
874+
logger.info(
875+
f"Generated {len(new_tokens) - 1} tokens with {inference_count} inference(s)."
876+
)
877+
return new_tokens

0 commit comments

Comments
 (0)