Skip to content

Commit e98d8a0

Browse files
authored
[UR][L0][L0v2] Refactor reference counting in UR L0 and L0v2 (#19057)
For #18644 Most UR adapters had their own reference counting of some sort. This adds a new `RefCount` class and refactors adapter code so all adapters can share the same code for reference counting. This PR handles L0/L0V2 and I will open more PRs for each adapter in turn.
1 parent c480967 commit e98d8a0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+190
-143
lines changed

unified-runtime/source/adapters/level_zero/adapter.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ Behavior Summary:
296296
SysMan initialization is skipped.
297297
*/
298298
ur_adapter_handle_t_::ur_adapter_handle_t_()
299-
: handle_base(), logger(logger::get_logger("level_zero")) {
299+
: handle_base(), logger(logger::get_logger("level_zero")), RefCount(0) {
300300
ZeInitDriversResult = ZE_RESULT_ERROR_UNINITIALIZED;
301301
ZeInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
302302
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
@@ -675,7 +675,7 @@ ur_result_t urAdapterGet(
675675
}
676676
*Adapters = GlobalAdapter;
677677

678-
if (GlobalAdapter->RefCount++ == 0) {
678+
if (GlobalAdapter->RefCount.retain() == 0) {
679679
adapterStateInit();
680680
}
681681
}
@@ -692,7 +692,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) {
692692

693693
// NOTE: This does not require guarding with a mutex; the instant the ref
694694
// count hits zero, both Get and Retain are UB.
695-
if (--GlobalAdapter->RefCount == 0) {
695+
if (GlobalAdapter->RefCount.release()) {
696696
auto result = adapterStateTeardown();
697697
#ifdef UR_STATIC_LEVEL_ZERO
698698
// Given static linking of the L0 Loader, we must delay the loader's
@@ -711,7 +711,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) {
711711

712712
ur_result_t urAdapterRetain([[maybe_unused]] ur_adapter_handle_t Adapter) {
713713
assert(GlobalAdapter && GlobalAdapter == Adapter);
714-
GlobalAdapter->RefCount++;
714+
GlobalAdapter->RefCount.retain();
715715

716716
return UR_RESULT_SUCCESS;
717717
}
@@ -740,7 +740,7 @@ ur_result_t urAdapterGetInfo(ur_adapter_handle_t, ur_adapter_info_t PropName,
740740
case UR_ADAPTER_INFO_BACKEND:
741741
return ReturnValue(UR_BACKEND_LEVEL_ZERO);
742742
case UR_ADAPTER_INFO_REFERENCE_COUNT:
743-
return ReturnValue(GlobalAdapter->RefCount.load());
743+
return ReturnValue(GlobalAdapter->RefCount.getCount());
744744
case UR_ADAPTER_INFO_VERSION: {
745745
#ifdef UR_ADAPTER_LEVEL_ZERO_V2
746746
uint32_t adapterVersion = 2;

unified-runtime/source/adapters/level_zero/adapter.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12+
#include "common/ur_ref_count.hpp"
1213
#include "logger/ur_logger.hpp"
1314
#include "ur_interface_loader.hpp"
14-
#include <atomic>
1515
#include <loader/ur_loader.hpp>
1616
#include <loader/ze_loader.h>
1717
#include <optional>
@@ -26,7 +26,6 @@ class ur_legacy_sink;
2626

2727
struct ur_adapter_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter> {
2828
ur_adapter_handle_t_();
29-
std::atomic<uint32_t> RefCount = 0;
3029

3130
zes_pfnDriverGetDeviceByUuidExp_t getDeviceByUUIdFunctionPtr = nullptr;
3231
zes_pfnDriverGet_t getSysManDriversFunctionPtr = nullptr;
@@ -45,6 +44,8 @@ struct ur_adapter_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter> {
4544
ZeCache<Result<PlatformVec>> PlatformCache;
4645
logger::Logger &logger;
4746
HMODULE processHandle = nullptr;
47+
48+
ur::RefCount RefCount;
4849
};
4950

5051
extern ur_adapter_handle_t_ *GlobalAdapter;

unified-runtime/source/adapters/level_zero/async_alloc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ ur_result_t urEnqueueUSMFreeExp(
247247
}
248248

249249
size_t size = umfPoolMallocUsableSize(hPool, Mem);
250-
(*Event)->RefCount.increment();
250+
(*Event)->RefCount.retain();
251251
usmPool->AsyncPool.insert(Mem, size, *Event, Queue);
252252

253253
// Signal that USM free event was finished

unified-runtime/source/adapters/level_zero/command_buffer.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -840,13 +840,13 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
840840

841841
ur_result_t
842842
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t CommandBuffer) {
843-
CommandBuffer->RefCount.increment();
843+
CommandBuffer->RefCount.retain();
844844
return UR_RESULT_SUCCESS;
845845
}
846846

847847
ur_result_t
848848
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t CommandBuffer) {
849-
if (!CommandBuffer->RefCount.decrementAndTest())
849+
if (!CommandBuffer->RefCount.release())
850850
return UR_RESULT_SUCCESS;
851851

852852
UR_CALL(waitForOngoingExecution(CommandBuffer));
@@ -1641,7 +1641,7 @@ ur_result_t enqueueImmediateAppendPath(
16411641
if (CommandBuffer->CurrentSubmissionEvent) {
16421642
UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent));
16431643
}
1644-
(*Event)->RefCount.increment();
1644+
(*Event)->RefCount.retain();
16451645
CommandBuffer->CurrentSubmissionEvent = *Event;
16461646

16471647
UR_CALL(Queue->executeCommandList(CommandListHelper, false, false));
@@ -1724,7 +1724,7 @@ ur_result_t enqueueWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer,
17241724
if (CommandBuffer->CurrentSubmissionEvent) {
17251725
UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent));
17261726
}
1727-
(*Event)->RefCount.increment();
1727+
(*Event)->RefCount.retain();
17281728
CommandBuffer->CurrentSubmissionEvent = *Event;
17291729

17301730
UR_CALL(Queue->executeCommandList(SignalCommandList, false /*IsBlocking*/,
@@ -1848,7 +1848,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer,
18481848

18491849
switch (propName) {
18501850
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
1851-
return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()});
1851+
return ReturnValue(uint32_t{hCommandBuffer->RefCount.getCount()});
18521852
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
18531853
ur_exp_command_buffer_desc_t Descriptor{};
18541854
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;

unified-runtime/source/adapters/level_zero/command_buffer.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "common.hpp"
1919

20+
#include "common/ur_ref_count.hpp"
2021
#include "context.hpp"
2122
#include "kernel.hpp"
2223
#include "queue.hpp"
@@ -149,4 +150,6 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object {
149150
// Track handle objects to free when command-buffer is destroyed.
150151
std::vector<std::unique_ptr<ur_exp_command_buffer_command_handle_t_>>
151152
CommandHandles;
153+
154+
ur::RefCount RefCount;
152155
};

unified-runtime/source/adapters/level_zero/common.hpp

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <level_zero/ze_intel_gpu.h>
3535
#include <umf_pools/disjoint_pool_config_parser.hpp>
3636

37+
#include "common/ur_ref_count.hpp"
3738
#include "logger/ur_logger.hpp"
3839
#include "ur_interface_loader.hpp"
3940

@@ -220,55 +221,9 @@ void zeParseError(ze_result_t ZeError, const char *&ErrorString);
220221
#define ZE_CALL_NOCHECK_NAME(ZeName, ZeArgs, callName) \
221222
ZeCall().doCall(ZeName ZeArgs, callName, #ZeArgs, false)
222223

223-
// This wrapper around std::atomic is created to limit operations with reference
224-
// counter and to make allowed operations more transparent in terms of
225-
// thread-safety in the plugin. increment() and load() operations do not need a
226-
// mutex guard around them since the underlying data is already atomic.
227-
// decrementAndTest() method is used to guard a code which needs to be
228-
// executed when object's ref count becomes zero after release. This method also
229-
// doesn't need a mutex guard because decrement operation is atomic and only one
230-
// thread can reach ref count equal to zero, i.e. only a single thread can pass
231-
// through this check.
232-
struct ReferenceCounter {
233-
ReferenceCounter() : RefCount{1} {}
234-
235-
// Reset the counter to the initial value.
236-
void reset() { RefCount = 1; }
237-
238-
// Used when retaining an object.
239-
void increment() { RefCount++; }
240-
241-
// Supposed to be used in ur*GetInfo* methods where ref count value is
242-
// requested.
243-
uint32_t load() { return RefCount.load(); }
244-
245-
// This method allows to guard a code which needs to be executed when object's
246-
// ref count becomes zero after release. It is important to notice that only a
247-
// single thread can pass through this check. This is true because of several
248-
// reasons:
249-
// 1. Decrement operation is executed atomically.
250-
// 2. It is not allowed to retain an object after its refcount reaches zero.
251-
// 3. It is not allowed to release an object more times than the value of
252-
// the ref count.
253-
// 2. and 3. basically means that we can't use an object at all as soon as its
254-
// refcount reaches zero. Using this check guarantees that code for deleting
255-
// an object and releasing its resources is executed once by a single thread
256-
// and we don't need to use any mutexes to guard access to this object in the
257-
// scope after this check. Of course if we access another objects in this code
258-
// (not the one which is being deleted) then access to these objects must be
259-
// guarded, for example with a mutex.
260-
bool decrementAndTest() { return --RefCount == 0; }
261-
262-
private:
263-
std::atomic<uint32_t> RefCount;
264-
};
265-
266224
// Base class to store common data
267225
struct ur_object : ur::handle_base<ur::level_zero::ddi_getter> {
268-
ur_object() : handle_base(), RefCount{} {}
269-
270-
// Must be atomic to prevent data race when incrementing/decrementing.
271-
ReferenceCounter RefCount;
226+
ur_object() : handle_base() {}
272227

273228
// This mutex protects accesses to all the non-const member variables.
274229
// Exclusive access is required to modify any of these members.
@@ -303,6 +258,8 @@ struct MemAllocRecord : ur_object {
303258
// TODO: this should go away when memory isolation issue is fixed in the Level
304259
// Zero runtime.
305260
ur_context_handle_t Context;
261+
262+
ur::RefCount RefCount;
306263
};
307264

308265
extern usm::DisjointPoolAllConfigs DisjointPoolConfigInstance;

unified-runtime/source/adapters/level_zero/context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ ur_result_t urContextRetain(
6161

6262
/// [in] handle of the context to get a reference of.
6363
ur_context_handle_t Context) {
64-
Context->RefCount.increment();
64+
Context->RefCount.retain();
6565
return UR_RESULT_SUCCESS;
6666
}
6767

@@ -113,7 +113,7 @@ ur_result_t urContextGetInfo(
113113
case UR_CONTEXT_INFO_NUM_DEVICES:
114114
return ReturnValue(uint32_t(Context->Devices.size()));
115115
case UR_CONTEXT_INFO_REFERENCE_COUNT:
116-
return ReturnValue(uint32_t{Context->RefCount.load()});
116+
return ReturnValue(uint32_t{Context->RefCount.getCount()});
117117
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
118118
// 2D USM memcpy is supported.
119119
return ReturnValue(uint8_t{UseMemcpy2DOperations});
@@ -251,7 +251,7 @@ ur_device_handle_t ur_context_handle_t_::getRootDevice() const {
251251
// from the list of tracked contexts.
252252
ur_result_t ContextReleaseHelper(ur_context_handle_t Context) {
253253

254-
if (!Context->RefCount.decrementAndTest())
254+
if (!Context->RefCount.release())
255255
return UR_RESULT_SUCCESS;
256256

257257
if (IndirectAccessTrackingEnabled) {

unified-runtime/source/adapters/level_zero/context.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "queue.hpp"
2727
#include "usm.hpp"
2828

29+
#include "common/ur_ref_count.hpp"
2930
#include <umf_helpers.hpp>
3031

3132
struct l0_command_list_cache_info {
@@ -358,6 +359,8 @@ struct ur_context_handle_t_ : ur_object {
358359
// Get handle to the L0 context
359360
ze_context_handle_t getZeHandle() const;
360361

362+
ur::RefCount RefCount;
363+
361364
private:
362365
enum EventFlags {
363366
EVENT_FLAG_HOST_VISIBLE = UR_BIT(0),

unified-runtime/source/adapters/level_zero/device.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ ur_result_t urDeviceGetInfo(
470470
return ReturnValue((uint32_t)Device->SubDevices.size());
471471
}
472472
case UR_DEVICE_INFO_REFERENCE_COUNT:
473-
return ReturnValue(uint32_t{Device->RefCount.load()});
473+
return ReturnValue(uint32_t{Device->RefCount.getCount()});
474474
case UR_DEVICE_INFO_SUPPORTED_PARTITIONS: {
475475
// SYCL spec says: if this SYCL device cannot be partitioned into at least
476476
// two sub devices then the returned vector must be empty.
@@ -1666,15 +1666,15 @@ ur_result_t urDeviceGetGlobalTimestamps(
16661666
ur_result_t urDeviceRetain(ur_device_handle_t Device) {
16671667
// The root-device ref-count remains unchanged (always 1).
16681668
if (Device->isSubDevice()) {
1669-
Device->RefCount.increment();
1669+
Device->RefCount.retain();
16701670
}
16711671
return UR_RESULT_SUCCESS;
16721672
}
16731673

16741674
ur_result_t urDeviceRelease(ur_device_handle_t Device) {
16751675
// Root devices are destroyed during the piTearDown process.
16761676
if (Device->isSubDevice()) {
1677-
if (Device->RefCount.decrementAndTest()) {
1677+
if (Device->RefCount.release()) {
16781678
delete Device;
16791679
}
16801680
}

unified-runtime/source/adapters/level_zero/device.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "adapters/level_zero/platform.hpp"
2222
#include "common.hpp"
23+
#include "common/ur_ref_count.hpp"
2324
#include <ur/ur.hpp>
2425
#include <ur_ddi.h>
2526
#include <ze_api.h>
@@ -242,6 +243,8 @@ struct ur_device_handle_t_ : ur_object {
242243

243244
// unique ephemeral identifer of the device in the adapter
244245
std::optional<DeviceId> Id;
246+
247+
ur::RefCount RefCount;
245248
};
246249

247250
inline std::vector<ur_device_handle_t>

unified-runtime/source/adapters/level_zero/event.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ ur_result_t urEventGetInfo(
505505
return ReturnValue(Result);
506506
}
507507
case UR_EVENT_INFO_REFERENCE_COUNT: {
508-
return ReturnValue(Event->RefCount.load());
508+
return ReturnValue(Event->RefCount.getCount());
509509
}
510510
default:
511511
UR_LOG(ERR, "Unsupported ParamName in urEventGetInfo: ParamName={}(0x{})",
@@ -874,7 +874,7 @@ ur_result_t
874874
/// [in] handle of the event object
875875
urEventRetain(/** [in] handle of the event object */ ur_event_handle_t Event) {
876876
Event->RefCountExternal++;
877-
Event->RefCount.increment();
877+
Event->RefCount.retain();
878878

879879
return UR_RESULT_SUCCESS;
880880
}
@@ -1088,7 +1088,7 @@ ur_event_handle_t_::~ur_event_handle_t_() {
10881088

10891089
ur_result_t urEventReleaseInternal(ur_event_handle_t Event,
10901090
bool *isEventDeleted) {
1091-
if (!Event->RefCount.decrementAndTest())
1091+
if (!Event->RefCount.release())
10921092
return UR_RESULT_SUCCESS;
10931093

10941094
if (Event->OriginAllocEvent) {
@@ -1524,7 +1524,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
15241524
std::shared_lock<ur_shared_mutex> Lock(CurQueue->LastCommandEvent->Mutex);
15251525
this->ZeEventList[0] = CurQueue->LastCommandEvent->ZeEvent;
15261526
this->UrEventList[0] = CurQueue->LastCommandEvent;
1527-
this->UrEventList[0]->RefCount.increment();
1527+
this->UrEventList[0]->RefCount.retain();
15281528
TmpListLength = 1;
15291529
} else if (EventListLength > 0) {
15301530
this->ZeEventList = new ze_event_handle_t[EventListLength];
@@ -1660,7 +1660,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
16601660
IsInternal, IsMultiDevice));
16611661
MultiDeviceZeEvent = MultiDeviceEvent->ZeEvent;
16621662
const auto &ZeCommandList = CommandList->first;
1663-
EventList[I]->RefCount.increment();
1663+
EventList[I]->RefCount.retain();
16641664

16651665
// Append a Barrier to wait on the original event while signalling the
16661666
// new multi device event.
@@ -1676,11 +1676,11 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
16761676

16771677
this->ZeEventList[TmpListLength] = MultiDeviceZeEvent;
16781678
this->UrEventList[TmpListLength] = MultiDeviceEvent;
1679-
this->UrEventList[TmpListLength]->RefCount.increment();
1679+
this->UrEventList[TmpListLength]->RefCount.retain();
16801680
} else {
16811681
this->ZeEventList[TmpListLength] = EventList[I]->ZeEvent;
16821682
this->UrEventList[TmpListLength] = EventList[I];
1683-
this->UrEventList[TmpListLength]->RefCount.increment();
1683+
this->UrEventList[TmpListLength]->RefCount.retain();
16841684
}
16851685

16861686
if (QueueLock.has_value()) {

unified-runtime/source/adapters/level_zero/event.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <zes_api.h>
2626

2727
#include "common.hpp"
28+
#include "common/ur_ref_count.hpp"
2829
#include "queue.hpp"
2930
#include "ur_api.h"
3031

@@ -262,6 +263,8 @@ struct ur_event_handle_t_ : ur_object {
262263
// Used only for asynchronous allocations. This is the event originally used
263264
// on async free to indicate when the allocation can be used again.
264265
ur_event_handle_t OriginAllocEvent = nullptr;
266+
267+
ur::RefCount RefCount;
265268
};
266269

267270
// Helper function to implement zeHostSynchronize.

0 commit comments

Comments
 (0)