|
9 | 9 | #include <nlohmann/json.hpp>
|
10 | 10 |
|
11 | 11 | #if defined(_WIN32)
|
| 12 | +# ifndef NOMINMAX |
| 13 | +# define NOMINMAX |
| 14 | +# endif |
12 | 15 | # include <windows.h>
|
13 | 16 | # include <io.h>
|
14 | 17 | #else
|
@@ -940,16 +943,29 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
|
940 | 943 | static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
|
941 | 944 | std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
|
942 | 945 | const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1;
|
943 |
| - |
944 |
| - const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); |
945 |
| - prompt_tokens.resize(n_prompt_tokens); |
946 |
| - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, |
947 |
| - true) < 0) { |
948 |
| - printe("failed to tokenize the prompt\n"); |
| 946 | + int n_tokens = prompt.size() + 2 * is_first; |
| 947 | + prompt_tokens.resize(n_tokens); |
| 948 | + n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(), |
| 949 | + prompt_tokens.data(), prompt_tokens.size(), |
| 950 | + is_first, /*parse_special =*/true); |
| 951 | + if (n_tokens == std::numeric_limits<int32_t>::min()) { |
| 952 | + printe("tokenization failed: input too large\n"); |
949 | 953 | return -1;
|
950 | 954 | }
|
951 |
| - |
952 |
| - return n_prompt_tokens; |
| 955 | + if (n_tokens < 0) { |
| 956 | + prompt_tokens.resize(-n_tokens); |
| 957 | + int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(), |
| 958 | + prompt_tokens.data(), prompt_tokens.size(), |
| 959 | + is_first, /*parse_special =*/true); |
| 960 | + if (check != -n_tokens) { |
| 961 | + printe("failed to tokenize the prompt (size mismatch)\n"); |
| 962 | + return -1; |
| 963 | + } |
| 964 | + n_tokens = check; |
| 965 | + } else { |
| 966 | + prompt_tokens.resize(n_tokens); |
| 967 | + } |
| 968 | + return n_tokens; |
953 | 969 | }
|
954 | 970 |
|
955 | 971 | // Check if we have enough space in the context to evaluate this batch
|
|
0 commit comments