Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 101 additions & 6 deletions src/sqlite-ai.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ typedef enum {
AI_MODEL_CHAT_TEMPLATE
} ai_model_setting;

const char *ROLE_SYSTEM = "system";
const char *ROLE_USER = "user";
const char *ROLE_ASSISTANT = "assistant";

Expand Down Expand Up @@ -785,27 +786,58 @@ static bool llm_check_context (sqlite3_context *context) {
// MARK: - Chat Messages -

bool llm_messages_append (ai_messages *list, const char *role, const char *content) {
if (list->count >= list->capacity) {
if (role == ROLE_SYSTEM && list->count > 0) {
// only one system prompt allowed at the beginning
return false;
}

bool needs_system_message = (list->count == 0 && role != ROLE_SYSTEM);
size_t required = list->count + (needs_system_message ? 1 : 0);
if (required >= list->capacity) {
size_t new_cap = list->capacity ? list->capacity * 2 : MIN_ALLOC_MESSAGES;
llama_chat_message *new_items = sqlite3_realloc64(list->items, new_cap * sizeof(llama_chat_message));
if (!new_items) return false;

list->items = new_items;
list->capacity = new_cap;
}

bool duplicate_role = ((role != ROLE_USER) && (role != ROLE_ASSISTANT));
if (needs_system_message) {
// reserve first item for empty system prompt
list->items[list->count].role = ROLE_SYSTEM;
list->items[list->count].content = sqlite_strdup("");
list->count += 1;
}

bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
list->items[list->count].role = (duplicate_role) ? sqlite_strdup(role) : role;
list->items[list->count].content = sqlite_strdup(content);
list->count += 1;
return true;
}

bool llm_messages_set (ai_messages *list, int pos, const char *role, const char *content) {
if (pos < 0 || pos >= list->count)
return false;

bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
llama_chat_message *message = &list->items[pos];

const char *message_role = message->role;
if ((message_role != ROLE_SYSTEM) && (message_role != ROLE_USER) && (message_role != ROLE_ASSISTANT))
sqlite3_free((char *)message_role);
sqlite3_free((char *)message->content);

message->role = (duplicate_role) ? sqlite_strdup(role) : role;
message->content = sqlite_strdup(content);
return true;
}

void llm_messages_free (ai_messages *list) {
for (size_t i = 0; i < list->count; ++i) {
// check if rule is static
const char *role = list->items[i].role;
bool role_tofree = ((role != ROLE_USER) && (role != ROLE_ASSISTANT));
bool role_tofree = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT));
if (role_tofree) sqlite3_free((char *)list->items[i].role);
// content is always to free
sqlite3_free((char *)list->items[i].content);
Expand Down Expand Up @@ -1648,12 +1680,23 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt)
return false;
}

// skip empty system message if present
size_t messages_count = messages->count;
const llama_chat_message *messages_items = messages->items;
if (messages->count > 0) {
const llama_chat_message first_message = messages->items[0];
if (first_message.role == ROLE_SYSTEM && first_message.content[0] == '\0') {
messages_items = messages->items + 1;
messages_count = messages->count - 1;
}
}

// transform a list of messages (the context) into
// <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>...
int32_t new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity);
int32_t new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity);
if (new_len > formatted->capacity) {
if (buffer_resize(formatted, new_len * 2) == false) return false;
new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity);
new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity);
}
if ((new_len < 0) || (new_len > formatted->capacity)) {
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "failed to apply chat template");
Expand Down Expand Up @@ -2015,6 +2058,52 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value
llm_chat_run(ai, NULL, user_prompt);
}

static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv) {
if (llm_check_context(context) == false)
return;

ai_context *ai = (ai_context *)sqlite3_user_data(context);
if (llm_chat_check_context(ai) == false)
return;

ai_messages *messages = &ai->chat.messages;

// get system role message
if (argc == 0) {
if (messages->count == 0) {
sqlite3_result_null(context);
return;
}

// only the first message is reserved to the system role
llama_chat_message *system_message = &messages->items[0];
const char *content = system_message->content;
if (system_message->role == ROLE_SYSTEM && content && content[0] != '\0') {
sqlite3_result_text(context, content, -1, SQLITE_TRANSIENT);
} else {
sqlite3_result_null(context);
}

return;
}

bool is_null_prompt = (sqlite3_value_type(argv[0]) == SQLITE_NULL);
int types[1];
types[0] = is_null_prompt ? SQLITE_NULL : SQLITE_TEXT;

if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false)
return;

const unsigned char *prompt_text = sqlite3_value_text(argv[0]);
const char *system_prompt = prompt_text ? (const char *)prompt_text : "";
if (!llm_messages_set(messages, 0, ROLE_SYSTEM, system_prompt)) {
if (!llm_messages_append(messages, ROLE_SYSTEM, system_prompt)) {
sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to set chat system prompt");
return;
}
}
}

// MARK: - LLM Sampler -

static void llm_sampler_init_greedy (sqlite3_context *context, int argc, sqlite3_value **argv) {
Expand Down Expand Up @@ -2853,6 +2942,12 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a

rc = sqlite3_create_function(db, "llm_chat_respond", 1, SQLITE_UTF8, ctx, llm_chat_respond, NULL, NULL);
if (rc != SQLITE_OK) goto cleanup;

rc = sqlite3_create_function(db, "llm_chat_system_prompt", 0, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL);
if (rc != SQLITE_OK) goto cleanup;

rc = sqlite3_create_function(db, "llm_chat_system_prompt", 1, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL);
if (rc != SQLITE_OK) goto cleanup;

rc = sqlite3_create_module(db, "llm_chat", &llm_chat, ctx);
if (rc != SQLITE_OK) goto cleanup;
Expand Down
Loading
Loading