Skip to content

Commit 67ae531

Browse files
authored
metal : fix thread-safety (#14300)
ggml-ci
1 parent 692e3cd commit 67ae531

File tree

1 file changed

+60
-28
lines changed

1 file changed

+60
-28
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,39 @@
4848
int mtl_device_ref_count;
4949
id<MTLLibrary> mtl_library;
5050

51+
NSLock * mtl_lock;
52+
5153
bool has_simdgroup_reduction;
5254
bool has_simdgroup_mm;
5355
bool has_residency_sets;
5456
bool has_bfloat;
5557
bool use_bfloat;
5658

59+
size_t max_size;
60+
5761
char name[128];
5862
} g_ggml_ctx_dev_main = {
5963
/*.mtl_device =*/ nil,
6064
/*.mtl_device_ref_count =*/ 0,
6165
/*.mtl_library =*/ nil,
66+
/*.mtl_lock =*/ nil,
6267
/*.has_simdgroup_reduction =*/ false,
6368
/*.has_simdgroup_mm =*/ false,
6469
/*.has_residency_sets =*/ false,
6570
/*.has_bfloat =*/ false,
6671
/*.use_bfloat =*/ false,
72+
/*.max_size =*/ 0,
6773
/*.name =*/ "",
6874
};
6975

7076
// acquire
7177
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
7278
assert(ctx != NULL);
7379

80+
if (ctx->mtl_lock == nil) {
81+
ctx->mtl_lock = [[NSLock alloc] init];
82+
}
83+
7484
if (ctx->mtl_device == nil) {
7585
ctx->mtl_device = MTLCreateSystemDefaultDevice();
7686
}
@@ -94,6 +104,8 @@
94104
ctx->use_bfloat = false;
95105
#endif
96106

107+
ctx->max_size = ctx->mtl_device.maxBufferLength;
108+
97109
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
98110
}
99111

@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
110122
ctx->mtl_device_ref_count--;
111123

112124
if (ctx->mtl_device_ref_count == 0) {
125+
if (ctx->mtl_lock) {
126+
[ctx->mtl_lock release];
127+
ctx->mtl_lock = nil;
128+
}
129+
113130
if (ctx->mtl_library) {
114131
[ctx->mtl_library release];
115132
ctx->mtl_library = nil;
@@ -977,7 +994,7 @@ @implementation GGMLMetalClass
977994
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
978995
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
979996

980-
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
997+
id<MTLDevice> device = ctx_dev->mtl_device;
981998

982999
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
9831000

@@ -991,9 +1008,16 @@ @implementation GGMLMetalClass
9911008
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
9921009

9931010
// load library
994-
if (ctx_dev->mtl_library == nil) {
995-
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1011+
{
1012+
[ctx_dev->mtl_lock lock];
1013+
1014+
if (ctx_dev->mtl_library == nil) {
1015+
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1016+
}
1017+
1018+
[ctx_dev->mtl_lock unlock];
9961019
}
1020+
9971021
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
9981022
if (metal_library == nil) {
9991023
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -5284,7 +5308,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
52845308
}
52855309

52865310
ggml_backend_metal_buffer_rset_free(ctx);
5287-
ggml_backend_metal_device_rel(buffer->buft->device->context);
52885311

52895312
if (ctx->owned) {
52905313
#if TARGET_OS_OSX
@@ -5393,7 +5416,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
53935416
}
53945417

53955418
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
5396-
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5419+
5420+
GGML_ASSERT(ctx_dev->mtl_device != nil);
5421+
5422+
id<MTLDevice> device = ctx_dev->mtl_device;
53975423

53985424
ctx->all_data = ggml_metal_host_malloc(size_aligned);
53995425
ctx->all_size = size_aligned;
@@ -5416,14 +5442,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
54165442
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
54175443
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
54185444
free(ctx);
5419-
ggml_backend_metal_device_rel(ctx_dev);
54205445
return NULL;
54215446
}
54225447

54235448
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
54245449
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
54255450
free(ctx);
5426-
ggml_backend_metal_device_rel(ctx_dev);
54275451
return NULL;
54285452
}
54295453

@@ -5434,17 +5458,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
54345458

54355459
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
54365460
return 32;
5461+
54375462
GGML_UNUSED(buft);
54385463
}
54395464

54405465
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
5441-
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
5442-
const size_t max_size = device.maxBufferLength;
5443-
ggml_backend_metal_device_rel(buft->device->context);
5466+
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
54445467

54455468
return max_size;
5446-
5447-
GGML_UNUSED(buft);
54485469
}
54495470

54505471
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@@ -5517,7 +5538,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
55175538
}
55185539

55195540
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
5520-
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5541+
5542+
GGML_ASSERT(ctx_dev->mtl_device != nil);
5543+
5544+
id<MTLDevice> device = ctx_dev->mtl_device;
55215545

55225546
// the buffer fits into the max buffer size allowed by the device
55235547
if (size_aligned <= device.maxBufferLength) {
@@ -5573,7 +5597,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
55735597
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
55745598
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
55755599
free(ctx);
5576-
ggml_backend_metal_device_rel(ctx_dev);
55775600
return NULL;
55785601
}
55795602

@@ -5589,10 +5612,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
55895612
}
55905613

55915614
static void ggml_backend_metal_free(ggml_backend_t backend) {
5592-
struct ggml_backend_metal_context * ctx = backend->context;
5593-
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5615+
struct ggml_backend_metal_context * ctx = backend->context;
55945616

5595-
ggml_backend_metal_device_rel(ctx_dev);
55965617
ggml_metal_free(ctx);
55975618

55985619
free(backend);
@@ -5732,6 +5753,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
57325753

57335754
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
57345755

5756+
GGML_ASSERT(ctx_dev->mtl_device != nil);
5757+
57355758
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
57365759
}
57375760

@@ -5751,23 +5774,18 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
57515774
}
57525775

57535776
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
5754-
// acq/rel just to populate ctx->name in case it hasn't been done yet
57555777
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5756-
ggml_backend_metal_device_acq(ctx_dev);
5757-
ggml_backend_metal_device_rel(ctx_dev);
57585778

57595779
return ctx_dev->name;
57605780
}
57615781

57625782
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
57635783
if (@available(macOS 10.12, iOS 16.0, *)) {
57645784
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5765-
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5785+
id<MTLDevice> device = ctx_dev->mtl_device;
57665786

57675787
*total = device.recommendedMaxWorkingSetSize;
57685788
*free = *total - device.currentAllocatedSize;
5769-
5770-
ggml_backend_metal_device_rel(ctx_dev);
57715789
} else {
57725790
*free = 1;
57735791
*total = 1;
@@ -5845,7 +5863,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
58455863
}
58465864

58475865
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5848-
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5866+
5867+
GGML_ASSERT(ctx_dev->mtl_device != nil);
5868+
5869+
id<MTLDevice> device = ctx_dev->mtl_device;
58495870

58505871
// the buffer fits into the max buffer size allowed by the device
58515872
if (size_aligned <= device.maxBufferLength) {
@@ -5901,7 +5922,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
59015922
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
59025923
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
59035924
free(ctx);
5904-
ggml_backend_metal_device_rel(ctx_dev);
59055925
return NULL;
59065926
}
59075927

@@ -5915,8 +5935,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
59155935
}
59165936

59175937
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
5918-
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
5919-
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
5938+
return
5939+
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
5940+
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
59205941

59215942
GGML_UNUSED(dev);
59225943
}
@@ -6001,8 +6022,19 @@ static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t r
60016022
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
60026023
};
60036024

6025+
// called upon program exit
6026+
static void ggml_metal_cleanup(void) {
6027+
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
6028+
}
6029+
6030+
// TODO: make thread-safe
60046031
ggml_backend_reg_t ggml_backend_metal_reg(void) {
6005-
// TODO: make this thread-safe somehow?
6032+
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
6033+
6034+
// register cleanup callback
6035+
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6036+
atexit(ggml_metal_cleanup);
6037+
60066038
{
60076039
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
60086040
/* .api_version = */ GGML_BACKEND_API_VERSION,

0 commit comments

Comments
 (0)