@@ -199,6 +199,7 @@ typedef enum {
199199 AI_MODEL_CHAT_TEMPLATE
200200} ai_model_setting ;
201201
202+ const char * ROLE_SYSTEM = "system" ;
202203const char * ROLE_USER = "user" ;
203204const char * ROLE_ASSISTANT = "assistant" ;
204205
@@ -785,27 +786,58 @@ static bool llm_check_context (sqlite3_context *context) {
785786// MARK: - Chat Messages -
786787
787788bool llm_messages_append (ai_messages * list , const char * role , const char * content ) {
788- if (list -> count >= list -> capacity ) {
789+ if (role == ROLE_SYSTEM && list -> count > 0 ) {
790+ // only one system prompt allowed at the beginning
791+ return false;
792+ }
793+
794+ bool needs_system_message = (list -> count == 0 && role != ROLE_SYSTEM );
795+ size_t required = list -> count + (needs_system_message ? 1 : 0 );
796+ if (required >= list -> capacity ) {
789797 size_t new_cap = list -> capacity ? list -> capacity * 2 : MIN_ALLOC_MESSAGES ;
790798 llama_chat_message * new_items = sqlite3_realloc64 (list -> items , new_cap * sizeof (llama_chat_message ));
791799 if (!new_items ) return false;
792-
800+
793801 list -> items = new_items ;
794802 list -> capacity = new_cap ;
795803 }
796804
797- bool duplicate_role = ((role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
805+ if (needs_system_message ) {
806+ // reserve first item for empty system prompt
807+ list -> items [list -> count ].role = ROLE_SYSTEM ;
808+ list -> items [list -> count ].content = sqlite_strdup ("" );
809+ list -> count += 1 ;
810+ }
811+
812+ bool duplicate_role = ((role != ROLE_SYSTEM ) && (role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
798813 list -> items [list -> count ].role = (duplicate_role ) ? sqlite_strdup (role ) : role ;
799814 list -> items [list -> count ].content = sqlite_strdup (content );
800815 list -> count += 1 ;
801816 return true;
802817}
803818
819+ bool llm_messages_set (ai_messages * list , int pos , const char * role , const char * content ) {
820+ if (pos < 0 || pos >= list -> count )
821+ return false;
822+
823+ bool duplicate_role = ((role != ROLE_SYSTEM ) && (role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
824+ llama_chat_message * message = & list -> items [pos ];
825+
826+ const char * message_role = message -> role ;
827+ if ((message_role != ROLE_SYSTEM ) && (message_role != ROLE_USER ) && (message_role != ROLE_ASSISTANT ))
828+ sqlite3_free ((char * )message_role );
829+ sqlite3_free ((char * )message -> content );
830+
831+ message -> role = (duplicate_role ) ? sqlite_strdup (role ) : role ;
832+ message -> content = sqlite_strdup (content );
833+ return true;
834+ }
835+
804836void llm_messages_free (ai_messages * list ) {
805837 for (size_t i = 0 ; i < list -> count ; ++ i ) {
806838 // check if rule is static
807839 const char * role = list -> items [i ].role ;
808- bool role_tofree = ((role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
840+ bool role_tofree = ((role != ROLE_SYSTEM ) && ( role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
809841 if (role_tofree ) sqlite3_free ((char * )list -> items [i ].role );
810842 // content is always to free
811843 sqlite3_free ((char * )list -> items [i ].content );
@@ -1648,12 +1680,23 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt)
16481680 return false;
16491681 }
16501682
1683+ // skip empty system message if present
1684+ size_t messages_count = messages -> count ;
1685+ const llama_chat_message * messages_items = messages -> items ;
1686+ if (messages -> count > 0 ) {
1687+ const llama_chat_message first_message = messages -> items [0 ];
1688+ if (first_message .role == ROLE_SYSTEM && first_message .content [0 ] == '\0' ) {
1689+ messages_items = messages -> items + 1 ;
1690+ messages_count = messages -> count - 1 ;
1691+ }
1692+ }
1693+
16511694 // transform a list of messages (the context) into
16521695 // <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>...
1653- int32_t new_len = llama_chat_apply_template (template , messages -> items , messages -> count , true, formatted -> data , formatted -> capacity );
1696+ int32_t new_len = llama_chat_apply_template (template , messages_items , messages_count , true, formatted -> data , formatted -> capacity );
16541697 if (new_len > formatted -> capacity ) {
16551698 if (buffer_resize (formatted , new_len * 2 ) == false) return false;
1656- new_len = llama_chat_apply_template (template , messages -> items , messages -> count , true, formatted -> data , formatted -> capacity );
1699+ new_len = llama_chat_apply_template (template , messages_items , messages_count , true, formatted -> data , formatted -> capacity );
16571700 }
16581701 if ((new_len < 0 ) || (new_len > formatted -> capacity )) {
16591702 sqlite_common_set_error (ai -> context , ai -> vtab , SQLITE_ERROR , "failed to apply chat template" );
@@ -2015,6 +2058,52 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value
20152058 llm_chat_run (ai , NULL , user_prompt );
20162059}
20172060
2061+ static void llm_chat_system_prompt (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
2062+ if (llm_check_context (context ) == false)
2063+ return ;
2064+
2065+ ai_context * ai = (ai_context * )sqlite3_user_data (context );
2066+ if (llm_chat_check_context (ai ) == false)
2067+ return ;
2068+
2069+ ai_messages * messages = & ai -> chat .messages ;
2070+
2071+ // get system role message
2072+ if (argc == 0 ) {
2073+ if (messages -> count == 0 ) {
2074+ sqlite3_result_null (context );
2075+ return ;
2076+ }
2077+
2078+ // only the first message is reserved to the system role
2079+ llama_chat_message * system_message = & messages -> items [0 ];
2080+ const char * content = system_message -> content ;
2081+ if (system_message -> role == ROLE_SYSTEM && content && content [0 ] != '\0' ) {
2082+ sqlite3_result_text (context , content , -1 , SQLITE_TRANSIENT );
2083+ } else {
2084+ sqlite3_result_null (context );
2085+ }
2086+
2087+ return ;
2088+ }
2089+
2090+ bool is_null_prompt = (sqlite3_value_type (argv [0 ]) == SQLITE_NULL );
2091+ int types [1 ];
2092+ types [0 ] = is_null_prompt ? SQLITE_NULL : SQLITE_TEXT ;
2093+
2094+ if (sqlite_sanity_function (context , "llm_chat_system_prompt" , argc , argv , 1 , types , true, false) == false)
2095+ return ;
2096+
2097+ const unsigned char * prompt_text = sqlite3_value_text (argv [0 ]);
2098+ const char * system_prompt = prompt_text ? (const char * )prompt_text : "" ;
2099+ if (!llm_messages_set (messages , 0 , ROLE_SYSTEM , system_prompt )) {
2100+ if (!llm_messages_append (messages , ROLE_SYSTEM , system_prompt )) {
2101+ sqlite_common_set_error (ai -> context , ai -> vtab , SQLITE_ERROR , "Failed to set chat system prompt" );
2102+ return ;
2103+ }
2104+ }
2105+ }
2106+
20182107// MARK: - LLM Sampler -
20192108
20202109static void llm_sampler_init_greedy (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
@@ -2853,6 +2942,12 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a
28532942
28542943 rc = sqlite3_create_function (db , "llm_chat_respond" , 1 , SQLITE_UTF8 , ctx , llm_chat_respond , NULL , NULL );
28552944 if (rc != SQLITE_OK ) goto cleanup ;
2945+
2946+ rc = sqlite3_create_function (db , "llm_chat_system_prompt" , 0 , SQLITE_UTF8 , ctx , llm_chat_system_prompt , NULL , NULL );
2947+ if (rc != SQLITE_OK ) goto cleanup ;
2948+
2949+ rc = sqlite3_create_function (db , "llm_chat_system_prompt" , 1 , SQLITE_UTF8 , ctx , llm_chat_system_prompt , NULL , NULL );
2950+ if (rc != SQLITE_OK ) goto cleanup ;
28562951
28572952 rc = sqlite3_create_module (db , "llm_chat" , & llm_chat , ctx );
28582953 if (rc != SQLITE_OK ) goto cleanup ;
0 commit comments