diff --git a/unified-runtime/source/adapters/level_zero/adapter.cpp b/unified-runtime/source/adapters/level_zero/adapter.cpp index 6b23d0161a4f5..0c1abe7667bd3 100644 --- a/unified-runtime/source/adapters/level_zero/adapter.cpp +++ b/unified-runtime/source/adapters/level_zero/adapter.cpp @@ -296,7 +296,7 @@ Behavior Summary: SysMan initialization is skipped. */ ur_adapter_handle_t_::ur_adapter_handle_t_() - : handle_base(), logger(logger::get_logger("level_zero")) { + : handle_base(), logger(logger::get_logger("level_zero")), RefCount(0) { ZeInitDriversResult = ZE_RESULT_ERROR_UNINITIALIZED; ZeInitResult = ZE_RESULT_ERROR_UNINITIALIZED; ZesResult = ZE_RESULT_ERROR_UNINITIALIZED; @@ -675,7 +675,7 @@ ur_result_t urAdapterGet( } *Adapters = GlobalAdapter; - if (GlobalAdapter->RefCount++ == 0) { + if (GlobalAdapter->RefCount.retain() == 0) { adapterStateInit(); } } @@ -692,7 +692,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) { // NOTE: This does not require guarding with a mutex; the instant the ref // count hits zero, both Get and Retain are UB. - if (--GlobalAdapter->RefCount == 0) { + if (GlobalAdapter->RefCount.release()) { auto result = adapterStateTeardown(); #ifdef UR_STATIC_LEVEL_ZERO // Given static linking of the L0 Loader, we must delay the loader's @@ -711,7 +711,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) { ur_result_t urAdapterRetain([[maybe_unused]] ur_adapter_handle_t Adapter) { assert(GlobalAdapter && GlobalAdapter == Adapter); - GlobalAdapter->RefCount++; + GlobalAdapter->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -740,7 +740,7 @@ ur_result_t urAdapterGetInfo(ur_adapter_handle_t, ur_adapter_info_t PropName, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_LEVEL_ZERO); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(GlobalAdapter->RefCount.load()); + return ReturnValue(GlobalAdapter->RefCount.getCount()); case UR_ADAPTER_INFO_VERSION: { #ifdef UR_ADAPTER_LEVEL_ZERO_V2 uint32_t adapterVersion = 2; diff --git a/unified-runtime/source/adapters/level_zero/adapter.hpp b/unified-runtime/source/adapters/level_zero/adapter.hpp index bb0a9058bce1b..890e39d296672 100644 --- a/unified-runtime/source/adapters/level_zero/adapter.hpp +++ b/unified-runtime/source/adapters/level_zero/adapter.hpp @@ -9,9 +9,9 @@ //===----------------------------------------------------------------------===// #pragma once +#include "common/ur_ref_count.hpp" #include "logger/ur_logger.hpp" #include "ur_interface_loader.hpp" -#include #include #include #include @@ -26,7 +26,6 @@ class ur_legacy_sink; struct ur_adapter_handle_t_ : ur::handle_base { ur_adapter_handle_t_(); - std::atomic RefCount = 0; zes_pfnDriverGetDeviceByUuidExp_t getDeviceByUUIdFunctionPtr = nullptr; zes_pfnDriverGet_t getSysManDriversFunctionPtr = nullptr; @@ -45,6 +44,8 @@ struct ur_adapter_handle_t_ : ur::handle_base { ZeCache> PlatformCache; logger::Logger &logger; HMODULE processHandle = nullptr; + + ur::RefCount RefCount; }; extern ur_adapter_handle_t_ *GlobalAdapter; diff --git a/unified-runtime/source/adapters/level_zero/async_alloc.cpp b/unified-runtime/source/adapters/level_zero/async_alloc.cpp index 204b43c3bcc79..6d55fda0f20b1 100644 --- a/unified-runtime/source/adapters/level_zero/async_alloc.cpp +++ b/unified-runtime/source/adapters/level_zero/async_alloc.cpp @@ -247,7 +247,7 @@ ur_result_t urEnqueueUSMFreeExp( } size_t size = umfPoolMallocUsableSize(hPool, Mem); - (*Event)->RefCount.increment(); + (*Event)->RefCount.retain(); usmPool->AsyncPool.insert(Mem, size, *Event, Queue); // Signal that USM free event was finished diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/command_buffer.cpp index 0773eadad7fd8..1e68069db51b2 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.cpp @@ -840,13 +840,13 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device, ur_result_t urCommandBufferRetainExp(ur_exp_command_buffer_handle_t CommandBuffer) { - CommandBuffer->RefCount.increment(); + CommandBuffer->RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t CommandBuffer) { - if (!CommandBuffer->RefCount.decrementAndTest()) + if (!CommandBuffer->RefCount.release()) return UR_RESULT_SUCCESS; UR_CALL(waitForOngoingExecution(CommandBuffer)); @@ -1641,7 +1641,7 @@ ur_result_t enqueueImmediateAppendPath( if (CommandBuffer->CurrentSubmissionEvent) { UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent)); } - (*Event)->RefCount.increment(); + (*Event)->RefCount.retain(); CommandBuffer->CurrentSubmissionEvent = *Event; UR_CALL(Queue->executeCommandList(CommandListHelper, false, false)); @@ -1724,7 +1724,7 @@ ur_result_t enqueueWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer, if (CommandBuffer->CurrentSubmissionEvent) { UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent)); } - (*Event)->RefCount.increment(); + (*Event)->RefCount.retain(); CommandBuffer->CurrentSubmissionEvent = *Event; UR_CALL(Queue->executeCommandList(SignalCommandList, false /*IsBlocking*/, @@ -1848,7 +1848,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()}); + return ReturnValue(uint32_t{hCommandBuffer->RefCount.getCount()}); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.hpp b/unified-runtime/source/adapters/level_zero/command_buffer.hpp index f7b62a9c8dd1e..9fe6ef83ab5bd 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.hpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.hpp @@ -17,6 +17,7 @@ #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "kernel.hpp" #include "queue.hpp" @@ -149,4 +150,6 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { // Track handle objects to free when command-buffer is destroyed. std::vector> CommandHandles; + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/common.hpp b/unified-runtime/source/adapters/level_zero/common.hpp index 19e22de14605d..cfb19f4977004 100644 --- a/unified-runtime/source/adapters/level_zero/common.hpp +++ b/unified-runtime/source/adapters/level_zero/common.hpp @@ -34,6 +34,7 @@ #include #include +#include "common/ur_ref_count.hpp" #include "logger/ur_logger.hpp" #include "ur_interface_loader.hpp" @@ -220,55 +221,9 @@ void zeParseError(ze_result_t ZeError, const char *&ErrorString); #define ZE_CALL_NOCHECK_NAME(ZeName, ZeArgs, callName) \ ZeCall().doCall(ZeName ZeArgs, callName, #ZeArgs, false) -// This wrapper around std::atomic is created to limit operations with reference -// counter and to make allowed operations more transparent in terms of -// thread-safety in the plugin. increment() and load() operations do not need a -// mutex guard around them since the underlying data is already atomic. -// decrementAndTest() method is used to guard a code which needs to be -// executed when object's ref count becomes zero after release. This method also -// doesn't need a mutex guard because decrement operation is atomic and only one -// thread can reach ref count equal to zero, i.e. only a single thread can pass -// through this check. -struct ReferenceCounter { - ReferenceCounter() : RefCount{1} {} - - // Reset the counter to the initial value. - void reset() { RefCount = 1; } - - // Used when retaining an object. - void increment() { RefCount++; } - - // Supposed to be used in ur*GetInfo* methods where ref count value is - // requested. - uint32_t load() { return RefCount.load(); } - - // This method allows to guard a code which needs to be executed when object's - // ref count becomes zero after release. It is important to notice that only a - // single thread can pass through this check. This is true because of several - // reasons: - // 1. Decrement operation is executed atomically. - // 2. It is not allowed to retain an object after its refcount reaches zero. - // 3. It is not allowed to release an object more times than the value of - // the ref count. - // 2. and 3. basically means that we can't use an object at all as soon as its - // refcount reaches zero. Using this check guarantees that code for deleting - // an object and releasing its resources is executed once by a single thread - // and we don't need to use any mutexes to guard access to this object in the - // scope after this check. Of course if we access another objects in this code - // (not the one which is being deleted) then access to these objects must be - // guarded, for example with a mutex. - bool decrementAndTest() { return --RefCount == 0; } - -private: - std::atomic RefCount; -}; - // Base class to store common data struct ur_object : ur::handle_base { - ur_object() : handle_base(), RefCount{} {} - - // Must be atomic to prevent data race when incrementing/decrementing. - ReferenceCounter RefCount; + ur_object() : handle_base() {} // This mutex protects accesses to all the non-const member variables. // Exclusive access is required to modify any of these members. @@ -303,6 +258,8 @@ struct MemAllocRecord : ur_object { // TODO: this should go away when memory isolation issue is fixed in the Level // Zero runtime. ur_context_handle_t Context; + + ur::RefCount RefCount; }; extern usm::DisjointPoolAllConfigs DisjointPoolConfigInstance; diff --git a/unified-runtime/source/adapters/level_zero/context.cpp b/unified-runtime/source/adapters/level_zero/context.cpp index 3209b8b789155..fe690f3673934 100644 --- a/unified-runtime/source/adapters/level_zero/context.cpp +++ b/unified-runtime/source/adapters/level_zero/context.cpp @@ -61,7 +61,7 @@ ur_result_t urContextRetain( /// [in] handle of the context to get a reference of. ur_context_handle_t Context) { - Context->RefCount.increment(); + Context->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -113,7 +113,7 @@ ur_result_t urContextGetInfo( case UR_CONTEXT_INFO_NUM_DEVICES: return ReturnValue(uint32_t(Context->Devices.size())); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Context->RefCount.load()}); + return ReturnValue(uint32_t{Context->RefCount.getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // 2D USM memcpy is supported. return ReturnValue(uint8_t{UseMemcpy2DOperations}); @@ -251,7 +251,7 @@ ur_device_handle_t ur_context_handle_t_::getRootDevice() const { // from the list of tracked contexts. ur_result_t ContextReleaseHelper(ur_context_handle_t Context) { - if (!Context->RefCount.decrementAndTest()) + if (!Context->RefCount.release()) return UR_RESULT_SUCCESS; if (IndirectAccessTrackingEnabled) { diff --git a/unified-runtime/source/adapters/level_zero/context.hpp b/unified-runtime/source/adapters/level_zero/context.hpp index 86e0ea27b5c3e..fbcbceb71b7f0 100644 --- a/unified-runtime/source/adapters/level_zero/context.hpp +++ b/unified-runtime/source/adapters/level_zero/context.hpp @@ -26,6 +26,7 @@ #include "queue.hpp" #include "usm.hpp" +#include "common/ur_ref_count.hpp" #include struct l0_command_list_cache_info { @@ -358,6 +359,8 @@ struct ur_context_handle_t_ : ur_object { // Get handle to the L0 context ze_context_handle_t getZeHandle() const; + ur::RefCount RefCount; + private: enum EventFlags { EVENT_FLAG_HOST_VISIBLE = UR_BIT(0), diff --git a/unified-runtime/source/adapters/level_zero/device.cpp b/unified-runtime/source/adapters/level_zero/device.cpp index 6392b3802a199..7fdc25c7382f6 100644 --- a/unified-runtime/source/adapters/level_zero/device.cpp +++ b/unified-runtime/source/adapters/level_zero/device.cpp @@ -470,7 +470,7 @@ ur_result_t urDeviceGetInfo( return ReturnValue((uint32_t)Device->SubDevices.size()); } case UR_DEVICE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Device->RefCount.load()}); + return ReturnValue(uint32_t{Device->RefCount.getCount()}); case UR_DEVICE_INFO_SUPPORTED_PARTITIONS: { // SYCL spec says: if this SYCL device cannot be partitioned into at least // two sub devices then the returned vector must be empty. @@ -1666,7 +1666,7 @@ ur_result_t urDeviceGetGlobalTimestamps( ur_result_t urDeviceRetain(ur_device_handle_t Device) { // The root-device ref-count remains unchanged (always 1). if (Device->isSubDevice()) { - Device->RefCount.increment(); + Device->RefCount.retain(); } return UR_RESULT_SUCCESS; } @@ -1674,7 +1674,7 @@ ur_result_t urDeviceRetain(ur_device_handle_t Device) { ur_result_t urDeviceRelease(ur_device_handle_t Device) { // Root devices are destroyed during the piTearDown process. if (Device->isSubDevice()) { - if (Device->RefCount.decrementAndTest()) { + if (Device->RefCount.release()) { delete Device; } } diff --git a/unified-runtime/source/adapters/level_zero/device.hpp b/unified-runtime/source/adapters/level_zero/device.hpp index a8326c0cf668b..48d0d5c13c579 100644 --- a/unified-runtime/source/adapters/level_zero/device.hpp +++ b/unified-runtime/source/adapters/level_zero/device.hpp @@ -20,6 +20,7 @@ #include "adapters/level_zero/platform.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include #include #include @@ -242,6 +243,8 @@ struct ur_device_handle_t_ : ur_object { // unique ephemeral identifer of the device in the adapter std::optional Id; + + ur::RefCount RefCount; }; inline std::vector diff --git a/unified-runtime/source/adapters/level_zero/event.cpp b/unified-runtime/source/adapters/level_zero/event.cpp index f06cae5ec0cb3..e1834376f0b41 100644 --- a/unified-runtime/source/adapters/level_zero/event.cpp +++ b/unified-runtime/source/adapters/level_zero/event.cpp @@ -505,7 +505,7 @@ ur_result_t urEventGetInfo( return ReturnValue(Result); } case UR_EVENT_INFO_REFERENCE_COUNT: { - return ReturnValue(Event->RefCount.load()); + return ReturnValue(Event->RefCount.getCount()); } default: UR_LOG(ERR, "Unsupported ParamName in urEventGetInfo: ParamName={}(0x{})", @@ -874,7 +874,7 @@ ur_result_t /// [in] handle of the event object urEventRetain(/** [in] handle of the event object */ ur_event_handle_t Event) { Event->RefCountExternal++; - Event->RefCount.increment(); + Event->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -1088,7 +1088,7 @@ ur_event_handle_t_::~ur_event_handle_t_() { ur_result_t urEventReleaseInternal(ur_event_handle_t Event, bool *isEventDeleted) { - if (!Event->RefCount.decrementAndTest()) + if (!Event->RefCount.release()) return UR_RESULT_SUCCESS; if (Event->OriginAllocEvent) { @@ -1524,7 +1524,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( std::shared_lock Lock(CurQueue->LastCommandEvent->Mutex); this->ZeEventList[0] = CurQueue->LastCommandEvent->ZeEvent; this->UrEventList[0] = CurQueue->LastCommandEvent; - this->UrEventList[0]->RefCount.increment(); + this->UrEventList[0]->RefCount.retain(); TmpListLength = 1; } else if (EventListLength > 0) { this->ZeEventList = new ze_event_handle_t[EventListLength]; @@ -1660,7 +1660,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( IsInternal, IsMultiDevice)); MultiDeviceZeEvent = MultiDeviceEvent->ZeEvent; const auto &ZeCommandList = CommandList->first; - EventList[I]->RefCount.increment(); + EventList[I]->RefCount.retain(); // Append a Barrier to wait on the original event while signalling the // new multi device event. @@ -1676,11 +1676,11 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList( this->ZeEventList[TmpListLength] = MultiDeviceZeEvent; this->UrEventList[TmpListLength] = MultiDeviceEvent; - this->UrEventList[TmpListLength]->RefCount.increment(); + this->UrEventList[TmpListLength]->RefCount.retain(); } else { this->ZeEventList[TmpListLength] = EventList[I]->ZeEvent; this->UrEventList[TmpListLength] = EventList[I]; - this->UrEventList[TmpListLength]->RefCount.increment(); + this->UrEventList[TmpListLength]->RefCount.retain(); } if (QueueLock.has_value()) { diff --git a/unified-runtime/source/adapters/level_zero/event.hpp b/unified-runtime/source/adapters/level_zero/event.hpp index 13b36bcdfbe94..c89ee5097c8e8 100644 --- a/unified-runtime/source/adapters/level_zero/event.hpp +++ b/unified-runtime/source/adapters/level_zero/event.hpp @@ -25,6 +25,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "queue.hpp" #include "ur_api.h" @@ -262,6 +263,8 @@ struct ur_event_handle_t_ : ur_object { // Used only for asynchronous allocations. This is the event originally used // on async free to indicate when the allocation can be used again. ur_event_handle_t OriginAllocEvent = nullptr; + + ur::RefCount RefCount; }; // Helper function to implement zeHostSynchronize. diff --git a/unified-runtime/source/adapters/level_zero/kernel.cpp b/unified-runtime/source/adapters/level_zero/kernel.cpp index 29c6b19e2bbfe..838cb96dc59ed 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/kernel.cpp @@ -787,7 +787,7 @@ ur_result_t urKernelGetInfo( } } case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Kernel->RefCount.load()}); + return ReturnValue(uint32_t{Kernel->RefCount.getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: try { uint32_t Size; @@ -938,7 +938,7 @@ ur_result_t urKernelGetSubGroupInfo( ur_result_t urKernelRetain( /// [in] handle for the Kernel to retain ur_kernel_handle_t Kernel) { - Kernel->RefCount.increment(); + Kernel->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -946,7 +946,7 @@ ur_result_t urKernelRetain( ur_result_t urKernelRelease( /// [in] handle for the Kernel to release ur_kernel_handle_t Kernel) { - if (!Kernel->RefCount.decrementAndTest()) + if (!Kernel->RefCount.release()) return UR_RESULT_SUCCESS; auto KernelProgram = Kernel->Program; diff --git a/unified-runtime/source/adapters/level_zero/kernel.hpp b/unified-runtime/source/adapters/level_zero/kernel.hpp index 7f80348cda31f..131dba270c05d 100644 --- a/unified-runtime/source/adapters/level_zero/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/kernel.hpp @@ -9,9 +9,11 @@ //===----------------------------------------------------------------------===// #pragma once +#include + #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "memory.hpp" -#include struct ur_kernel_handle_t_ : ur_object { ur_kernel_handle_t_(bool OwnZeHandle, ur_program_handle_t Program) @@ -106,6 +108,8 @@ struct ur_kernel_handle_t_ : ur_object { // Cache of the kernel properties. ZeCache> ZeKernelProperties; ZeCache ZeKernelName; + + ur::RefCount RefCount; }; ur_result_t getZeKernel(ze_device_handle_t hDevice, ur_kernel_handle_t hKernel, diff --git a/unified-runtime/source/adapters/level_zero/memory.cpp b/unified-runtime/source/adapters/level_zero/memory.cpp index 0f6bb37dde904..3b1158645e77a 100644 --- a/unified-runtime/source/adapters/level_zero/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/memory.cpp @@ -1052,7 +1052,7 @@ ur_result_t urEnqueueMemBufferMap( // Add the event to the command list. CommandList->second.append(reinterpret_cast(*Event)); - (*Event)->RefCount.increment(); + (*Event)->RefCount.retain(); const auto &ZeCommandList = CommandList->first; const auto &WaitList = (*Event)->WaitList; @@ -1183,7 +1183,7 @@ ur_result_t urEnqueueMemUnmap( nullptr /*ForcedCmdQueue*/)); CommandList->second.append(reinterpret_cast(*Event)); - (*Event)->RefCount.increment(); + (*Event)->RefCount.retain(); const auto &ZeCommandList = CommandList->first; @@ -1635,14 +1635,14 @@ ur_result_t urMemBufferCreate( ur_result_t urMemRetain( /// [in] handle of the memory object to get access ur_mem_handle_t Mem) { - Mem->RefCount.increment(); + Mem->RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t urMemRelease( /// [in] handle of the memory object to release ur_mem_handle_t Mem) { - if (!Mem->RefCount.decrementAndTest()) + if (!Mem->RefCount.release()) return UR_RESULT_SUCCESS; if (Mem->isImage()) { @@ -1848,7 +1848,7 @@ ur_result_t urMemGetInfo( return ReturnValue(size_t{Buffer->Size}); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(Buffer->RefCount.load()); + return ReturnValue(Buffer->RefCount.getCount()); } default: { return UR_RESULT_ERROR_INVALID_ENUMERATION; diff --git a/unified-runtime/source/adapters/level_zero/memory.hpp b/unified-runtime/source/adapters/level_zero/memory.hpp index 715b5b51870c1..f58f189b21c77 100644 --- a/unified-runtime/source/adapters/level_zero/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/memory.hpp @@ -19,6 +19,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "event.hpp" #include "program.hpp" @@ -90,6 +91,8 @@ struct ur_mem_handle_t_ : ur_object { // Method to get type of the derived object (image or buffer) bool isImage() const { return mem_type == mem_type_t::image; } + ur::RefCount RefCount; + protected: ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context) : UrContext{Context}, UrDevice{nullptr}, mem_type(type) {} @@ -116,7 +119,7 @@ struct ur_buffer final : ur_mem_handle_t_ { : ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), Size(Size), SubBuffer{{Parent, Origin}} { // Retain the Parent Buffer due to the Creation of the SubBuffer. - Parent->RefCount.increment(); + Parent->RefCount.retain(); } // Interop-buffer constructor diff --git a/unified-runtime/source/adapters/level_zero/physical_mem.cpp b/unified-runtime/source/adapters/level_zero/physical_mem.cpp index 5d4d0acce0eb3..8786dc2f784d1 100644 --- a/unified-runtime/source/adapters/level_zero/physical_mem.cpp +++ b/unified-runtime/source/adapters/level_zero/physical_mem.cpp @@ -42,12 +42,12 @@ ur_result_t urPhysicalMemCreate( } ur_result_t urPhysicalMemRetain(ur_physical_mem_handle_t hPhysicalMem) { - hPhysicalMem->RefCount.increment(); + hPhysicalMem->RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t urPhysicalMemRelease(ur_physical_mem_handle_t hPhysicalMem) { - if (!hPhysicalMem->RefCount.decrementAndTest()) + if (!hPhysicalMem->RefCount.release()) return UR_RESULT_SUCCESS; if (checkL0LoaderTeardown()) { @@ -68,7 +68,7 @@ ur_result_t urPhysicalMemGetInfo(ur_physical_mem_handle_t hPhysicalMem, switch (propName) { case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hPhysicalMem->RefCount.load()); + return ReturnValue(hPhysicalMem->RefCount.getCount()); } default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; diff --git a/unified-runtime/source/adapters/level_zero/physical_mem.hpp b/unified-runtime/source/adapters/level_zero/physical_mem.hpp index 6ce630bcc5e1f..8e0ab261bcfd4 100644 --- a/unified-runtime/source/adapters/level_zero/physical_mem.hpp +++ b/unified-runtime/source/adapters/level_zero/physical_mem.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" struct ur_physical_mem_handle_t_ : ur_object { ur_physical_mem_handle_t_(ze_physical_mem_handle_t ZePhysicalMem, @@ -21,4 +22,6 @@ struct ur_physical_mem_handle_t_ : ur_object { // Keeps the PI context of this memory handle. ur_context_handle_t Context; + + ur::RefCount RefCount; }; diff --git a/unified-runtime/source/adapters/level_zero/program.cpp b/unified-runtime/source/adapters/level_zero/program.cpp index 497e3057b7b9b..f41f9f6faf9ff 100644 --- a/unified-runtime/source/adapters/level_zero/program.cpp +++ b/unified-runtime/source/adapters/level_zero/program.cpp @@ -558,14 +558,14 @@ ur_result_t urProgramLinkExp( ur_result_t urProgramRetain( /// [in] handle for the Program to retain ur_program_handle_t Program) { - Program->RefCount.increment(); + Program->RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t urProgramRelease( /// [in] handle for the Program to release ur_program_handle_t Program) { - if (!Program->RefCount.decrementAndTest()) + if (!Program->RefCount.release()) return UR_RESULT_SUCCESS; delete Program; @@ -708,7 +708,7 @@ ur_result_t urProgramGetInfo( switch (PropName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Program->RefCount.load()}); + return ReturnValue(uint32_t{Program->RefCount.getCount()}); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(Program->Context); case UR_PROGRAM_INFO_NUM_DEVICES: @@ -1115,7 +1115,7 @@ void ur_program_handle_t_::ur_release_program_resources(bool deletion) { // must be destroyed before the Module can be destroyed. So, be sure // to destroy build log before destroying the module. if (!deletion) { - if (!RefCount.decrementAndTest()) { + if (!RefCount.release()) { return; } } diff --git a/unified-runtime/source/adapters/level_zero/program.hpp b/unified-runtime/source/adapters/level_zero/program.hpp index 789daf052ba0c..fb2cc1f12ee5e 100644 --- a/unified-runtime/source/adapters/level_zero/program.hpp +++ b/unified-runtime/source/adapters/level_zero/program.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" struct ur_program_handle_t_ : ur_object { @@ -226,6 +227,8 @@ struct ur_program_handle_t_ : ur_object { // UR_PROGRAM_INFO_BINARY_SIZES. const std::vector AssociatedDevices; + ur::RefCount RefCount; + private: struct DeviceData { // Log from the result of building the program for the device using diff --git a/unified-runtime/source/adapters/level_zero/queue.cpp b/unified-runtime/source/adapters/level_zero/queue.cpp index dc7d19ffaa007..4cb06f1348b7f 100644 --- a/unified-runtime/source/adapters/level_zero/queue.cpp +++ b/unified-runtime/source/adapters/level_zero/queue.cpp @@ -369,7 +369,7 @@ ur_result_t urQueueGetInfo( case UR_QUEUE_INFO_DEVICE: return ReturnValue(Queue->Device); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{Queue->RefCount.load()}); + return ReturnValue(uint32_t{Queue->RefCount.getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(Queue->Properties); case UR_QUEUE_INFO_SIZE: @@ -593,7 +593,7 @@ ur_result_t urQueueRetain( std::scoped_lock Lock(Queue->Mutex); Queue->RefCountExternal++; } - Queue->RefCount.increment(); + Queue->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -612,7 +612,7 @@ ur_result_t urQueueRelease( // internal reference count. When the External Reference count == 0, then // cleanup of the queue begins and the final decrement of the internal // reference count is completed. - static_cast(Queue->RefCount.decrementAndTest()); + static_cast(Queue->RefCount.release()); return UR_RESULT_SUCCESS; } @@ -1389,7 +1389,7 @@ ur_queue_handle_t_::executeCommandList(ur_command_list_ptr_t CommandList, if (!Event->HostVisibleEvent) { Event->HostVisibleEvent = reinterpret_cast(HostVisibleEvent); - HostVisibleEvent->RefCount.increment(); + HostVisibleEvent->RefCount.retain(); } } @@ -1550,7 +1550,7 @@ ur_result_t ur_queue_handle_t_::addEventToQueueCache(ur_event_handle_t Event) { } void ur_queue_handle_t_::active_barriers::add(ur_event_handle_t &Event) { - Event->RefCount.increment(); + Event->RefCount.retain(); Events.push_back(Event); } @@ -1588,7 +1588,7 @@ void ur_queue_handle_t_::clearEndTimeRecordings() { } ur_result_t urQueueReleaseInternal(ur_queue_handle_t Queue) { - if (!Queue->RefCount.decrementAndTest()) + if (!Queue->RefCount.release()) return UR_RESULT_SUCCESS; for (auto &Cache : Queue->EventCaches) { @@ -1921,7 +1921,7 @@ ur_result_t createEventAndAssociateQueue(ur_queue_handle_t Queue, // Append this Event to the CommandList, if any if (CommandList != Queue->CommandListMap.end()) { CommandList->second.append(*Event); - (*Event)->RefCount.increment(); + (*Event)->RefCount.retain(); } // We need to increment the reference counter here to avoid ur_queue_handle_t @@ -1929,7 +1929,7 @@ ur_result_t createEventAndAssociateQueue(ur_queue_handle_t Queue, // urEventRelease requires access to the associated ur_queue_handle_t. // In urEventRelease, the reference counter of the Queue is decremented // to release it. - Queue->RefCount.increment(); + Queue->RefCount.retain(); // SYCL RT does not track completion of the events, so it could // release a PI event as soon as that's not being waited in the app. @@ -1961,7 +1961,7 @@ void ur_queue_handle_t_::CaptureIndirectAccesses() { // SubmissionsCount turns to 0. We don't want to know how many times // allocation was retained by each submission. if (Pair.second) - Elem.second.RefCount.increment(); + Elem.second.RefCount.retain(); } } Kernel->SubmissionsCount++; diff --git a/unified-runtime/source/adapters/level_zero/queue.hpp b/unified-runtime/source/adapters/level_zero/queue.hpp index 405929c8f0f0e..a81ef2fecd56a 100644 --- a/unified-runtime/source/adapters/level_zero/queue.hpp +++ b/unified-runtime/source/adapters/level_zero/queue.hpp @@ -25,6 +25,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" extern "C" { @@ -692,6 +693,8 @@ struct ur_queue_handle_t_ : ur_object { // Pointer to the unified handle. ur_queue_handle_t_ *UnifiedHandle; + + ur::RefCount RefCount; }; // This helper function creates a ur_event_handle_t and associate a diff --git a/unified-runtime/source/adapters/level_zero/sampler.cpp b/unified-runtime/source/adapters/level_zero/sampler.cpp index 4f6f5760faada..e048cb5fe4408 100644 --- a/unified-runtime/source/adapters/level_zero/sampler.cpp +++ b/unified-runtime/source/adapters/level_zero/sampler.cpp @@ -124,14 +124,14 @@ ur_result_t urSamplerCreate( ur_result_t urSamplerRetain( /// [in] handle of the sampler object to get access ur_sampler_handle_t Sampler) { - Sampler->RefCount.increment(); + Sampler->RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t urSamplerRelease( /// [in] handle of the sampler object to release ur_sampler_handle_t Sampler) { - if (!Sampler->RefCount.decrementAndTest()) + if (!Sampler->RefCount.release()) return UR_RESULT_SUCCESS; if (checkL0LoaderTeardown()) { diff --git a/unified-runtime/source/adapters/level_zero/sampler.hpp b/unified-runtime/source/adapters/level_zero/sampler.hpp index 9a834a05215d9..29b2b9617c822 100644 --- a/unified-runtime/source/adapters/level_zero/sampler.hpp +++ b/unified-runtime/source/adapters/level_zero/sampler.hpp @@ -10,6 +10,7 @@ #pragma once #include "common.hpp" +#include "common/ur_ref_count.hpp" struct ur_sampler_handle_t_ : ur_object { ur_sampler_handle_t_(ze_sampler_handle_t Sampler) : ZeSampler{Sampler} {} @@ -18,6 +19,8 @@ struct ur_sampler_handle_t_ : ur_object { ze_sampler_handle_t ZeSampler; ZeStruct ZeSamplerDesc; + + ur::RefCount RefCount; }; // Construct ZE sampler desc from UR sampler desc. diff --git a/unified-runtime/source/adapters/level_zero/usm.cpp b/unified-runtime/source/adapters/level_zero/usm.cpp index ab8556dd692c7..e8dfb89df145c 100644 --- a/unified-runtime/source/adapters/level_zero/usm.cpp +++ b/unified-runtime/source/adapters/level_zero/usm.cpp @@ -523,14 +523,14 @@ ur_result_t urUSMPoolCreate( ur_result_t /// [in] pointer to USM memory pool urUSMPoolRetain(ur_usm_pool_handle_t Pool) { - Pool->RefCount.increment(); + Pool->RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t /// [in] pointer to USM memory pool urUSMPoolRelease(ur_usm_pool_handle_t Pool) { - if (Pool->RefCount.decrementAndTest()) { + if (Pool->RefCount.release()) { std::scoped_lock ContextLock(Pool->Context->Mutex); Pool->Context->UsmPoolHandles.remove(Pool); delete Pool; @@ -553,7 +553,7 @@ ur_result_t urUSMPoolGetInfo( switch (PropName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(Pool->RefCount.load()); + return ReturnValue(Pool->RefCount.getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(Pool->Context); @@ -1250,7 +1250,7 @@ ur_result_t ZeMemFreeHelper(ur_context_handle_t Context, void *Ptr) { if (It == std::end(Context->MemAllocs)) { die("All memory allocations must be tracked!"); } - if (!It->second.RefCount.decrementAndTest()) { + if (!It->second.RefCount.release()) { // Memory can't be deallocated yet. return UR_RESULT_SUCCESS; } @@ -1297,7 +1297,7 @@ ur_result_t USMFreeHelper(ur_context_handle_t Context, void *Ptr, if (It == std::end(Context->MemAllocs)) { die("All memory allocations must be tracked!"); } - if (!It->second.RefCount.decrementAndTest()) { + if (!It->second.RefCount.release()) { // Memory can't be deallocated yet. return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/level_zero/usm.hpp b/unified-runtime/source/adapters/level_zero/usm.hpp index b29ea29a7914d..f99a2f79c87b2 100644 --- a/unified-runtime/source/adapters/level_zero/usm.hpp +++ b/unified-runtime/source/adapters/level_zero/usm.hpp @@ -9,13 +9,14 @@ //===----------------------------------------------------------------------===// #pragma once -#include "common.hpp" +#include +#include "common.hpp" +#include "common/ur_ref_count.hpp" #include "enqueued_pool.hpp" #include "event.hpp" #include "ur_api.h" #include "ur_pool_manager.hpp" -#include #include usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); @@ -53,6 +54,8 @@ struct ur_usm_pool_handle_t_ : ur_object { ur_context_handle_t Context; + ur::RefCount RefCount; + private: UsmPool *getPool(const usm::pool_descriptor &Desc); usm::pool_manager PoolManager; diff --git a/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp index cc2a88fb409e5..bbc200bac6bab 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_buffer.cpp @@ -258,7 +258,7 @@ urCommandBufferCreateExp(ur_context_handle_t context, ur_device_handle_t device, ur_result_t urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { - hCommandBuffer->RefCount.increment(); + hCommandBuffer->RefCount.retain(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); @@ -266,7 +266,7 @@ urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { ur_result_t urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) try { - if (!hCommandBuffer->RefCount.decrementAndTest()) + if (!hCommandBuffer->RefCount.release()) return UR_RESULT_SUCCESS; if (auto executionEvent = hCommandBuffer->getExecutionEventUnlocked()) { @@ -630,7 +630,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer, switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()}); + return ReturnValue(uint32_t{hCommandBuffer->RefCount.getCount()}); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp b/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp index 109ec09ce24fd..c7f1d585a5b77 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_buffer.hpp @@ -9,15 +9,18 @@ //===----------------------------------------------------------------------===// #pragma once +#include + #include "../helpers/mutable_helpers.hpp" #include "command_list_manager.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "kernel.hpp" #include "lockable.hpp" #include "queue_api.hpp" -#include #include + struct kernel_command_handle; struct ur_exp_command_buffer_handle_t_ : public ur_object { @@ -62,6 +65,8 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object { ur_event_handle_t createEventIfRequested(ur_exp_command_buffer_sync_point_t *retSyncPoint); + ur::RefCount RefCount; + private: // Stores all sync points that are created by the command buffer. std::vector syncPoints; diff --git a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp index d29496ec39425..4532e6ded3e1b 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp @@ -916,7 +916,7 @@ ur_result_t ur_command_list_manager::appendNativeCommandExp( void ur_command_list_manager::recordSubmittedKernel( ur_kernel_handle_t hKernel) { submittedKernels.push_back(hKernel); - hKernel->RefCount.increment(); + hKernel->RefCount.retain(); } ze_command_list_handle_t ur_command_list_manager::getZeCommandList() { diff --git a/unified-runtime/source/adapters/level_zero/v2/context.cpp b/unified-runtime/source/adapters/level_zero/v2/context.cpp index 9ad4a4e61d633..bf054a493f656 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.cpp @@ -80,12 +80,12 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext, defaultUSMPool(this, nullptr), asyncPool(this, nullptr) {} ur_result_t ur_context_handle_t_::retain() { - RefCount.increment(); + RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t ur_context_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCount.release()) return UR_RESULT_SUCCESS; delete this; @@ -201,7 +201,7 @@ ur_result_t urContextGetInfo(ur_context_handle_t hContext, case UR_CONTEXT_INFO_NUM_DEVICES: return ReturnValue(uint32_t(hContext->getDevices().size())); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hContext->RefCount.load()}); + return ReturnValue(uint32_t{hContext->RefCount.getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // TODO: this is currently not implemented return ReturnValue(uint8_t{false}); diff --git a/unified-runtime/source/adapters/level_zero/v2/context.hpp b/unified-runtime/source/adapters/level_zero/v2/context.hpp index 8d3cf8ca05579..b1500092a727b 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.hpp @@ -14,6 +14,7 @@ #include "command_list_cache.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event_pool_cache.hpp" #include "usm.hpp" @@ -64,6 +65,8 @@ struct ur_context_handle_t_ : ur_object { // For that the Device or its root devices need to be in the context. bool isValidDevice(ur_device_handle_t Device) const; + ur::RefCount RefCount; + private: const v2::raii::ze_context_handle_t hContext; const std::vector hDevices; diff --git a/unified-runtime/source/adapters/level_zero/v2/event.cpp b/unified-runtime/source/adapters/level_zero/v2/event.cpp index 30816a9fcdd6f..2c3c4b9a8685c 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event.cpp @@ -160,12 +160,12 @@ ze_event_handle_t ur_event_handle_t_::getZeEvent() const { } ur_result_t ur_event_handle_t_::retain() { - RefCount.increment(); + RefCount.retain(); return UR_RESULT_SUCCESS; } ur_result_t ur_event_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCount.release()) return UR_RESULT_SUCCESS; if (event_pool) { @@ -258,7 +258,7 @@ ur_result_t urEventGetInfo(ur_event_handle_t hEvent, ur_event_info_t propName, } } case UR_EVENT_INFO_REFERENCE_COUNT: { - return returnValue(hEvent->RefCount.load()); + return returnValue(hEvent->RefCount.getCount()); } case UR_EVENT_INFO_COMMAND_QUEUE: { auto urQueueHandle = reinterpret_cast(hEvent->getQueue()) - diff --git a/unified-runtime/source/adapters/level_zero/v2/event.hpp b/unified-runtime/source/adapters/level_zero/v2/event.hpp index 0e9386578a2f6..9a31c47358947 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/event.hpp @@ -17,6 +17,7 @@ #include "adapters/level_zero/v2/queue_api.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event_provider.hpp" namespace v2 { @@ -112,6 +113,8 @@ struct ur_event_handle_t_ : ur_object { uint64_t getEventStartTimestmap() const; uint64_t getEventEndTimestamp(); + ur::RefCount RefCount; + private: ur_event_handle_t_(ur_context_handle_t hContext, event_variant hZeEvent, v2::event_flags_t flags, v2::event_pool *pool); diff --git a/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp b/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp index 99d4852f9ad4a..f5d6d3429dbdf 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_pool.cpp @@ -52,8 +52,8 @@ void event_pool::free(ur_event_handle_t event) { freelist.push_back(event); // The event is still in the pool, so we need to increment the refcount - assert(event->RefCount.load() == 0); - event->RefCount.increment(); + assert(event->RefCount.getCount() == 0); + event->RefCount.retain(); } event_provider *event_pool::getProvider() const { return provider.get(); } diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp index a2189b57536e8..173b51ffc42a5 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.cpp @@ -97,7 +97,7 @@ ur_kernel_handle_t_::ur_kernel_handle_t_( } ur_result_t ur_kernel_handle_t_::release() { - if (!RefCount.decrementAndTest()) + if (!RefCount.release()) return UR_RESULT_SUCCESS; // manually release kernels to allow errors to be propagated @@ -370,7 +370,7 @@ urKernelCreateWithNativeHandle(ur_native_handle_t hNativeKernel, ur_result_t urKernelRetain( /// [in] handle for the Kernel to retain ur_kernel_handle_t hKernel) try { - hKernel->RefCount.increment(); + hKernel->RefCount.retain(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); @@ -634,7 +634,7 @@ ur_result_t urKernelGetInfo(ur_kernel_handle_t hKernel, spills.size()); } case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hKernel->RefCount.load()}); + return ReturnValue(uint32_t{hKernel->RefCount.getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: { auto attributes = hKernel->getSourceAttributes(); return ReturnValue(static_cast(attributes.data())); diff --git a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp index 0cabb888ac3be..9c823760c42f2 100644 --- a/unified-runtime/source/adapters/level_zero/v2/kernel.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/kernel.hpp @@ -13,6 +13,7 @@ #include "../program.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "memory.hpp" struct ur_single_device_kernel_t { @@ -91,6 +92,8 @@ struct ur_kernel_handle_t_ : ur_object { ze_command_list_handle_t cmdList, wait_list_view &waitListView); + ur::RefCount RefCount; + private: // Keep the program of the kernel. const ur_program_handle_t hProgram; diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.cpp b/unified-runtime/source/adapters/level_zero/v2/memory.cpp index b1f3829dd6967..9c39a97d163e8 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.cpp @@ -671,7 +671,7 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMem, ur_mem_info_t propName, return returnValue(size_t{hMem->getBuffer()->getSize()}); } case UR_MEM_INFO_REFERENCE_COUNT: { - return returnValue(hMem->getObject()->RefCount.load()); + return returnValue(hMem->RefCount.getCount()); } default: { return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -684,14 +684,14 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMem, ur_mem_info_t propName, } ur_result_t urMemRetain(ur_mem_handle_t hMem) try { - hMem->getObject()->RefCount.increment(); + hMem->RefCount.retain(); return UR_RESULT_SUCCESS; } catch (...) { return exceptionToResult(std::current_exception()); } ur_result_t urMemRelease(ur_mem_handle_t hMem) try { - if (!hMem->getObject()->RefCount.decrementAndTest()) + if (!hMem->RefCount.release()) return UR_RESULT_SUCCESS; delete hMem; diff --git a/unified-runtime/source/adapters/level_zero/v2/memory.hpp b/unified-runtime/source/adapters/level_zero/v2/memory.hpp index 9c0dc66ef72b4..7201df57c9509 100644 --- a/unified-runtime/source/adapters/level_zero/v2/memory.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/memory.hpp @@ -19,6 +19,7 @@ #include "../image_common.hpp" #include "command_list_manager.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" using usm_unique_ptr_t = std::unique_ptr>; @@ -279,16 +280,10 @@ struct ur_mem_handle_t_ : ur::handle_base { mem); } - ur_object *getObject() { - return std::visit( - [](auto &&arg) -> ur_object * { - return static_cast(&arg); - }, - mem); - } - bool isImage() const { return std::holds_alternative(mem); } + ur::RefCount RefCount; + private: template ur_mem_handle_t_(std::in_place_type_t, Args &&...args) diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp index 9831afdbc9e4c..c414f79a46d71 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_handle.hpp @@ -13,11 +13,12 @@ #pragma once +#include + #include "../common.hpp" #include "queue_immediate_in_order.hpp" #include "queue_immediate_out_of_order.hpp" #include -#include struct ur_queue_handle_t_ : ur::handle_base { using data_variant = std::variant { ur_result_t queueRetain() { return std::visit( [](auto &q) { - q.RefCount.increment(); + q.RefCount.retain(); return UR_RESULT_SUCCESS; }, queue_data); @@ -59,7 +60,7 @@ struct ur_queue_handle_t_ : ur::handle_base { ur_result_t queueRelease() { return std::visit( [queueHandle = this](auto &q) { - if (!q.RefCount.decrementAndTest()) + if (!q.RefCount.release()) return UR_RESULT_SUCCESS; delete queueHandle; return UR_RESULT_SUCCESS; diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp index cc9b464333e70..85e6c7e2503c9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.cpp @@ -60,7 +60,7 @@ ur_queue_immediate_in_order_t::queueGetInfo(ur_queue_info_t propName, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hDevice); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{RefCount.load()}); + return ReturnValue(uint32_t{RefCount.getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(flags); case UR_QUEUE_INFO_SIZE: diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp index 362a6ea31c9f4..74b37d1b40eb3 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_in_order.hpp @@ -12,6 +12,7 @@ #include "../common.hpp" #include "../device.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "event.hpp" #include "event_pool_cache.hpp" @@ -451,6 +452,8 @@ struct ur_queue_immediate_in_order_t : ur_object, ur_queue_t_ { numEventsInWaitList, phEventWaitList, createEventIfRequested(eventPool.get(), phEvent, this)); } + + ur::RefCount RefCount; }; } // namespace v2 diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp index bfb6079af3ea5..f0344eba267f9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.cpp @@ -54,7 +54,7 @@ ur_result_t ur_queue_immediate_out_of_order_t::queueGetInfo( case UR_QUEUE_INFO_DEVICE: return ReturnValue(hDevice); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{RefCount.load()}); + return ReturnValue(uint32_t{RefCount.getCount()}); case UR_QUEUE_INFO_FLAGS: return ReturnValue(flags); case UR_QUEUE_INFO_SIZE: diff --git a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp index 1d0bf5636d58c..07e8743154ded 100644 --- a/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/queue_immediate_out_of_order.hpp @@ -11,6 +11,7 @@ #include "../common.hpp" #include "../device.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "event.hpp" @@ -503,6 +504,8 @@ struct ur_queue_immediate_out_of_order_t : ur_object, ur_queue_t_ { numEventsInWaitList, phEventWaitList, createEventIfRequested(eventPool.get(), phEvent, this)); } + + ur::RefCount RefCount; }; } // namespace v2 diff --git a/unified-runtime/source/adapters/level_zero/v2/usm.cpp b/unified-runtime/source/adapters/level_zero/v2/usm.cpp index f455fd3763554..f95d86e1c33e9 100644 --- a/unified-runtime/source/adapters/level_zero/v2/usm.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/usm.cpp @@ -332,7 +332,7 @@ ur_result_t urUSMPoolCreate( ur_result_t /// [in] pointer to USM memory pool urUSMPoolRetain(ur_usm_pool_handle_t hPool) try { - hPool->RefCount.increment(); + hPool->RefCount.retain(); return UR_RESULT_SUCCESS; } catch (umf_result_t e) { return umf::umf2urResult(e); @@ -343,7 +343,7 @@ urUSMPoolRetain(ur_usm_pool_handle_t hPool) try { ur_result_t /// [in] pointer to USM memory pool urUSMPoolRelease(ur_usm_pool_handle_t hPool) try { - if (hPool->RefCount.decrementAndTest()) { + if (hPool->RefCount.release()) { hPool->getContextHandle()->removeUsmPool(hPool); delete hPool; } @@ -369,7 +369,7 @@ ur_result_t urUSMPoolGetInfo( switch (propName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(hPool->RefCount.load()); + return ReturnValue(hPool->RefCount.getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(hPool->getContextHandle()); diff --git a/unified-runtime/source/adapters/level_zero/v2/usm.hpp b/unified-runtime/source/adapters/level_zero/v2/usm.hpp index ff33b5f6bbed1..35e3446b82abc 100644 --- a/unified-runtime/source/adapters/level_zero/v2/usm.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/usm.hpp @@ -14,6 +14,7 @@ #include "../enqueued_pool.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event.hpp" #include "ur_pool_manager.hpp" @@ -49,6 +50,8 @@ struct ur_usm_pool_handle_t_ : ur_object { void cleanupPools(); void cleanupPoolsForQueue(void *hQueue); + ur::RefCount RefCount; + private: ur_context_handle_t hContext; usm::pool_manager poolManager; diff --git a/unified-runtime/source/common/ur_ref_count.hpp b/unified-runtime/source/common/ur_ref_count.hpp new file mode 100644 index 0000000000000..7815e1dab65d0 --- /dev/null +++ b/unified-runtime/source/common/ur_ref_count.hpp @@ -0,0 +1,36 @@ +/* + * + * Copyright (C) 2025 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM + * Exceptions. See LICENSE.TXT + * + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + */ +#ifndef URREFCOUNT_HPP +#define URREFCOUNT_HPP 1 + +#include +#include + +namespace ur { + +class RefCount { +public: + RefCount(uint32_t count = 1) : Count(count) {} + RefCount(const RefCount &) = delete; + RefCount &operator=(const RefCount &) = delete; + + uint32_t getCount() const noexcept { return Count.load(); } + uint32_t retain() { return ++Count; } + bool release() { return --Count == 0; } + void reset(uint32_t value = 1) { Count = value; } + +private: + std::atomic_uint32_t Count; +}; + +} // namespace ur + +#endif // URREFCOUNT_HPP