Skip to content

wasi-nn: protect the backend lookup table with a lock #4319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
#define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION

/* Global variables */
static korp_mutex wasi_nn_lock;
/*
* the "lookup" table is protected by wasi_nn_lock.
*
* an exception: during wasm_runtime_destroy, wasi_nn_destroy tears down
* the table without acquiring the lock. it's ok because there should be
* no other threads using the runtime at this point.
*/
struct backends_api_functions {
void *backend_handle;
api_function functions;
Expand Down Expand Up @@ -109,12 +117,18 @@ wasi_nn_initialize()
{
NN_DBG_PRINTF("[WASI NN General] Initializing wasi-nn");

if (os_mutex_init(&wasi_nn_lock)) {
NN_ERR_PRINTF("Error while initializing global lock");
return false;
}

// hashmap { instance: wasi_nn_ctx }
hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func,
key_equal_func, key_destroy_func,
value_destroy_func);
if (hashmap == NULL) {
NN_ERR_PRINTF("Error while initializing hashmap");
os_mutex_destroy(&wasi_nn_lock);
return false;
}

Expand Down Expand Up @@ -175,6 +189,8 @@ wasi_nn_destroy()

memset(&lookup[i].functions, 0, sizeof(api_function));
}

os_mutex_destroy(&wasi_nn_lock);
}

/* Utils */
Expand Down Expand Up @@ -349,6 +365,8 @@ static bool
detect_and_load_backend(graph_encoding backend_hint,
graph_encoding *loaded_backend)
{
bool ret;

if (backend_hint > autodetect)
return false;

Expand All @@ -360,16 +378,23 @@ detect_and_load_backend(graph_encoding backend_hint,

*loaded_backend = backend_hint;

os_mutex_lock(&wasi_nn_lock);
/* if already loaded */
if (lookup[backend_hint].backend_handle)
if (lookup[backend_hint].backend_handle) {
os_mutex_unlock(&wasi_nn_lock);
return true;
}

const char *backend_lib_name =
graph_encoding_to_backend_lib_name(backend_hint);
if (!backend_lib_name)
if (!backend_lib_name) {
os_mutex_unlock(&wasi_nn_lock);
return false;
}

return prepare_backend(backend_lib_name, lookup + backend_hint);
ret = prepare_backend(backend_lib_name, lookup + backend_hint);
os_mutex_unlock(&wasi_nn_lock);
return ret;
}

/* WASI-NN implementation */
Expand Down
Loading