Skip to content

Commit 5837b04

Browse files
committed
Add reference counting to handle_base struct to allow all UR handles to inherit common reference counting functionality. Update L0 and L0V2 to use this.
1 parent b0a40e2 commit 5837b04

27 files changed

+91
-131
lines changed

unified-runtime/source/adapters/level_zero/adapter.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ ur_result_t urAdapterGet(
675675
}
676676
*Adapters = GlobalAdapter;
677677

678-
if (GlobalAdapter->RefCount++ == 0) {
678+
if (GlobalAdapter->incrementRefCount() == 0) {
679679
adapterStateInit();
680680
}
681681
}
@@ -692,7 +692,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) {
692692

693693
// NOTE: This does not require guarding with a mutex; the instant the ref
694694
// count hits zero, both Get and Retain are UB.
695-
if (--GlobalAdapter->RefCount == 0) {
695+
if (GlobalAdapter->decrementRefCount() == 0) {
696696
auto result = adapterStateTeardown();
697697
#ifdef UR_STATIC_LEVEL_ZERO
698698
// Given static linking of the L0 Loader, we must delay the loader's
@@ -709,9 +709,9 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) {
709709
return UR_RESULT_SUCCESS;
710710
}
711711

712-
ur_result_t urAdapterRetain([[maybe_unused]] ur_adapter_handle_t Adapter) {
712+
ur_result_t urAdapterRetain(ur_adapter_handle_t) {
713713
assert(GlobalAdapter && GlobalAdapter == Adapter);
714-
GlobalAdapter->RefCount++;
714+
GlobalAdapter->incrementRefCount();
715715

716716
return UR_RESULT_SUCCESS;
717717
}
@@ -740,12 +740,12 @@ ur_result_t urAdapterGetInfo(ur_adapter_handle_t, ur_adapter_info_t PropName,
740740
case UR_ADAPTER_INFO_BACKEND:
741741
return ReturnValue(UR_BACKEND_LEVEL_ZERO);
742742
case UR_ADAPTER_INFO_REFERENCE_COUNT:
743-
return ReturnValue(GlobalAdapter->RefCount.load());
743+
return ReturnValue(GlobalAdapter->getRefCount());
744744
case UR_ADAPTER_INFO_VERSION: {
745745
#ifdef UR_ADAPTER_LEVEL_ZERO_V2
746746
uint32_t adapterVersion = 2;
747747
#else
748-
uint32_t adapterVersion = 1;
748+
uint32_t adapterVersion = 1;
749749
#endif
750750
return ReturnValue(adapterVersion);
751751
}

unified-runtime/source/adapters/level_zero/adapter.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "logger/ur_logger.hpp"
1313
#include "ur_interface_loader.hpp"
14-
#include <atomic>
1514
#include <loader/ur_loader.hpp>
1615
#include <loader/ze_loader.h>
1716
#include <optional>
@@ -26,7 +25,6 @@ class ur_legacy_sink;
2625

2726
struct ur_adapter_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter> {
2827
ur_adapter_handle_t_();
29-
std::atomic<uint32_t> RefCount = 0;
3028

3129
zes_pfnDriverGetDeviceByUuidExp_t getDeviceByUUIdFunctionPtr = nullptr;
3230
zes_pfnDriverGet_t getSysManDriversFunctionPtr = nullptr;

unified-runtime/source/adapters/level_zero/async_alloc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ ur_result_t urEnqueueUSMFreeExp(
247247
}
248248

249249
size_t size = umfPoolMallocUsableSize(hPool, Mem);
250-
(*Event)->RefCount.increment();
250+
(*Event)->incrementRefCount();
251251
usmPool->AsyncPool.insert(Mem, size, *Event, Queue);
252252

253253
// Signal that USM free event was finished

unified-runtime/source/adapters/level_zero/command_buffer.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -842,13 +842,13 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
842842

843843
ur_result_t
844844
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t CommandBuffer) {
845-
CommandBuffer->RefCount.increment();
845+
CommandBuffer->incrementRefCount();
846846
return UR_RESULT_SUCCESS;
847847
}
848848

849849
ur_result_t
850850
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t CommandBuffer) {
851-
if (!CommandBuffer->RefCount.decrementAndTest())
851+
if (!CommandBuffer->decrementRefCount() == 0)
852852
return UR_RESULT_SUCCESS;
853853

854854
UR_CALL(waitForOngoingExecution(CommandBuffer));
@@ -1643,7 +1643,7 @@ ur_result_t enqueueImmediateAppendPath(
16431643
if (CommandBuffer->CurrentSubmissionEvent) {
16441644
UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent));
16451645
}
1646-
(*Event)->RefCount.increment();
1646+
(*Event)->incrementRefCount();
16471647
CommandBuffer->CurrentSubmissionEvent = *Event;
16481648

16491649
UR_CALL(Queue->executeCommandList(CommandListHelper, false, false));
@@ -1726,7 +1726,7 @@ ur_result_t enqueueWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer,
17261726
if (CommandBuffer->CurrentSubmissionEvent) {
17271727
UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent));
17281728
}
1729-
(*Event)->RefCount.increment();
1729+
(*Event)->incrementRefCount();
17301730
CommandBuffer->CurrentSubmissionEvent = *Event;
17311731

17321732
UR_CALL(Queue->executeCommandList(SignalCommandList, false /*IsBlocking*/,
@@ -1850,7 +1850,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer,
18501850

18511851
switch (propName) {
18521852
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
1853-
return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()});
1853+
return ReturnValue(uint32_t{hCommandBuffer->getRefCount()});
18541854
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
18551855
ur_exp_command_buffer_desc_t Descriptor{};
18561856
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;

unified-runtime/source/adapters/level_zero/common.hpp

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -213,55 +213,9 @@ void zeParseError(ze_result_t ZeError, const char *&ErrorString);
213213
#define ZE_CALL_NOCHECK_NAME(ZeName, ZeArgs, callName) \
214214
ZeCall().doCall(ZeName ZeArgs, callName, #ZeArgs, false)
215215

216-
// This wrapper around std::atomic is created to limit operations with reference
217-
// counter and to make allowed operations more transparent in terms of
218-
// thread-safety in the plugin. increment() and load() operations do not need a
219-
// mutex guard around them since the underlying data is already atomic.
220-
// decrementAndTest() method is used to guard a code which needs to be
221-
// executed when object's ref count becomes zero after release. This method also
222-
// doesn't need a mutex guard because decrement operation is atomic and only one
223-
// thread can reach ref count equal to zero, i.e. only a single thread can pass
224-
// through this check.
225-
struct ReferenceCounter {
226-
ReferenceCounter() : RefCount{1} {}
227-
228-
// Reset the counter to the initial value.
229-
void reset() { RefCount = 1; }
230-
231-
// Used when retaining an object.
232-
void increment() { RefCount++; }
233-
234-
// Supposed to be used in ur*GetInfo* methods where ref count value is
235-
// requested.
236-
uint32_t load() { return RefCount.load(); }
237-
238-
// This method allows to guard a code which needs to be executed when object's
239-
// ref count becomes zero after release. It is important to notice that only a
240-
// single thread can pass through this check. This is true because of several
241-
// reasons:
242-
// 1. Decrement operation is executed atomically.
243-
// 2. It is not allowed to retain an object after its refcount reaches zero.
244-
// 3. It is not allowed to release an object more times than the value of
245-
// the ref count.
246-
// 2. and 3. basically means that we can't use an object at all as soon as its
247-
// refcount reaches zero. Using this check guarantees that code for deleting
248-
// an object and releasing its resources is executed once by a single thread
249-
// and we don't need to use any mutexes to guard access to this object in the
250-
// scope after this check. Of course if we access another objects in this code
251-
// (not the one which is being deleted) then access to these objects must be
252-
// guarded, for example with a mutex.
253-
bool decrementAndTest() { return --RefCount == 0; }
254-
255-
private:
256-
std::atomic<uint32_t> RefCount;
257-
};
258-
259216
// Base class to store common data
260217
struct ur_object : ur::handle_base<ur::level_zero::ddi_getter> {
261-
ur_object() : handle_base(), RefCount{} {}
262-
263-
// Must be atomic to prevent data race when incrementing/decrementing.
264-
ReferenceCounter RefCount;
218+
ur_object() : handle_base() {}
265219

266220
// This mutex protects accesses to all the non-const member variables.
267221
// Exclusive access is required to modify any of these members.

unified-runtime/source/adapters/level_zero/context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ ur_result_t urContextRetain(
6161

6262
/// [in] handle of the context to get a reference of.
6363
ur_context_handle_t Context) {
64-
Context->RefCount.increment();
64+
Context->incrementRefCount();
6565
return UR_RESULT_SUCCESS;
6666
}
6767

@@ -113,7 +113,7 @@ ur_result_t urContextGetInfo(
113113
case UR_CONTEXT_INFO_NUM_DEVICES:
114114
return ReturnValue(uint32_t(Context->Devices.size()));
115115
case UR_CONTEXT_INFO_REFERENCE_COUNT:
116-
return ReturnValue(uint32_t{Context->RefCount.load()});
116+
return ReturnValue(uint32_t{Context->getRefCount()});
117117
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
118118
// 2D USM memcpy is supported.
119119
return ReturnValue(uint8_t{UseMemcpy2DOperations});
@@ -251,7 +251,7 @@ ur_device_handle_t ur_context_handle_t_::getRootDevice() const {
251251
// from the list of tracked contexts.
252252
ur_result_t ContextReleaseHelper(ur_context_handle_t Context) {
253253

254-
if (!Context->RefCount.decrementAndTest())
254+
if (!Context->decrementRefCount() == 0)
255255
return UR_RESULT_SUCCESS;
256256

257257
if (IndirectAccessTrackingEnabled) {

unified-runtime/source/adapters/level_zero/device.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ ur_result_t urDeviceGetInfo(
470470
return ReturnValue((uint32_t)Device->SubDevices.size());
471471
}
472472
case UR_DEVICE_INFO_REFERENCE_COUNT:
473-
return ReturnValue(uint32_t{Device->RefCount.load()});
473+
return ReturnValue(uint32_t{Device->getRefCount()});
474474
case UR_DEVICE_INFO_SUPPORTED_PARTITIONS: {
475475
// SYCL spec says: if this SYCL device cannot be partitioned into at least
476476
// two sub devices then the returned vector must be empty.
@@ -1643,15 +1643,15 @@ ur_result_t urDeviceGetGlobalTimestamps(
16431643
ur_result_t urDeviceRetain(ur_device_handle_t Device) {
16441644
// The root-device ref-count remains unchanged (always 1).
16451645
if (Device->isSubDevice()) {
1646-
Device->RefCount.increment();
1646+
Device->incrementRefCount();
16471647
}
16481648
return UR_RESULT_SUCCESS;
16491649
}
16501650

16511651
ur_result_t urDeviceRelease(ur_device_handle_t Device) {
16521652
// Root devices are destroyed during the piTearDown process.
16531653
if (Device->isSubDevice()) {
1654-
if (Device->RefCount.decrementAndTest()) {
1654+
if (Device->decrementRefCount() == 0) {
16551655
delete Device;
16561656
}
16571657
}

unified-runtime/source/adapters/level_zero/event.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ ur_result_t urEventGetInfo(
505505
return ReturnValue(Result);
506506
}
507507
case UR_EVENT_INFO_REFERENCE_COUNT: {
508-
return ReturnValue(Event->RefCount.load());
508+
return ReturnValue(Event->getRefCount());
509509
}
510510
default:
511511
UR_LOG(ERR, "Unsupported ParamName in urEventGetInfo: ParamName={}(0x{})",
@@ -874,7 +874,7 @@ ur_result_t
874874
/// [in] handle of the event object
875875
urEventRetain(/** [in] handle of the event object */ ur_event_handle_t Event) {
876876
Event->RefCountExternal++;
877-
Event->RefCount.increment();
877+
Event->incrementRefCount();
878878

879879
return UR_RESULT_SUCCESS;
880880
}
@@ -1088,7 +1088,7 @@ ur_event_handle_t_::~ur_event_handle_t_() {
10881088

10891089
ur_result_t urEventReleaseInternal(ur_event_handle_t Event,
10901090
bool *isEventDeleted) {
1091-
if (!Event->RefCount.decrementAndTest())
1091+
if (!Event->decrementRefCount() == 0)
10921092
return UR_RESULT_SUCCESS;
10931093

10941094
if (Event->OriginAllocEvent) {
@@ -1429,7 +1429,7 @@ ur_result_t ur_event_handle_t_::reset() {
14291429
CommandType = UR_EXT_COMMAND_TYPE_USER;
14301430
WaitList = {};
14311431
RefCountExternal = 0;
1432-
RefCount.reset();
1432+
resetRefCount();
14331433
CommandList = std::nullopt;
14341434
completionBatch = std::nullopt;
14351435
OriginAllocEvent = nullptr;
@@ -1524,7 +1524,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
15241524
std::shared_lock<ur_shared_mutex> Lock(CurQueue->LastCommandEvent->Mutex);
15251525
this->ZeEventList[0] = CurQueue->LastCommandEvent->ZeEvent;
15261526
this->UrEventList[0] = CurQueue->LastCommandEvent;
1527-
this->UrEventList[0]->RefCount.increment();
1527+
this->UrEventList[0]->incrementRefCount();
15281528
TmpListLength = 1;
15291529
} else if (EventListLength > 0) {
15301530
this->ZeEventList = new ze_event_handle_t[EventListLength];
@@ -1660,7 +1660,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
16601660
IsInternal, IsMultiDevice));
16611661
MultiDeviceZeEvent = MultiDeviceEvent->ZeEvent;
16621662
const auto &ZeCommandList = CommandList->first;
1663-
EventList[I]->RefCount.increment();
1663+
EventList[I]->incrementRefCount();
16641664

16651665
// Append a Barrier to wait on the original event while signalling the
16661666
// new multi device event.
@@ -1676,11 +1676,11 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
16761676

16771677
this->ZeEventList[TmpListLength] = MultiDeviceZeEvent;
16781678
this->UrEventList[TmpListLength] = MultiDeviceEvent;
1679-
this->UrEventList[TmpListLength]->RefCount.increment();
1679+
this->UrEventList[TmpListLength]->incrementRefCount();
16801680
} else {
16811681
this->ZeEventList[TmpListLength] = EventList[I]->ZeEvent;
16821682
this->UrEventList[TmpListLength] = EventList[I];
1683-
this->UrEventList[TmpListLength]->RefCount.increment();
1683+
this->UrEventList[TmpListLength]->incrementRefCount();
16841684
}
16851685

16861686
if (QueueLock.has_value()) {

unified-runtime/source/adapters/level_zero/kernel.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ ur_result_t urKernelGetInfo(
787787
}
788788
}
789789
case UR_KERNEL_INFO_REFERENCE_COUNT:
790-
return ReturnValue(uint32_t{Kernel->RefCount.load()});
790+
return ReturnValue(uint32_t{Kernel->getRefCount()});
791791
case UR_KERNEL_INFO_ATTRIBUTES:
792792
try {
793793
uint32_t Size;
@@ -938,15 +938,15 @@ ur_result_t urKernelGetSubGroupInfo(
938938
ur_result_t urKernelRetain(
939939
/// [in] handle for the Kernel to retain
940940
ur_kernel_handle_t Kernel) {
941-
Kernel->RefCount.increment();
941+
Kernel->incrementRefCount();
942942

943943
return UR_RESULT_SUCCESS;
944944
}
945945

946946
ur_result_t urKernelRelease(
947947
/// [in] handle for the Kernel to release
948948
ur_kernel_handle_t Kernel) {
949-
if (!Kernel->RefCount.decrementAndTest())
949+
if (!Kernel->decrementRefCount() == 0)
950950
return UR_RESULT_SUCCESS;
951951

952952
auto KernelProgram = Kernel->Program;

unified-runtime/source/adapters/level_zero/memory.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ ur_result_t urEnqueueMemBufferMap(
10521052

10531053
// Add the event to the command list.
10541054
CommandList->second.append(reinterpret_cast<ur_event_handle_t>(*Event));
1055-
(*Event)->RefCount.increment();
1055+
(*Event)->incrementRefCount();
10561056

10571057
const auto &ZeCommandList = CommandList->first;
10581058
const auto &WaitList = (*Event)->WaitList;
@@ -1183,7 +1183,7 @@ ur_result_t urEnqueueMemUnmap(
11831183
nullptr /*ForcedCmdQueue*/));
11841184

11851185
CommandList->second.append(reinterpret_cast<ur_event_handle_t>(*Event));
1186-
(*Event)->RefCount.increment();
1186+
(*Event)->incrementRefCount();
11871187

11881188
const auto &ZeCommandList = CommandList->first;
11891189

@@ -1635,14 +1635,14 @@ ur_result_t urMemBufferCreate(
16351635
ur_result_t urMemRetain(
16361636
/// [in] handle of the memory object to get access
16371637
ur_mem_handle_t Mem) {
1638-
Mem->RefCount.increment();
1638+
Mem->incrementRefCount();
16391639
return UR_RESULT_SUCCESS;
16401640
}
16411641

16421642
ur_result_t urMemRelease(
16431643
/// [in] handle of the memory object to release
16441644
ur_mem_handle_t Mem) {
1645-
if (!Mem->RefCount.decrementAndTest())
1645+
if (!Mem->decrementRefCount() == 0)
16461646
return UR_RESULT_SUCCESS;
16471647

16481648
if (Mem->isImage()) {
@@ -1848,7 +1848,7 @@ ur_result_t urMemGetInfo(
18481848
return ReturnValue(size_t{Buffer->Size});
18491849
}
18501850
case UR_MEM_INFO_REFERENCE_COUNT: {
1851-
return ReturnValue(Buffer->RefCount.load());
1851+
return ReturnValue(Buffer->getRefCount());
18521852
}
18531853
default: {
18541854
return UR_RESULT_ERROR_INVALID_ENUMERATION;

unified-runtime/source/adapters/level_zero/memory.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ struct ur_buffer final : ur_mem_handle_t_ {
116116
: ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), Size(Size),
117117
SubBuffer{{Parent, Origin}} {
118118
// Retain the Parent Buffer due to the Creation of the SubBuffer.
119-
Parent->RefCount.increment();
119+
Parent->incrementRefCount();
120120
}
121121

122122
// Interop-buffer constructor

unified-runtime/source/adapters/level_zero/physical_mem.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ ur_result_t urPhysicalMemCreate(
4242
}
4343

4444
ur_result_t urPhysicalMemRetain(ur_physical_mem_handle_t hPhysicalMem) {
45-
hPhysicalMem->RefCount.increment();
45+
hPhysicalMem->incrementRefCount();
4646
return UR_RESULT_SUCCESS;
4747
}
4848

4949
ur_result_t urPhysicalMemRelease(ur_physical_mem_handle_t hPhysicalMem) {
50-
if (!hPhysicalMem->RefCount.decrementAndTest())
50+
if (!hPhysicalMem->decrementRefCount() == 0)
5151
return UR_RESULT_SUCCESS;
5252

5353
if (checkL0LoaderTeardown()) {
@@ -68,7 +68,7 @@ ur_result_t urPhysicalMemGetInfo(ur_physical_mem_handle_t hPhysicalMem,
6868

6969
switch (propName) {
7070
case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: {
71-
return ReturnValue(hPhysicalMem->RefCount.load());
71+
return ReturnValue(hPhysicalMem->getRefCount());
7272
}
7373
default:
7474
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;

0 commit comments

Comments
 (0)