Skip to content

Commit 2a12ca3

Browse files
author
Daniele Briggi
committed
refact(system-prompt): reserved the first slot of the messages list
1 parent dbc3d69 commit 2a12ca3

File tree

2 files changed

+336
-35
lines changed

2 files changed

+336
-35
lines changed

src/sqlite-ai.c

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ typedef struct {
152152
const char *template;
153153
const struct llama_vocab*vocab;
154154

155-
char *system_prompt;
156155
ai_messages messages;
157156
buffer_t formatted;
158157
buffer_t response;
@@ -796,18 +795,40 @@ bool llm_messages_append (ai_messages *list, const char *role, const char *conte
796795
list->capacity = new_cap;
797796
}
798797

798+
if (list->count != 0 && role == ROLE_SYSTEM) {
799+
// only one system message allowed at the beginning
800+
return false;
801+
}
802+
799803
bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
800804
list->items[list->count].role = (duplicate_role) ? sqlite_strdup(role) : role;
801805
list->items[list->count].content = sqlite_strdup(content);
802806
list->count += 1;
803807
return true;
804808
}
805809

810+
bool llm_messages_set (ai_messages *list, int pos, const char *role, const char *content) {
811+
if (pos < 0 || pos >= list->count)
812+
return false;
813+
814+
bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
815+
llama_chat_message *message = &list->items[pos];
816+
817+
const char *message_role = message->role;
818+
if ((message_role != ROLE_SYSTEM) && (message_role != ROLE_USER) && (message_role != ROLE_ASSISTANT))
819+
sqlite3_free(message_role);
820+
sqlite3_free(message->content);
821+
822+
message->role = (duplicate_role) ? sqlite_strdup(role) : role;
823+
message->content = sqlite_strdup(content);
824+
return true;
825+
}
826+
806827
void llm_messages_free (ai_messages *list) {
807828
for (size_t i = 0; i < list->count; ++i) {
808829
// check if rule is static
809830
const char *role = list->items[i].role;
810-
bool role_tofree = ((role != ROLE_USER) && (role != ROLE_ASSISTANT));
831+
bool role_tofree = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
811832
if (role_tofree) sqlite3_free((char *)list->items[i].role);
812833
// content is always to free
813834
sqlite3_free((char *)list->items[i].content);
@@ -1491,7 +1512,7 @@ static bool llm_chat_check_context (ai_context *ai) {
14911512
llama_sampler_chain_add(ai->sampler, llama_sampler_init_temp(0.8));
14921513
llama_sampler_chain_add(ai->sampler, llama_sampler_init_dist((uint32_t)LLAMA_DEFAULT_SEED));
14931514
}
1494-
1515+
14951516
// initialize the chat struct if already created
14961517
if (ai->chat.uuid[0] != '\0') return true;
14971518

@@ -1649,33 +1670,24 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt)
16491670
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to append message");
16501671
return false;
16511672
}
1652-
1653-
// add system prompt if available, models expect it to be the first message
1654-
llama_chat_message *new_items = messages->items;
1655-
if (ai->chat.system_prompt) {
1656-
size_t n = messages->count + (ai->chat.system_prompt ? 1 : 0);
1657-
llama_chat_message *new_items = sqlite3_realloc64(messages->items, n * sizeof(llama_chat_message));
1658-
if (!new_items)
1659-
return false;
1660-
1661-
messages->items = new_items;
1662-
messages->capacity = n;
1663-
1664-
int idx = 0;
1665-
new_items[0].role = ROLE_SYSTEM;
1666-
new_items[0].content = ai->chat.system_prompt;
1667-
idx = 1;
1668-
for (size_t i = 0; i < messages->count; ++i) {
1669-
new_items[idx++] = messages->items[i];
1673+
1674+
// skip empty system message if present
1675+
size_t messages_count = messages->count;
1676+
const llama_chat_message *messages_items = messages->items;
1677+
if (messages->count > 0) {
1678+
const llama_chat_message first_message = messages->items[0];
1679+
if (first_message.role == ROLE_SYSTEM && first_message.content[0] == '\0') {
1680+
messages_items = messages->items + 1;
1681+
messages_count = messages->count - 1;
16701682
}
16711683
}
16721684

16731685
// transform a list of messages (the context) into
16741686
// <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>...
1675-
int32_t new_len = llama_chat_apply_template(template, new_items, messages->count, true, formatted->data, formatted->capacity);
1687+
int32_t new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity);
16761688
if (new_len > formatted->capacity) {
16771689
if (buffer_resize(formatted, new_len * 2) == false) return false;
1678-
new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity);
1690+
new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity);
16791691
}
16801692
if ((new_len < 0) || (new_len > formatted->capacity)) {
16811693
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "failed to apply chat template");
@@ -2037,25 +2049,50 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value
20372049
llm_chat_run(ai, NULL, user_prompt);
20382050
}
20392051

2040-
static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv)
2041-
{
2052+
static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv) {
20422053
if (llm_check_context(context) == false)
20432054
return;
20442055

2045-
int types[] = {SQLITE_TEXT};
2046-
if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false)
2056+
ai_context *ai = (ai_context *)sqlite3_user_data(context);
2057+
if (llm_chat_check_context(ai) == false)
20472058
return;
20482059

2049-
const char *system_prompt = (const char *)sqlite3_value_text(argv[0]);
2050-
ai_context *ai = (ai_context *)sqlite3_user_data(context);
2060+
ai_messages *messages = &ai->chat.messages;
2061+
2062+
// get system role message
2063+
if (argc == 0) {
2064+
if (messages->count == 0) {
2065+
sqlite3_result_null(context);
2066+
return;
2067+
}
2068+
2069+
// only the first message is reserved to the system role
2070+
llama_chat_message *system_message = &messages->items[0];
2071+
const char *content = system_message->content;
2072+
if (system_message->role == ROLE_SYSTEM && content && content[0] != '\0') {
2073+
sqlite3_result_text(context, content, -1, SQLITE_TRANSIENT);
2074+
} else {
2075+
sqlite3_result_null(context);
2076+
}
20512077

2052-
if (llm_chat_check_context(ai) == false)
20532078
return;
2079+
}
2080+
2081+
bool is_null_prompt = (sqlite3_value_type(argv[0]) == SQLITE_NULL);
2082+
int types[1];
2083+
types[0] = is_null_prompt ? SQLITE_NULL : SQLITE_TEXT;
20542084

2055-
if (ai->chat.system_prompt) {
2056-
sqlite3_free(ai->chat.system_prompt);
2085+
if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false)
2086+
return;
2087+
2088+
const unsigned char *prompt_text = sqlite3_value_text(argv[0]);
2089+
const char *system_prompt = prompt_text ? (const char *)prompt_text : "";
2090+
if (!llm_messages_set(messages, 0, ROLE_SYSTEM, system_prompt)) {
2091+
if (!llm_messages_append(messages, ROLE_SYSTEM, system_prompt)) {
2092+
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to set chat system prompt");
2093+
return;
2094+
}
20572095
}
2058-
ai->chat.system_prompt = sqlite3_mprintf("%s", system_prompt);
20592096
}
20602097

20612098
// MARK: - LLM Sampler -
@@ -2895,10 +2932,13 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a
28952932
if (rc != SQLITE_OK) goto cleanup;
28962933

28972934
rc = sqlite3_create_function(db, "llm_chat_respond", 1, SQLITE_UTF8, ctx, llm_chat_respond, NULL, NULL);
2935+
if (rc != SQLITE_OK) goto cleanup;
2936+
2937+
rc = sqlite3_create_function(db, "llm_chat_system_prompt", 0, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL);
2938+
if (rc != SQLITE_OK) goto cleanup;
28982939

28992940
rc = sqlite3_create_function(db, "llm_chat_system_prompt", 1, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL);
2900-
if (rc != SQLITE_OK)
2901-
goto cleanup;
2941+
if (rc != SQLITE_OK) goto cleanup;
29022942

29032943
rc = sqlite3_create_module(db, "llm_chat", &llm_chat, ctx);
29042944
if (rc != SQLITE_OK) goto cleanup;

0 commit comments

Comments
 (0)