@@ -59,6 +59,7 @@ SQLITE_EXTENSION_INIT1
5959#define OPTION_KEY_ROPE_SCALING_TYPE "rope_scaling_type"
6060#define OPTION_KEY_POOLING_TYPE "pooling_type"
6161#define OPTION_KEY_ATTENTION_TYPE "attention_type"
62+ #define OPTION_KEY_FLASH_ATTN_TYPE "flash_attn_type"
6263
6364#define OPTION_KEY_ROPE_FREQ_BASE "rope_freq_base"
6465#define OPTION_KEY_ROPE_FREQ_SCALE "rope_freq_scale"
@@ -71,9 +72,9 @@ SQLITE_EXTENSION_INIT1
7172#define OPTION_KEY_TYPE_K "type_k"
7273#define OPTION_KEY_TYPE_V "type_v"
7374#define OPTION_KEY_OFFLOAD_KQV "offload_kqv"
74- #define OPTION_KEY_FLASH_ATTN "flash_attn"
7575#define OPTION_KEY_OP_OFFLOAD "op_offload"
7676#define OPTION_KEY_SWA_FULL "swa_full"
77+ #define OPTION_KEY_TYPE_KV_UNIFIED "kv_unified"
7778
7879#define OPTION_KEY_GENERATE_EMBEDDING "generate_embedding"
7980#define OPTION_KEY_NORMALIZE_EMBEDDING "normalize_embedding"
@@ -376,8 +377,9 @@ static bool llm_context_options_callback (void *ctx, void *xdata, const char *ke
376377 }
377378
378379 if (strncasecmp (key , OPTION_KEY_POOLING_TYPE , key_len ) == 0 ) {
379- if (strcasecmp (buffer , "none" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_MEAN ;
380380 // pooling_type mean is not supported and so in this version we forced it to be really mean so ONE EMBEDDING will be generated
381+ if (strcasecmp (buffer , "none" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_MEAN ;
382+ else if (strcasecmp (buffer , "unspecified" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_MEAN ;
381383 else if (strcasecmp (buffer , "mean" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_MEAN ;
382384 else if (strcasecmp (buffer , "cls" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_CLS ;
383385 else if (strcasecmp (buffer , "last" ) == 0 ) options -> pooling_type = LLAMA_POOLING_TYPE_LAST ;
@@ -386,7 +388,8 @@ static bool llm_context_options_callback (void *ctx, void *xdata, const char *ke
386388 }
387389
388390 if (strncasecmp (key , OPTION_KEY_ATTENTION_TYPE , key_len ) == 0 ) {
389- if (strcasecmp (buffer , "causal" ) == 0 ) options -> attention_type = LLAMA_ATTENTION_TYPE_CAUSAL ;
391+ if (strcasecmp (buffer , "unspecified" ) == 0 ) options -> attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED ;
392+ else if (strcasecmp (buffer , "causal" ) == 0 ) options -> attention_type = LLAMA_ATTENTION_TYPE_CAUSAL ;
390393 else if (strcasecmp (buffer , "non_causal" ) == 0 ) options -> attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL ;
391394 return true;
392395 }
@@ -400,6 +403,13 @@ static bool llm_context_options_callback (void *ctx, void *xdata, const char *ke
400403 return true;
401404 }
402405
406+ if (strncasecmp (key , OPTION_KEY_FLASH_ATTN_TYPE , key_len ) == 0 ) {
407+ if (strcasecmp (buffer , "auto" ) == 0 ) options -> flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO ;
408+ else if (strcasecmp (buffer , "disabled" ) == 0 ) options -> flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED ;
409+ else if (strcasecmp (buffer , "enabled" ) == 0 ) options -> flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED ;
410+ return true;
411+ }
412+
403413 if (strncasecmp (key , OPTION_KEY_ROPE_FREQ_BASE , key_len ) == 0 ) {
404414 float value = strtof (buffer , NULL );
405415 options -> rope_freq_base = value ;
@@ -454,12 +464,6 @@ static bool llm_context_options_callback (void *ctx, void *xdata, const char *ke
454464 return true;
455465 }
456466
457- if (strncasecmp (key , OPTION_KEY_FLASH_ATTN , key_len ) == 0 ) {
458- int value = (int )strtol (buffer , NULL , 0 );
459- options -> flash_attn = (value != 0 );
460- return true;
461- }
462-
463467 if (strncasecmp (key , OPTION_KEY_OP_OFFLOAD , key_len ) == 0 ) {
464468 int value = (int )strtol (buffer , NULL , 0 );
465469 options -> op_offload = (value != 0 );
@@ -484,6 +488,12 @@ static bool llm_context_options_callback (void *ctx, void *xdata, const char *ke
484488 return true;
485489 }
486490
491+ if (strncasecmp (key , OPTION_KEY_TYPE_KV_UNIFIED , key_len ) == 0 ) {
492+ int value = (int )strtol (buffer , NULL , 0 );
493+ if (value >= 0 ) options -> kv_unified = (value != 0 );
494+ return true;
495+ }
496+
487497 // means ignore unknown keys
488498 return true;
489499}
0 commit comments