Skip to content

Commit 34be76d

Browse files
committed
fix(chat): make chat saves idempotent and support repeated updates for the same uuid
Updated llm_chat_save to use INSERT ... ON CONFLICT for ai_chat_history, ensuring chat metadata is updated if the uuid exists. Added logic to fetch the correct rowid after upsert and to delete all previous messages for the chat before saving new ones, ensuring message history is consistent.
1 parent 6032d3e commit 34be76d

File tree

2 files changed

+144
-5
lines changed

2 files changed

+144
-5
lines changed

src/sqlite-ai.c

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,20 +1957,44 @@ static void llm_chat_save (sqlite3_context *context, int argc, sqlite3_value **a
19571957
// start transaction
19581958
sqlite_db_write_simple(context, db, "BEGIN;");
19591959

1960-
// save chat
1961-
const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?);";
1960+
// save chat, the ON CONFLICT allows saving multiple times
1961+
const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?) "
1962+
"ON CONFLICT(uuid) DO UPDATE SET "
1963+
" title = excluded.title, "
1964+
" metadata = excluded.metadata, "
1965+
" created_at = CURRENT_TIMESTAMP;";
19621966
const char *values[] = {ai->chat.uuid, title, meta};
19631967
int types[] = {SQLITE_TEXT, SQLITE_TEXT, SQLITE_TEXT};
19641968
int lens[] = {-1, -1, -1};
19651969

19661970
int rc = sqlite_db_write(context, db, sql, values, types, lens, 3);
19671971
if (rc != SQLITE_OK) goto abort_save;
1968-
1969-
// loop to save messages (the context)
1972+
1973+
// get the rowid, cannot use sqlite3_last_insert_rowid for the CONFLICT case
19701974
char rowid_s[256];
1971-
sqlite3_int64 rowid = sqlite3_last_insert_rowid(db);
1975+
sqlite3_stmt *pstmt = NULL;
1976+
sql = "SELECT id FROM ai_chat_history WHERE uuid = ?;";
1977+
rc = sqlite3_prepare_v2(db, sql, -1, &pstmt, NULL);
1978+
if (rc != SQLITE_OK) goto abort_save;
1979+
rc = sqlite3_bind_text(pstmt, 1, ai->chat.uuid, -1, SQLITE_STATIC);
1980+
rc = sqlite3_step(pstmt);
1981+
if (rc != SQLITE_ROW) {
1982+
sqlite3_finalize(pstmt);
1983+
goto abort_save;
1984+
}
1985+
sqlite3_int64 rowid = sqlite3_column_int64(pstmt, 0);
1986+
sqlite3_finalize(pstmt);
19721987
snprintf(rowid_s, sizeof(rowid_s), "%lld", (long long)rowid);
1988+
1989+
// delete all messages for this chat id, if any
1990+
sql = "DELETE FROM ai_chat_messages WHERE chat_id = ?;";
1991+
const char *values3[] = {rowid_s};
1992+
int types3[] = {SQLITE_INTEGER};
1993+
int lens3[] = {-1};
1994+
rc = sqlite_db_write(context, db, sql, values3, types3, lens3, 1);
1995+
if (rc != SQLITE_OK) goto abort_save;
19731996

1997+
// loop to save messages (the context)
19741998
sql = "INSERT INTO ai_chat_messages (chat_id, role, content) VALUES (?, ?, ?);";
19751999
int types2[] = {SQLITE_INTEGER, SQLITE_TEXT, SQLITE_TEXT};
19762000

tests/c/unittest.c

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,120 @@ static int test_chat_system_prompt_after_first_response(const test_env *env) {
10081008
return status;
10091009
}
10101010

1011+
static int test_llm_chat_double_save(const test_env *env) {
1012+
sqlite3 *db = NULL;
1013+
bool model_loaded = false;
1014+
bool context_created = false;
1015+
bool chat_created = false;
1016+
int status = 1;
1017+
1018+
if (open_db_and_load(env, &db) != SQLITE_OK) {
1019+
goto done;
1020+
}
1021+
1022+
const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH;
1023+
char sqlbuf[512];
1024+
snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model);
1025+
if (exec_expect_ok(env, db, sqlbuf) != 0)
1026+
goto done;
1027+
model_loaded = true;
1028+
1029+
if (exec_expect_ok(env, db,
1030+
"SELECT llm_context_create('context_size=1000');") != 0)
1031+
goto done;
1032+
context_created = true;
1033+
1034+
if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0)
1035+
goto done;
1036+
chat_created = true;
1037+
1038+
// First prompt
1039+
const char *prompt1 = "First prompt";
1040+
if (exec_expect_ok(env, db, "SELECT llm_chat_respond('First prompt');") != 0)
1041+
goto done;
1042+
1043+
// First save
1044+
if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0)
1045+
goto done;
1046+
1047+
// Second prompt
1048+
const char *prompt2 = "Second prompt";
1049+
if (exec_expect_ok(env, db, "SELECT llm_chat_respond('Second prompt');") != 0)
1050+
goto done;
1051+
1052+
// Second save
1053+
if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0)
1054+
goto done;
1055+
1056+
ai_chat_message_row rows[8];
1057+
int count = 0;
1058+
// We expect 4 messages: User1, Assistant1, User2, Assistant2
1059+
if (fetch_ai_chat_messages(env, db, rows, 8, &count) != 0)
1060+
goto done;
1061+
1062+
if (count != 5) {
1063+
fprintf(stderr,
1064+
"[test_llm_chat_double_save] expected 4 message rows, got %d\n",
1065+
count);
1066+
goto done;
1067+
}
1068+
1069+
// Verify order and roles
1070+
if (strcmp(rows[0].role, "system") != 0 ||
1071+
strcmp(rows[0].content, "") != 0) {
1072+
fprintf(stderr,
1073+
"[test_llm_chat_double_save] row 0 mismatch (expected system/'%s', "
1074+
"got %s/'%s')\n",
1075+
"", rows[0].role, rows[0].content);
1076+
goto done;
1077+
}
1078+
if (strcmp(rows[1].role, "user") != 0 ||
1079+
strcmp(rows[1].content, prompt1) != 0) {
1080+
fprintf(stderr,
1081+
"[test_llm_chat_double_save] row 0 mismatch (expected user/'%s', "
1082+
"got %s/'%s')\n",
1083+
prompt1, rows[1].role, rows[1].content);
1084+
goto done;
1085+
}
1086+
if (strcmp(rows[2].role, "assistant") != 0) {
1087+
fprintf(stderr,
1088+
"[test_llm_chat_double_save] row 1 mismatch (expected assistant, "
1089+
"got %s)\n",
1090+
rows[2].role);
1091+
goto done;
1092+
}
1093+
if (strcmp(rows[3].role, "user") != 0 ||
1094+
strcmp(rows[3].content, prompt2) != 0) {
1095+
fprintf(stderr,
1096+
"[test_llm_chat_double_save] row 2 mismatch (expected user/'%s', "
1097+
"got %s/'%s')\n",
1098+
prompt2, rows[3].role, rows[3].content);
1099+
goto done;
1100+
}
1101+
if (strcmp(rows[4].role, "assistant") != 0) {
1102+
fprintf(stderr,
1103+
"[test_llm_chat_double_save] row 3 mismatch (expected assistant, "
1104+
"got %s)\n",
1105+
rows[4].role);
1106+
goto done;
1107+
}
1108+
1109+
status = 0;
1110+
1111+
done:
1112+
if (chat_created)
1113+
exec_expect_ok(env, db, "SELECT llm_chat_free();");
1114+
if (context_created)
1115+
exec_expect_ok(env, db, "SELECT llm_context_free();");
1116+
if (model_loaded)
1117+
exec_expect_ok(env, db, "SELECT llm_model_free();");
1118+
if (db)
1119+
sqlite3_close(db);
1120+
if (status == 0)
1121+
status = assert_sqlite_memory_clean("llm_chat_double_save", env);
1122+
return status;
1123+
}
1124+
10111125
static const test_case TESTS[] = {
10121126
{"issue15_llm_chat_without_context", test_issue15_chat_without_context},
10131127
{"llm_chat_respond_repeated", test_llm_chat_respond_repeated},
@@ -1026,6 +1140,7 @@ static const test_case TESTS[] = {
10261140
{"chat_system_prompt_new_chat", test_chat_system_prompt_new_chat},
10271141
{"chat_system_prompt_replace_previous_prompt", test_chat_system_prompt_replace_previous_prompt},
10281142
{"chat_system_prompt_after_first_response", test_chat_system_prompt_after_first_response},
1143+
{"llm_chat_double_save", test_llm_chat_double_save},
10291144
};
10301145

10311146
int main(int argc, char **argv) {

0 commit comments

Comments
 (0)