@@ -152,7 +152,6 @@ typedef struct {
152152 const char * template ;
153153 const struct llama_vocab * vocab ;
154154
155- char * system_prompt ;
156155 ai_messages messages ;
157156 buffer_t formatted ;
158157 buffer_t response ;
@@ -796,18 +795,40 @@ bool llm_messages_append (ai_messages *list, const char *role, const char *conte
796795 list -> capacity = new_cap ;
797796 }
798797
798+ if (list -> count != 0 && role == ROLE_SYSTEM ) {
799+ // only one system message allowed at the beginning
800+ return false;
801+ }
802+
799803 bool duplicate_role = ((role != ROLE_SYSTEM ) && (role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
800804 list -> items [list -> count ].role = (duplicate_role ) ? sqlite_strdup (role ) : role ;
801805 list -> items [list -> count ].content = sqlite_strdup (content );
802806 list -> count += 1 ;
803807 return true;
804808}
805809
810+ bool llm_messages_set (ai_messages * list , int pos , const char * role , const char * content ) {
811+ if (pos < 0 || pos >= list -> count )
812+ return false;
813+
814+ bool duplicate_role = ((role != ROLE_SYSTEM ) && (role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
815+ llama_chat_message * message = & list -> items [pos ];
816+
817+ const char * message_role = message -> role ;
818+ if ((message_role != ROLE_SYSTEM ) && (message_role != ROLE_USER ) && (message_role != ROLE_ASSISTANT ))
819+ sqlite3_free (message_role );
820+ sqlite3_free (message -> content );
821+
822+ message -> role = (duplicate_role ) ? sqlite_strdup (role ) : role ;
823+ message -> content = sqlite_strdup (content );
824+ return true;
825+ }
826+
806827void llm_messages_free (ai_messages * list ) {
807828 for (size_t i = 0 ; i < list -> count ; ++ i ) {
808829 // check if rule is static
809830 const char * role = list -> items [i ].role ;
810- bool role_tofree = ((role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
831+ bool role_tofree = ((role != ROLE_SYSTEM ) && ( role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
811832 if (role_tofree ) sqlite3_free ((char * )list -> items [i ].role );
812833 // content is always to free
813834 sqlite3_free ((char * )list -> items [i ].content );
@@ -1491,7 +1512,7 @@ static bool llm_chat_check_context (ai_context *ai) {
14911512 llama_sampler_chain_add (ai -> sampler , llama_sampler_init_temp (0.8 ));
14921513 llama_sampler_chain_add (ai -> sampler , llama_sampler_init_dist ((uint32_t )LLAMA_DEFAULT_SEED ));
14931514 }
1494-
1515+
14951516 // initialize the chat struct if already created
14961517 if (ai -> chat .uuid [0 ] != '\0' ) return true;
14971518
@@ -1649,33 +1670,24 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt)
16491670 sqlite_common_set_error (ai -> context , ai -> vtab , SQLITE_ERROR , "Failed to append message" );
16501671 return false;
16511672 }
1652-
1653- // add system prompt if available, models expect it to be the first message
1654- llama_chat_message * new_items = messages -> items ;
1655- if (ai -> chat .system_prompt ) {
1656- size_t n = messages -> count + (ai -> chat .system_prompt ? 1 : 0 );
1657- llama_chat_message * new_items = sqlite3_realloc64 (messages -> items , n * sizeof (llama_chat_message ));
1658- if (!new_items )
1659- return false;
1660-
1661- messages -> items = new_items ;
1662- messages -> capacity = n ;
1663-
1664- int idx = 0 ;
1665- new_items [0 ].role = ROLE_SYSTEM ;
1666- new_items [0 ].content = ai -> chat .system_prompt ;
1667- idx = 1 ;
1668- for (size_t i = 0 ; i < messages -> count ; ++ i ) {
1669- new_items [idx ++ ] = messages -> items [i ];
1673+
1674+ // skip empty system message if present
1675+ size_t messages_count = messages -> count ;
1676+ const llama_chat_message * messages_items = messages -> items ;
1677+ if (messages -> count > 0 ) {
1678+ const llama_chat_message first_message = messages -> items [0 ];
1679+ if (first_message .role == ROLE_SYSTEM && first_message .content [0 ] == '\0' ) {
1680+ messages_items = messages -> items + 1 ;
1681+ messages_count = messages -> count - 1 ;
16701682 }
16711683 }
16721684
16731685 // transform a list of messages (the context) into
16741686 // <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>...
1675- int32_t new_len = llama_chat_apply_template (template , new_items , messages -> count , true, formatted -> data , formatted -> capacity );
1687+ int32_t new_len = llama_chat_apply_template (template , messages_items , messages_count , true, formatted -> data , formatted -> capacity );
16761688 if (new_len > formatted -> capacity ) {
16771689 if (buffer_resize (formatted , new_len * 2 ) == false) return false;
1678- new_len = llama_chat_apply_template (template , messages -> items , messages -> count , true, formatted -> data , formatted -> capacity );
1690+ new_len = llama_chat_apply_template (template , messages_items , messages_count , true, formatted -> data , formatted -> capacity );
16791691 }
16801692 if ((new_len < 0 ) || (new_len > formatted -> capacity )) {
16811693 sqlite_common_set_error (ai -> context , ai -> vtab , SQLITE_ERROR , "failed to apply chat template" );
@@ -2037,25 +2049,50 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value
20372049 llm_chat_run (ai , NULL , user_prompt );
20382050}
20392051
2040- static void llm_chat_system_prompt (sqlite3_context * context , int argc , sqlite3_value * * argv )
2041- {
2052+ static void llm_chat_system_prompt (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
20422053 if (llm_check_context (context ) == false)
20432054 return ;
20442055
2045- int types [] = { SQLITE_TEXT } ;
2046- if (sqlite_sanity_function ( context , "llm_chat_system_prompt" , argc , argv , 1 , types , true, false ) == false)
2056+ ai_context * ai = ( ai_context * ) sqlite3_user_data ( context ) ;
2057+ if (llm_chat_check_context ( ai ) == false)
20472058 return ;
20482059
2049- const char * system_prompt = (const char * )sqlite3_value_text (argv [0 ]);
2050- ai_context * ai = (ai_context * )sqlite3_user_data (context );
2060+ ai_messages * messages = & ai -> chat .messages ;
2061+
2062+ // get system role message
2063+ if (argc == 0 ) {
2064+ if (messages -> count == 0 ) {
2065+ sqlite3_result_null (context );
2066+ return ;
2067+ }
2068+
2069+ // only the first message is reserved to the system role
2070+ llama_chat_message * system_message = & messages -> items [0 ];
2071+ const char * content = system_message -> content ;
2072+ if (system_message -> role == ROLE_SYSTEM && content && content [0 ] != '\0' ) {
2073+ sqlite3_result_text (context , content , -1 , SQLITE_TRANSIENT );
2074+ } else {
2075+ sqlite3_result_null (context );
2076+ }
20512077
2052- if (llm_chat_check_context (ai ) == false)
20532078 return ;
2079+ }
2080+
2081+ bool is_null_prompt = (sqlite3_value_type (argv [0 ]) == SQLITE_NULL );
2082+ int types [1 ];
2083+ types [0 ] = is_null_prompt ? SQLITE_NULL : SQLITE_TEXT ;
20542084
2055- if (ai -> chat .system_prompt ) {
2056- sqlite3_free (ai -> chat .system_prompt );
2085+ if (sqlite_sanity_function (context , "llm_chat_system_prompt" , argc , argv , 1 , types , true, false) == false)
2086+ return ;
2087+
2088+ const unsigned char * prompt_text = sqlite3_value_text (argv [0 ]);
2089+ const char * system_prompt = prompt_text ? (const char * )prompt_text : "" ;
2090+ if (!llm_messages_set (messages , 0 , ROLE_SYSTEM , system_prompt )) {
2091+ if (!llm_messages_append (messages , ROLE_SYSTEM , system_prompt )) {
2092+ sqlite_common_set_error (ai -> context , ai -> vtab , SQLITE_ERROR , "Failed to set chat system prompt" );
2093+ return ;
2094+ }
20572095 }
2058- ai -> chat .system_prompt = sqlite3_mprintf ("%s" , system_prompt );
20592096}
20602097
20612098// MARK: - LLM Sampler -
@@ -2895,10 +2932,13 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a
28952932 if (rc != SQLITE_OK ) goto cleanup ;
28962933
28972934 rc = sqlite3_create_function (db , "llm_chat_respond" , 1 , SQLITE_UTF8 , ctx , llm_chat_respond , NULL , NULL );
2935+ if (rc != SQLITE_OK ) goto cleanup ;
2936+
2937+ rc = sqlite3_create_function (db , "llm_chat_system_prompt" , 0 , SQLITE_UTF8 , ctx , llm_chat_system_prompt , NULL , NULL );
2938+ if (rc != SQLITE_OK ) goto cleanup ;
28982939
28992940 rc = sqlite3_create_function (db , "llm_chat_system_prompt" , 1 , SQLITE_UTF8 , ctx , llm_chat_system_prompt , NULL , NULL );
2900- if (rc != SQLITE_OK )
2901- goto cleanup ;
2941+ if (rc != SQLITE_OK ) goto cleanup ;
29022942
29032943 rc = sqlite3_create_module (db , "llm_chat" , & llm_chat , ctx );
29042944 if (rc != SQLITE_OK ) goto cleanup ;
0 commit comments