|
17 | 17 | #include <string>
|
18 | 18 | #include <vector>
|
19 | 19 |
|
| 20 | +using json = nlohmann::ordered_json; |
| 21 | + |
20 | 22 | static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
21 | 23 | auto time = std::chrono::system_clock::to_time_t(now);
|
22 | 24 | auto local_time = *std::localtime(&time);
|
@@ -140,6 +142,7 @@ struct templates_params {
|
140 | 142 | bool add_generation_prompt = true;
|
141 | 143 | bool enable_thinking = true;
|
142 | 144 | std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
| 145 | + json extra_context; |
143 | 146 | };
|
144 | 147 |
|
145 | 148 | common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
@@ -720,16 +723,23 @@ static void foreach_function(const json & tools, const std::function<void(const
|
720 | 723 |
|
721 | 724 | static std::string apply(
|
722 | 725 | const common_chat_template & tmpl,
|
723 |
| - const nlohmann::ordered_json & messages, |
724 |
| - const nlohmann::ordered_json & tools, |
725 |
| - bool add_generation_prompt, |
726 |
| - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) |
| 726 | + const struct templates_params & inputs, |
| 727 | + const std::optional<json> & messages_override = std::nullopt, |
| 728 | + const std::optional<json> & tools_override = std::nullopt, |
| 729 | + const std::optional<json> & additional_context = std::nullopt) |
727 | 730 | {
|
728 | 731 | minja::chat_template_inputs tmpl_inputs;
|
729 |
| - tmpl_inputs.messages = messages; |
730 |
| - tmpl_inputs.tools = tools; |
731 |
| - tmpl_inputs.add_generation_prompt = add_generation_prompt; |
732 |
| - tmpl_inputs.extra_context = extra_context; |
| 732 | + tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; |
| 733 | + if (tools_override) { |
| 734 | + tmpl_inputs.tools = *tools_override; |
| 735 | + } else { |
| 736 | + tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; |
| 737 | + } |
| 738 | + tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; |
| 739 | + tmpl_inputs.extra_context = inputs.extra_context; |
| 740 | + if (additional_context) { |
| 741 | + tmpl_inputs.extra_context.merge_patch(*additional_context); |
| 742 | + } |
733 | 743 | // TODO: add flag to control date/time, if only for testing purposes.
|
734 | 744 | // tmpl_inputs.now = std::chrono::system_clock::now();
|
735 | 745 |
|
@@ -828,7 +838,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
828 | 838 | inputs.messages,
|
829 | 839 | "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
830 | 840 |
|
831 |
| - data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 841 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); |
832 | 842 | data.format = COMMON_CHAT_FORMAT_GENERIC;
|
833 | 843 | return data;
|
834 | 844 | }
|
@@ -904,7 +914,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
904 | 914 | data.preserved_tokens = {
|
905 | 915 | "[TOOL_CALLS]",
|
906 | 916 | };
|
907 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 917 | + data.prompt = apply(tmpl, inputs); |
908 | 918 | data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
909 | 919 | return data;
|
910 | 920 | }
|
@@ -934,7 +944,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
934 | 944 | adjusted_messages.push_back(msg);
|
935 | 945 | }
|
936 | 946 | }
|
937 |
| - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); |
| 947 | + data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); |
938 | 948 | data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
939 | 949 | if (string_ends_with(data.prompt, "<|START_THINKING|>")) {
|
940 | 950 | if (!inputs.enable_thinking) {
|
@@ -1122,7 +1132,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
|
1122 | 1132 | } else {
|
1123 | 1133 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1124 | 1134 | }
|
1125 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { |
| 1135 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { |
1126 | 1136 | {"date_string", format_time(inputs.now, "%d %b %Y")},
|
1127 | 1137 | {"tools_in_user_message", false},
|
1128 | 1138 | {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
@@ -1187,7 +1197,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
|
1187 | 1197 |
|
1188 | 1198 | static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1189 | 1199 | common_chat_params data;
|
1190 |
| - auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1200 | + auto prompt = apply(tmpl, inputs); |
1191 | 1201 |
|
1192 | 1202 | // Hacks to fix the official (broken) prompt.
|
1193 | 1203 | // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
@@ -1282,7 +1292,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
|
1282 | 1292 | static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1283 | 1293 | LOG_DBG("%s\n", __func__);
|
1284 | 1294 | common_chat_params data;
|
1285 |
| - data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { |
| 1295 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json { |
1286 | 1296 | {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
1287 | 1297 | {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
1288 | 1298 | });
|
@@ -1338,7 +1348,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
1338 | 1348 | // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
1339 | 1349 | // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code.
|
1340 | 1350 | common_chat_params data;
|
1341 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1351 | + data.prompt = apply(tmpl, inputs); |
1342 | 1352 | data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
1343 | 1353 | if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
1344 | 1354 | data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
@@ -1465,7 +1475,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
1465 | 1475 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1466 | 1476 | }
|
1467 | 1477 |
|
1468 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1478 | + data.prompt = apply(tmpl, inputs); |
1469 | 1479 | // TODO: if (has_raw_python)
|
1470 | 1480 | return data;
|
1471 | 1481 | }
|
@@ -1498,14 +1508,15 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
|
1498 | 1508 | static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1499 | 1509 | common_chat_params data;
|
1500 | 1510 |
|
1501 |
| - json additional_context = { |
| 1511 | + json extra_context = json { |
1502 | 1512 | {"enable_thinking", inputs.enable_thinking},
|
1503 | 1513 | };
|
| 1514 | + extra_context.update(inputs.extra_context); |
1504 | 1515 |
|
1505 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); |
| 1516 | + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context); |
1506 | 1517 | data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
1507 | 1518 | if (string_ends_with(data.prompt, "<think>\n")) {
|
1508 |
| - if (!inputs.enable_thinking) { |
| 1519 | + if (!extra_context["enable_thinking"]) { |
1509 | 1520 | data.prompt += "</think>";
|
1510 | 1521 | } else {
|
1511 | 1522 | data.thinking_forced_open = true;
|
@@ -1691,7 +1702,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
1691 | 1702 |
|
1692 | 1703 | static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1693 | 1704 | common_chat_params data;
|
1694 |
| - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); |
| 1705 | + data.prompt = apply(tmpl, inputs); |
1695 | 1706 | data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1696 | 1707 | data.grammar_lazy = false;
|
1697 | 1708 | if (!inputs.json_schema.is_null()) {
|
@@ -1722,6 +1733,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
1722 | 1733 | params.enable_thinking = inputs.enable_thinking;
|
1723 | 1734 | params.grammar = inputs.grammar;
|
1724 | 1735 | params.now = inputs.now;
|
| 1736 | + |
| 1737 | + params.extra_context = json::object(); |
| 1738 | + for (auto el : inputs.chat_template_kwargs) { |
| 1739 | + params.extra_context[el.first] = json::parse(el.second); |
| 1740 | + } |
| 1741 | + |
1725 | 1742 | if (!inputs.json_schema.empty()) {
|
1726 | 1743 | params.json_schema = json::parse(inputs.json_schema);
|
1727 | 1744 | }
|
|
0 commit comments