Skip to content

Commit 66aba7a

Browse files
authored
run : avoid double tokenization (#14327)
* run : avoid double tokenization by adopting common_tokenize heuristic * build : fix windows gcc and clang warnings * lint : fixed trailing whitepace * run : fix is_first flag
1 parent f1f5e82 commit 66aba7a

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

tools/run/run.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#include <nlohmann/json.hpp>
1010

1111
#if defined(_WIN32)
12+
# ifndef NOMINMAX
13+
# define NOMINMAX
14+
# endif
1215
# include <windows.h>
1316
# include <io.h>
1417
#else
@@ -940,16 +943,29 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
940943
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
941944
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
942945
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1;
943-
944-
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
945-
prompt_tokens.resize(n_prompt_tokens);
946-
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
947-
true) < 0) {
948-
printe("failed to tokenize the prompt\n");
946+
int n_tokens = prompt.size() + 2 * is_first;
947+
prompt_tokens.resize(n_tokens);
948+
n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
949+
prompt_tokens.data(), prompt_tokens.size(),
950+
is_first, /*parse_special =*/true);
951+
if (n_tokens == std::numeric_limits<int32_t>::min()) {
952+
printe("tokenization failed: input too large\n");
949953
return -1;
950954
}
951-
952-
return n_prompt_tokens;
955+
if (n_tokens < 0) {
956+
prompt_tokens.resize(-n_tokens);
957+
int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(),
958+
prompt_tokens.data(), prompt_tokens.size(),
959+
is_first, /*parse_special =*/true);
960+
if (check != -n_tokens) {
961+
printe("failed to tokenize the prompt (size mismatch)\n");
962+
return -1;
963+
}
964+
n_tokens = check;
965+
} else {
966+
prompt_tokens.resize(n_tokens);
967+
}
968+
return n_tokens;
953969
}
954970

955971
// Check if we have enough space in the context to evaluate this batch

0 commit comments

Comments
 (0)