Skip to content

Commit dbc3d69

Browse files
author
Daniele Briggi
committed
feat(chat): add system prompt handling in chat messages
1 parent c3fd345 commit dbc3d69

File tree

1 file changed

+50
-4
lines changed

1 file changed

+50
-4
lines changed

src/sqlite-ai.c

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ typedef struct {
152152
const char *template;
153153
const struct llama_vocab*vocab;
154154

155+
char *system_prompt;
155156
ai_messages messages;
156157
buffer_t formatted;
157158
buffer_t response;
@@ -199,6 +200,7 @@ typedef enum {
199200
AI_MODEL_CHAT_TEMPLATE
200201
} ai_model_setting;
201202

203+
const char *ROLE_SYSTEM = "system";
202204
const char *ROLE_USER = "user";
203205
const char *ROLE_ASSISTANT = "assistant";
204206

@@ -794,7 +796,7 @@ bool llm_messages_append (ai_messages *list, const char *role, const char *conte
794796
list->capacity = new_cap;
795797
}
796798

797-
bool duplicate_role = ((role != ROLE_USER) && (role != ROLE_ASSISTANT));
799+
bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
798800
list->items[list->count].role = (duplicate_role) ? sqlite_strdup(role) : role;
799801
list->items[list->count].content = sqlite_strdup(content);
800802
list->count += 1;
@@ -1489,7 +1491,7 @@ static bool llm_chat_check_context (ai_context *ai) {
14891491
llama_sampler_chain_add(ai->sampler, llama_sampler_init_temp(0.8));
14901492
llama_sampler_chain_add(ai->sampler, llama_sampler_init_dist((uint32_t)LLAMA_DEFAULT_SEED));
14911493
}
1492-
1494+
14931495
// initialize the chat struct if already created
14941496
if (ai->chat.uuid[0] != '\0') return true;
14951497

@@ -1647,10 +1649,30 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt)
16471649
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to append message");
16481650
return false;
16491651
}
1652+
1653+
// add system prompt if available, models expect it to be the first message
1654+
llama_chat_message *new_items = messages->items;
1655+
if (ai->chat.system_prompt) {
1656+
size_t n = messages->count + (ai->chat.system_prompt ? 1 : 0);
1657+
llama_chat_message *new_items = sqlite3_realloc64(messages->items, n * sizeof(llama_chat_message));
1658+
if (!new_items)
1659+
return false;
1660+
1661+
messages->items = new_items;
1662+
messages->capacity = n;
1663+
1664+
int idx = 0;
1665+
new_items[0].role = ROLE_SYSTEM;
1666+
new_items[0].content = ai->chat.system_prompt;
1667+
idx = 1;
1668+
for (size_t i = 0; i < messages->count; ++i) {
1669+
new_items[idx++] = messages->items[i];
1670+
}
1671+
}
16501672

16511673
// transform a list of messages (the context) into
16521674
// <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>...
1653-
int32_t new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity);
1675+
int32_t new_len = llama_chat_apply_template(template, new_items, messages->count, true, formatted->data, formatted->capacity);
16541676
if (new_len > formatted->capacity) {
16551677
if (buffer_resize(formatted, new_len * 2) == false) return false;
16561678
new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity);
@@ -2015,6 +2037,27 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value
20152037
llm_chat_run(ai, NULL, user_prompt);
20162038
}
20172039

2040+
static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv)
2041+
{
2042+
if (llm_check_context(context) == false)
2043+
return;
2044+
2045+
int types[] = {SQLITE_TEXT};
2046+
if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false)
2047+
return;
2048+
2049+
const char *system_prompt = (const char *)sqlite3_value_text(argv[0]);
2050+
ai_context *ai = (ai_context *)sqlite3_user_data(context);
2051+
2052+
if (llm_chat_check_context(ai) == false)
2053+
return;
2054+
2055+
if (ai->chat.system_prompt) {
2056+
sqlite3_free(ai->chat.system_prompt);
2057+
}
2058+
ai->chat.system_prompt = sqlite3_mprintf("%s", system_prompt);
2059+
}
2060+
20182061
// MARK: - LLM Sampler -
20192062

20202063
static void llm_sampler_init_greedy (sqlite3_context *context, int argc, sqlite3_value **argv) {
@@ -2852,7 +2895,10 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a
28522895
if (rc != SQLITE_OK) goto cleanup;
28532896

28542897
rc = sqlite3_create_function(db, "llm_chat_respond", 1, SQLITE_UTF8, ctx, llm_chat_respond, NULL, NULL);
2855-
if (rc != SQLITE_OK) goto cleanup;
2898+
2899+
rc = sqlite3_create_function(db, "llm_chat_system_prompt", 1, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL);
2900+
if (rc != SQLITE_OK)
2901+
goto cleanup;
28562902

28572903
rc = sqlite3_create_module(db, "llm_chat", &llm_chat, ctx);
28582904
if (rc != SQLITE_OK) goto cleanup;

0 commit comments

Comments
 (0)