@@ -152,6 +152,7 @@ typedef struct {
152152 const char * template ;
153153 const struct llama_vocab * vocab ;
154154
155+ char * system_prompt ;
155156 ai_messages messages ;
156157 buffer_t formatted ;
157158 buffer_t response ;
@@ -199,6 +200,7 @@ typedef enum {
199200 AI_MODEL_CHAT_TEMPLATE
200201} ai_model_setting ;
201202
203+ const char * ROLE_SYSTEM = "system" ;
202204const char * ROLE_USER = "user" ;
203205const char * ROLE_ASSISTANT = "assistant" ;
204206
@@ -794,7 +796,7 @@ bool llm_messages_append (ai_messages *list, const char *role, const char *conte
794796 list -> capacity = new_cap ;
795797 }
796798
797- bool duplicate_role = ((role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
799+ bool duplicate_role = ((role != ROLE_SYSTEM ) && ( role != ROLE_USER ) && (role != ROLE_ASSISTANT ));
798800 list -> items [list -> count ].role = (duplicate_role ) ? sqlite_strdup (role ) : role ;
799801 list -> items [list -> count ].content = sqlite_strdup (content );
800802 list -> count += 1 ;
@@ -1489,7 +1491,7 @@ static bool llm_chat_check_context (ai_context *ai) {
14891491 llama_sampler_chain_add (ai -> sampler , llama_sampler_init_temp (0.8 ));
14901492 llama_sampler_chain_add (ai -> sampler , llama_sampler_init_dist ((uint32_t )LLAMA_DEFAULT_SEED ));
14911493 }
1492-
1494+
14931495 // initialize the chat struct if already created
14941496 if (ai -> chat .uuid [0 ] != '\0' ) return true;
14951497
@@ -1647,10 +1649,30 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt)
16471649 sqlite_common_set_error (ai -> context , ai -> vtab , SQLITE_ERROR , "Failed to append message" );
16481650 return false;
16491651 }
1652+
1653+ // add system prompt if available, models expect it to be the first message
1654+ llama_chat_message * new_items = messages -> items ;
1655+ if (ai -> chat .system_prompt ) {
1656+ size_t n = messages -> count + (ai -> chat .system_prompt ? 1 : 0 );
1657+ llama_chat_message * new_items = sqlite3_realloc64 (messages -> items , n * sizeof (llama_chat_message ));
1658+ if (!new_items )
1659+ return false;
1660+
1661+ messages -> items = new_items ;
1662+ messages -> capacity = n ;
1663+
1664+ int idx = 0 ;
1665+ new_items [0 ].role = ROLE_SYSTEM ;
1666+ new_items [0 ].content = ai -> chat .system_prompt ;
1667+ idx = 1 ;
1668+ for (size_t i = 0 ; i < messages -> count ; ++ i ) {
1669+ new_items [idx ++ ] = messages -> items [i ];
1670+ }
1671+ }
16501672
16511673 // transform a list of messages (the context) into
16521674 // <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>...
1653- int32_t new_len = llama_chat_apply_template (template , messages -> items , messages -> count , true, formatted -> data , formatted -> capacity );
1675+ int32_t new_len = llama_chat_apply_template (template , new_items , messages -> count , true, formatted -> data , formatted -> capacity );
16541676 if (new_len > formatted -> capacity ) {
16551677 if (buffer_resize (formatted , new_len * 2 ) == false) return false;
16561678 new_len = llama_chat_apply_template (template , messages -> items , messages -> count , true, formatted -> data , formatted -> capacity );
@@ -2015,6 +2037,27 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value
20152037 llm_chat_run (ai , NULL , user_prompt );
20162038}
20172039
2040+ static void llm_chat_system_prompt (sqlite3_context * context , int argc , sqlite3_value * * argv )
2041+ {
2042+ if (llm_check_context (context ) == false)
2043+ return ;
2044+
2045+ int types [] = {SQLITE_TEXT };
2046+ if (sqlite_sanity_function (context , "llm_chat_system_prompt" , argc , argv , 1 , types , true, false) == false)
2047+ return ;
2048+
2049+ const char * system_prompt = (const char * )sqlite3_value_text (argv [0 ]);
2050+ ai_context * ai = (ai_context * )sqlite3_user_data (context );
2051+
2052+ if (llm_chat_check_context (ai ) == false)
2053+ return ;
2054+
2055+ if (ai -> chat .system_prompt ) {
2056+ sqlite3_free (ai -> chat .system_prompt );
2057+ }
2058+ ai -> chat .system_prompt = sqlite3_mprintf ("%s" , system_prompt );
2059+ }
2060+
20182061// MARK: - LLM Sampler -
20192062
20202063static void llm_sampler_init_greedy (sqlite3_context * context , int argc , sqlite3_value * * argv ) {
@@ -2852,7 +2895,10 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a
28522895 if (rc != SQLITE_OK ) goto cleanup ;
28532896
28542897 rc = sqlite3_create_function (db , "llm_chat_respond" , 1 , SQLITE_UTF8 , ctx , llm_chat_respond , NULL , NULL );
2855- if (rc != SQLITE_OK ) goto cleanup ;
2898+
2899+ rc = sqlite3_create_function (db , "llm_chat_system_prompt" , 1 , SQLITE_UTF8 , ctx , llm_chat_system_prompt , NULL , NULL );
2900+ if (rc != SQLITE_OK )
2901+ goto cleanup ;
28562902
28572903 rc = sqlite3_create_module (db , "llm_chat" , & llm_chat , ctx );
28582904 if (rc != SQLITE_OK ) goto cleanup ;
0 commit comments