Skip to content

Commit 692e3cd

Browse files
authored
memory : rename interface to llama_memory_context_i (#14296)
* memory : rename interface to llama_memory_context_i ggml-ci * cont : fix comments * cont : use "mctx" for referencing a memory context ggml-ci
1 parent b23fa0b commit 692e3cd

14 files changed

+339
-341
lines changed

src/llama-context.cpp

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -280,16 +280,16 @@ llama_context::llama_context(
280280

281281
// simulate full KV cache
282282

283-
const auto mstate = memory->init_full();
284-
if (!mstate) {
283+
const auto mctx = memory->init_full();
284+
if (!mctx) {
285285
throw std::runtime_error("failed to initialize KV cache");
286286
}
287287

288288
cross.v_embd.clear();
289289

290290
// reserve pp graph first so that buffers are only allocated once
291291
{
292-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
292+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293293
if (!gf) {
294294
throw std::runtime_error("failed to allocate compute pp buffers");
295295
}
@@ -300,7 +300,7 @@ llama_context::llama_context(
300300

301301
// reserve with tg graph to get the number of splits and nodes
302302
{
303-
auto * gf = graph_reserve(1, 1, 1, mstate.get());
303+
auto * gf = graph_reserve(1, 1, 1, mctx.get());
304304
if (!gf) {
305305
throw std::runtime_error("failed to allocate compute tg buffers");
306306
}
@@ -311,7 +311,7 @@ llama_context::llama_context(
311311

312312
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
313313
{
314-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
314+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315315
if (!gf) {
316316
throw std::runtime_error("failed to allocate compute pp buffers");
317317
}
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
444444
optimize |= memory_force_optimize;
445445
memory_force_optimize = false;
446446

447-
const auto mstate = memory->init_update(this, optimize);
448-
switch (mstate->get_status()) {
447+
const auto mctx = memory->init_update(this, optimize);
448+
switch (mctx->get_status()) {
449449
case LLAMA_MEMORY_STATUS_SUCCESS:
450450
{
451451
// noop
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
463463
}
464464
}
465465

466-
if (!mstate->apply()) {
466+
if (!mctx->apply()) {
467467
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468468
}
469469
}
470470

471471
// if the memory module did any computation, we have to reserve a new worst-case graph
472472
{
473-
const auto mstate = memory->init_full();
474-
if (!mstate) {
475-
throw std::runtime_error("failed to initialize memory state");
473+
const auto mctx = memory->init_full();
474+
if (!mctx) {
475+
throw std::runtime_error("failed to initialize memory context");
476476
}
477477

478478
const uint32_t n_seqs = cparams.n_seq_max;
479479
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
480480

481-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
481+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
482482
if (!gf) {
483483
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
484484
}
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
678678
return cvec.apply(model, data, len, n_embd, il_start, il_end);
679679
}
680680

681-
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
682-
if (mstate && !mstate->apply()) {
683-
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
681+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682+
if (mctx && !mctx->apply()) {
683+
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684684
ret = GGML_STATUS_FAILED;
685685
return nullptr;
686686
}
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
692692
return nullptr;
693693
}
694694

695-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
695+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696696
if (!res) {
697697
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698698
ret = GGML_STATUS_FAILED;
@@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
933933
// handle any pending defrags/shifts
934934
kv_self_update(false);
935935

936-
llama_memory_state_ptr mstate;
936+
llama_memory_context_ptr mctx;
937937

938938
while (true) {
939-
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940-
if (!mstate) {
939+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940+
if (!mctx) {
941941
return -2;
942942
}
943943

944-
switch (mstate->get_status()) {
944+
switch (mctx->get_status()) {
945945
case LLAMA_MEMORY_STATUS_SUCCESS:
946946
{
947947
} break;
948948
case LLAMA_MEMORY_STATUS_NO_UPDATE:
949949
{
950-
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
950+
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
951951

952952
return -2;
953953
}
@@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
987987
int64_t n_outputs_prev = 0;
988988

989989
do {
990-
const auto & ubatch = mstate->get_ubatch();
990+
const auto & ubatch = mctx->get_ubatch();
991991

992992
// count the outputs in this ubatch
993993
{
@@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
10091009
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
10101010

10111011
ggml_status status;
1012-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
1012+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
10131013

10141014
if (!res) {
10151015
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11261126
}
11271127

11281128
n_outputs_prev += n_outputs;
1129-
} while (mstate->next());
1129+
} while (mctx->next());
11301130

11311131
// set to total number of outputs in the batch, for use in llama_get_logits_ith
11321132
n_outputs = n_outputs_all;
@@ -1292,7 +1292,7 @@ ggml_cgraph * llama_context::graph_init() {
12921292
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
12931293
}
12941294

1295-
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
1295+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
12961296
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
12971297

12981298
if (n_tokens % n_seqs != 0) {
@@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13121312
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
13131313

13141314
auto * gf = graph_init();
1315-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
1315+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
13161316

13171317
this->n_outputs = save_n_outputs;
13181318

@@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13331333
}
13341334

13351335
llm_graph_result_ptr llama_context::graph_build(
1336-
ggml_context * ctx,
1337-
ggml_cgraph * gf,
1338-
const llama_ubatch & ubatch,
1339-
llm_graph_type gtype,
1340-
const llama_memory_state_i * mstate) {
1336+
ggml_context * ctx,
1337+
ggml_cgraph * gf,
1338+
const llama_ubatch & ubatch,
1339+
llm_graph_type gtype,
1340+
const llama_memory_context_i * mctx) {
13411341
return model.build_graph(
13421342
{
13431343
/*.ctx =*/ ctx,
@@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
13491349
/*.backend_cpu =*/ backend_cpu,
13501350
/*.cvec =*/ &cvec,
13511351
/*.loras =*/ &loras,
1352-
/*.mstate =*/ mstate,
1352+
/*.mctx =*/ mctx,
13531353
/*.cross =*/ &cross,
13541354
/*.n_outputs =*/ n_outputs,
13551355
/*.cb =*/ graph_get_cb(),
@@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter(
20422042

20432043
uint32_t n_outputs_all = n_tokens_all;
20442044

2045-
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
2046-
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2045+
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2046+
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
20472047
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
20482048
break;
20492049
}
@@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter(
20562056

20572057
uint32_t pos_batch = 0;
20582058
do {
2059-
const auto & ubatch = mstate->get_ubatch();
2059+
const auto & ubatch = mctx->get_ubatch();
20602060

20612061
n_outputs = ubatch.n_tokens;
20622062

2063-
if (!mstate->apply()) {
2064-
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
2063+
if (!mctx->apply()) {
2064+
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
20652065
break;
20662066
}
20672067

20682068
auto * gf = graph_init();
2069-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
2069+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
20702070

20712071
struct ggml_context * ctx_compute_opt;
20722072
{
@@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter(
21012101
ggml_free(ctx_compute_opt);
21022102

21032103
pos_batch += ubatch.n_tokens;
2104-
} while (mstate->next());
2104+
} while (mctx->next());
21052105
}
21062106
}
21072107

src/llama-context.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class llama_io_read_i;
1818
class llama_io_write_i;
1919

2020
struct llama_memory_i;
21-
struct llama_memory_state_i;
21+
struct llama_memory_context_i;
2222

2323
struct llama_context {
2424
// init scheduler and compute buffers, reserve worst-case graphs
@@ -93,14 +93,14 @@ struct llama_context {
9393
int32_t il_end);
9494

9595
// process a single ubatch with a specific graph type
96-
// if memory_state is provided, it will be applied first to the context's memory
96+
// if memory_context is provided, it will be applied first to the context's memory
9797
// ret contains the status of the graph computation
9898
// returns nullptr only if ret != GGML_STATUS_SUCCESS
9999
llm_graph_result_ptr process_ubatch(
100-
const llama_ubatch & ubatch,
101-
llm_graph_type gtype,
102-
llama_memory_state_i * mstate,
103-
ggml_status & ret);
100+
const llama_ubatch & ubatch,
101+
llm_graph_type gtype,
102+
llama_memory_context_i * mctx,
103+
ggml_status & ret);
104104

105105
int encode(const llama_batch & batch_inp);
106106
int decode(const llama_batch & batch_inp);
@@ -197,15 +197,15 @@ struct llama_context {
197197
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
198198

199199
// reserve a graph with a dummy ubatch of the specified size
200-
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
200+
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201201

202202
private:
203203
llm_graph_result_ptr graph_build(
204-
ggml_context * ctx,
205-
ggml_cgraph * gf,
206-
const llama_ubatch & ubatch,
207-
llm_graph_type gtype,
208-
const llama_memory_state_i * mstate);
204+
ggml_context * ctx,
205+
ggml_cgraph * gf,
206+
const llama_ubatch & ubatch,
207+
llm_graph_type gtype,
208+
const llama_memory_context_i * mctx);
209209

210210
llm_graph_cb graph_get_cb() const;
211211

0 commit comments

Comments
 (0)