Skip to content

Commit 7a15ddb

Browse files
feat(chat): add system prompt handling in chat messages (pr #21, issue #16)
New API llm_chat_system_prompt([TEXT text]). Argument can be: - a TEXT to set the cat system prompt - NULL to unset the system prompt - no args to get the current system prompt --------- Co-authored-by: Daniele Briggi <=>
1 parent c3fd345 commit 7a15ddb

File tree

2 files changed

+431
-6
lines changed

2 files changed

+431
-6
lines changed

src/sqlite-ai.c

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ typedef enum {
199199
AI_MODEL_CHAT_TEMPLATE
200200
} ai_model_setting;
201201

202+
const char *ROLE_SYSTEM = "system";
202203
const char *ROLE_USER = "user";
203204
const char *ROLE_ASSISTANT = "assistant";
204205

@@ -785,27 +786,58 @@ static bool llm_check_context (sqlite3_context *context) {
785786
// MARK: - Chat Messages -
786787

787788
bool llm_messages_append (ai_messages *list, const char *role, const char *content) {
788-
if (list->count >= list->capacity) {
789+
if (role == ROLE_SYSTEM && list->count > 0) {
790+
// only one system prompt allowed at the beginning
791+
return false;
792+
}
793+
794+
bool needs_system_message = (list->count == 0 && role != ROLE_SYSTEM);
795+
size_t required = list->count + (needs_system_message ? 1 : 0);
796+
if (required >= list->capacity) {
789797
size_t new_cap = list->capacity ? list->capacity * 2 : MIN_ALLOC_MESSAGES;
790798
llama_chat_message *new_items = sqlite3_realloc64(list->items, new_cap * sizeof(llama_chat_message));
791799
if (!new_items) return false;
792-
800+
793801
list->items = new_items;
794802
list->capacity = new_cap;
795803
}
796804

797-
bool duplicate_role = ((role != ROLE_USER) && (role != ROLE_ASSISTANT));
805+
if (needs_system_message) {
806+
// reserve first item for empty system prompt
807+
list->items[list->count].role = ROLE_SYSTEM;
808+
list->items[list->count].content = sqlite_strdup("");
809+
list->count += 1;
810+
}
811+
812+
bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
798813
list->items[list->count].role = (duplicate_role) ? sqlite_strdup(role) : role;
799814
list->items[list->count].content = sqlite_strdup(content);
800815
list->count += 1;
801816
return true;
802817
}
803818

819+
bool llm_messages_set (ai_messages *list, int pos, const char *role, const char *content) {
820+
if (pos < 0 || pos >= list->count)
821+
return false;
822+
823+
bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
824+
llama_chat_message *message = &list->items[pos];
825+
826+
const char *message_role = message->role;
827+
if ((message_role != ROLE_SYSTEM) && (message_role != ROLE_USER) && (message_role != ROLE_ASSISTANT))
828+
sqlite3_free((char *)message_role);
829+
sqlite3_free((char *)message->content);
830+
831+
message->role = (duplicate_role) ? sqlite_strdup(role) : role;
832+
message->content = sqlite_strdup(content);
833+
return true;
834+
}
835+
804836
void llm_messages_free (ai_messages *list) {
805837
for (size_t i = 0; i < list->count; ++i) {
806838
// check if rule is static
807839
const char *role = list->items[i].role;
808-
bool role_tofree = ((role != ROLE_USER) && (role != ROLE_ASSISTANT));
840+
bool role_tofree = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
809841
if (role_tofree) sqlite3_free((char *)list->items[i].role);
810842
// content is always to free
811843
sqlite3_free((char *)list->items[i].content);
@@ -1648,12 +1680,23 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt)
16481680
return false;
16491681
}
16501682

1683+
// skip empty system message if present
1684+
size_t messages_count = messages->count;
1685+
const llama_chat_message *messages_items = messages->items;
1686+
if (messages->count > 0) {
1687+
const llama_chat_message first_message = messages->items[0];
1688+
if (first_message.role == ROLE_SYSTEM && first_message.content[0] == '\0') {
1689+
messages_items = messages->items + 1;
1690+
messages_count = messages->count - 1;
1691+
}
1692+
}
1693+
16511694
// transform a list of messages (the context) into
16521695
// <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>...
1653-
int32_t new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity);
1696+
int32_t new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity);
16541697
if (new_len > formatted->capacity) {
16551698
if (buffer_resize(formatted, new_len * 2) == false) return false;
1656-
new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity);
1699+
new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity);
16571700
}
16581701
if ((new_len < 0) || (new_len > formatted->capacity)) {
16591702
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "failed to apply chat template");
@@ -2015,6 +2058,52 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value
20152058
llm_chat_run(ai, NULL, user_prompt);
20162059
}
20172060

2061+
static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv) {
2062+
if (llm_check_context(context) == false)
2063+
return;
2064+
2065+
ai_context *ai = (ai_context *)sqlite3_user_data(context);
2066+
if (llm_chat_check_context(ai) == false)
2067+
return;
2068+
2069+
ai_messages *messages = &ai->chat.messages;
2070+
2071+
// get system role message
2072+
if (argc == 0) {
2073+
if (messages->count == 0) {
2074+
sqlite3_result_null(context);
2075+
return;
2076+
}
2077+
2078+
// only the first message is reserved to the system role
2079+
llama_chat_message *system_message = &messages->items[0];
2080+
const char *content = system_message->content;
2081+
if (system_message->role == ROLE_SYSTEM && content && content[0] != '\0') {
2082+
sqlite3_result_text(context, content, -1, SQLITE_TRANSIENT);
2083+
} else {
2084+
sqlite3_result_null(context);
2085+
}
2086+
2087+
return;
2088+
}
2089+
2090+
bool is_null_prompt = (sqlite3_value_type(argv[0]) == SQLITE_NULL);
2091+
int types[1];
2092+
types[0] = is_null_prompt ? SQLITE_NULL : SQLITE_TEXT;
2093+
2094+
if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false)
2095+
return;
2096+
2097+
const unsigned char *prompt_text = sqlite3_value_text(argv[0]);
2098+
const char *system_prompt = prompt_text ? (const char *)prompt_text : "";
2099+
if (!llm_messages_set(messages, 0, ROLE_SYSTEM, system_prompt)) {
2100+
if (!llm_messages_append(messages, ROLE_SYSTEM, system_prompt)) {
2101+
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to set chat system prompt");
2102+
return;
2103+
}
2104+
}
2105+
}
2106+
20182107
// MARK: - LLM Sampler -
20192108

20202109
static void llm_sampler_init_greedy (sqlite3_context *context, int argc, sqlite3_value **argv) {
@@ -2853,6 +2942,12 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a
28532942

28542943
rc = sqlite3_create_function(db, "llm_chat_respond", 1, SQLITE_UTF8, ctx, llm_chat_respond, NULL, NULL);
28552944
if (rc != SQLITE_OK) goto cleanup;
2945+
2946+
rc = sqlite3_create_function(db, "llm_chat_system_prompt", 0, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL);
2947+
if (rc != SQLITE_OK) goto cleanup;
2948+
2949+
rc = sqlite3_create_function(db, "llm_chat_system_prompt", 1, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL);
2950+
if (rc != SQLITE_OK) goto cleanup;
28562951

28572952
rc = sqlite3_create_module(db, "llm_chat", &llm_chat, ctx);
28582953
if (rc != SQLITE_OK) goto cleanup;

0 commit comments

Comments
 (0)