diff --git a/unified-runtime/source/adapters/native_cpu/adapter.cpp b/unified-runtime/source/adapters/native_cpu/adapter.cpp index 3fd6d4256825b..19333ddc10654 100644 --- a/unified-runtime/source/adapters/native_cpu/adapter.cpp +++ b/unified-runtime/source/adapters/native_cpu/adapter.cpp @@ -10,17 +10,18 @@ #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "ur_api.h" struct ur_adapter_handle_t_ : ur::native_cpu::handle_base { - std::atomic RefCount = 0; logger::Logger &logger = logger::get_logger("native_cpu"); + ur::RefCount RefCount; } Adapter; UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (phAdapters) { - Adapter.RefCount++; + Adapter.RefCount.retain(); *phAdapters = &Adapter; } if (pNumAdapters) { @@ -30,12 +31,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - Adapter.RefCount--; + Adapter.RefCount.release(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - Adapter.RefCount++; + Adapter.RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -57,7 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_NATIVE_CPU); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(Adapter.RefCount.load()); + return ReturnValue(Adapter.RefCount.getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/native_cpu/common.hpp b/unified-runtime/source/adapters/native_cpu/common.hpp index e768a4b2ac7f4..d54b315664083 100644 --- a/unified-runtime/source/adapters/native_cpu/common.hpp +++ b/unified-runtime/source/adapters/native_cpu/common.hpp @@ -44,22 +44,13 @@ struct ddi_getter { using handle_base = ur::handle_base; } // namespace ur::native_cpu -// Todo: replace this with a common helper once it is available -struct RefCounted : ur::native_cpu::handle_base { - std::atomic_uint32_t _refCount; - uint32_t incrementReferenceCount() { return ++_refCount; } - uint32_t decrementReferenceCount() { return --_refCount; } - RefCounted() : handle_base(), _refCount{1} {} - uint32_t getReferenceCount() const { return _refCount; } -}; - // Base class to store common data -struct ur_object : RefCounted { +struct ur_object { ur_shared_mutex Mutex; }; template inline void decrementOrDelete(T *refC) { - if (refC->decrementReferenceCount() == 0) + if (refC->RefCount.release() == 0) delete refC; } diff --git a/unified-runtime/source/adapters/native_cpu/context.cpp b/unified-runtime/source/adapters/native_cpu/context.cpp index 5b7e8fc839884..ee5ad0d926e91 100644 --- a/unified-runtime/source/adapters/native_cpu/context.cpp +++ b/unified-runtime/source/adapters/native_cpu/context.cpp @@ -30,7 +30,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - hContext->incrementReferenceCount(); + hContext->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -51,7 +51,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, case UR_CONTEXT_INFO_DEVICES: return returnValue(hContext->_device); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return returnValue(uint32_t{hContext->getReferenceCount()}); + return returnValue(uint32_t{hContext->RefCount.getCount()}); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: return returnValue(true); case UR_CONTEXT_INFO_USM_FILL2D_SUPPORT: diff --git a/unified-runtime/source/adapters/native_cpu/context.hpp b/unified-runtime/source/adapters/native_cpu/context.hpp index b9d2d22dd1565..5c775a9234c86 100644 --- a/unified-runtime/source/adapters/native_cpu/context.hpp +++ b/unified-runtime/source/adapters/native_cpu/context.hpp @@ -15,6 +15,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" #include "ur/ur.hpp" @@ -83,7 +84,7 @@ static usm_alloc_info get_alloc_info(void *ptr) { } // namespace native_cpu -struct ur_context_handle_t_ : RefCounted { +struct ur_context_handle_t_ { ur_context_handle_t_(ur_device_handle_t_ *phDevices) : _device{phDevices} {} ur_device_handle_t _device; @@ -135,6 +136,8 @@ struct ur_context_handle_t_ : RefCounted { return ptr; } + ur::RefCount RefCount; + private: std::mutex alloc_mutex; std::set allocations; diff --git a/unified-runtime/source/adapters/native_cpu/event.cpp b/unified-runtime/source/adapters/native_cpu/event.cpp index 91b8fb302eb18..87897fe66ad5e 100644 --- a/unified-runtime/source/adapters/native_cpu/event.cpp +++ b/unified-runtime/source/adapters/native_cpu/event.cpp @@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, case UR_EVENT_INFO_COMMAND_TYPE: return ReturnValue(hEvent->getCommandType()); case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->RefCount.getCount()); case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: return ReturnValue(hEvent->getExecutionStatus()); case UR_EVENT_INFO_CONTEXT: @@ -69,7 +69,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - hEvent->incrementReferenceCount(); + hEvent->RefCount.retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/event.hpp b/unified-runtime/source/adapters/native_cpu/event.hpp index 479c671b38cd1..778cd4eef6abe 100644 --- a/unified-runtime/source/adapters/native_cpu/event.hpp +++ b/unified-runtime/source/adapters/native_cpu/event.hpp @@ -8,14 +8,17 @@ // //===----------------------------------------------------------------------===// #pragma once -#include "common.hpp" -#include "ur_api.h" + #include #include #include #include -struct ur_event_handle_t_ : RefCounted { +#include "common.hpp" +#include "common/ur_ref_count.hpp" +#include "ur_api.h" + +struct ur_event_handle_t_ { ur_event_handle_t_(ur_queue_handle_t queue, ur_command_t command_type); @@ -55,6 +58,8 @@ struct ur_event_handle_t_ : RefCounted { uint64_t get_end_timestamp() const { return timestamp_end; } + ur::RefCount RefCount; + private: ur_queue_handle_t queue; ur_context_handle_t context; diff --git a/unified-runtime/source/adapters/native_cpu/kernel.cpp b/unified-runtime/source/adapters/native_cpu/kernel.cpp index ac11331357f39..622bbc5cab38d 100644 --- a/unified-runtime/source/adapters/native_cpu/kernel.cpp +++ b/unified-runtime/source/adapters/native_cpu/kernel.cpp @@ -95,7 +95,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_FUNCTION_NAME: return ReturnValue(hKernel->_name.c_str()); case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(uint32_t{hKernel->getReferenceCount()}); + return ReturnValue(uint32_t{hKernel->RefCount.getCount()}); case UR_KERNEL_INFO_ATTRIBUTES: return ReturnValue(""); case UR_KERNEL_INFO_SPILL_MEM_SIZE: @@ -194,7 +194,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetSubGroupInfo( } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - hKernel->incrementReferenceCount(); + hKernel->RefCount.retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/kernel.hpp b/unified-runtime/source/adapters/native_cpu/kernel.hpp index 8daf23feb65f5..09a02578ccf36 100644 --- a/unified-runtime/source/adapters/native_cpu/kernel.hpp +++ b/unified-runtime/source/adapters/native_cpu/kernel.hpp @@ -8,13 +8,15 @@ #pragma once +#include +#include + #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "memory.hpp" #include "nativecpu_state.hpp" #include "program.hpp" -#include #include -#include using nativecpu_kernel_t = void(void *const *, native_cpu::state *); using nativecpu_ptr_t = nativecpu_kernel_t *; @@ -27,7 +29,7 @@ struct local_arg_info_t { : argIndex(argIndex), argSize(argSize) {} }; -struct ur_kernel_handle_t_ : RefCounted { +struct ur_kernel_handle_t_ { ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name, nativecpu_task_t subhandler) @@ -188,10 +190,12 @@ struct ur_kernel_handle_t_ : RefCounted { void addPtrArg(void *Ptr, size_t Index) { Args.addPtrArg(Index, Ptr); } void addArgReference(ur_mem_handle_t Arg) { - Arg->incrementReferenceCount(); + Arg->RefCount.getCount(); ReferencedArgs.push_back(Arg); } + ur::RefCount RefCount; + private: void removeArgReferences() { for (auto arg : ReferencedArgs) diff --git a/unified-runtime/source/adapters/native_cpu/memory.hpp b/unified-runtime/source/adapters/native_cpu/memory.hpp index ca6e3e77f5e87..1fd2962e4d342 100644 --- a/unified-runtime/source/adapters/native_cpu/memory.hpp +++ b/unified-runtime/source/adapters/native_cpu/memory.hpp @@ -15,6 +15,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" struct ur_mem_handle_t_ : ur_object { @@ -43,6 +44,8 @@ struct ur_mem_handle_t_ : ur_object { char *_mem; bool _ownsMem; + ur::RefCount RefCount; + private: const bool IsImage; }; diff --git a/unified-runtime/source/adapters/native_cpu/program.cpp b/unified-runtime/source/adapters/native_cpu/program.cpp index fee72f8a6bc3c..d408588ecbfc5 100644 --- a/unified-runtime/source/adapters/native_cpu/program.cpp +++ b/unified-runtime/source/adapters/native_cpu/program.cpp @@ -171,7 +171,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp( UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - hProgram->incrementReferenceCount(); + hProgram->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -205,7 +205,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return returnValue(hProgram->getReferenceCount()); + return returnValue(hProgram->RefCount.getCount()); case UR_PROGRAM_INFO_CONTEXT: return returnValue(nullptr); case UR_PROGRAM_INFO_NUM_DEVICES: diff --git a/unified-runtime/source/adapters/native_cpu/program.hpp b/unified-runtime/source/adapters/native_cpu/program.hpp index d58412751e8f2..82d7a7ea4ec35 100644 --- a/unified-runtime/source/adapters/native_cpu/program.hpp +++ b/unified-runtime/source/adapters/native_cpu/program.hpp @@ -10,23 +10,22 @@ #pragma once +#include +#include + #include +#include "common/ur_ref_count.hpp" #include "context.hpp" -#include -#include - namespace native_cpu { using WGSize_t = std::array; } -struct ur_program_handle_t_ : RefCounted { +struct ur_program_handle_t_ { ur_program_handle_t_(ur_context_handle_t ctx, const unsigned char *pBinary) : _ctx{ctx}, _ptr{pBinary} {} - uint32_t getReferenceCount() const noexcept { return _refCount; } - ur_context_handle_t _ctx; const unsigned char *_ptr; struct _compare { @@ -41,6 +40,8 @@ struct ur_program_handle_t_ : RefCounted { std::unordered_map KernelMaxWorkGroupSizeMD; std::unordered_map KernelMaxLinearWorkGroupSizeMD; + + ur::RefCount RefCount; }; // The nativecpu_entry struct is also defined as LLVM-IR in the diff --git a/unified-runtime/source/adapters/native_cpu/queue.cpp b/unified-runtime/source/adapters/native_cpu/queue.cpp index 5de7037519490..31732763687ef 100644 --- a/unified-runtime/source/adapters/native_cpu/queue.cpp +++ b/unified-runtime/source/adapters/native_cpu/queue.cpp @@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hQueue->getDevice()); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->RefCount.getCount()); case UR_QUEUE_INFO_EMPTY: return ReturnValue(hQueue->isEmpty()); default: @@ -48,7 +48,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate( } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - hQueue->incrementReferenceCount(); + hQueue->RefCount.retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/native_cpu/queue.hpp b/unified-runtime/source/adapters/native_cpu/queue.hpp index 9fd28c3e7ff00..a7fe86305953f 100644 --- a/unified-runtime/source/adapters/native_cpu/queue.hpp +++ b/unified-runtime/source/adapters/native_cpu/queue.hpp @@ -8,12 +8,15 @@ // //===----------------------------------------------------------------------===// #pragma once + +#include + #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "event.hpp" #include "ur_api.h" -#include -struct ur_queue_handle_t_ : RefCounted { +struct ur_queue_handle_t_ { ur_queue_handle_t_(ur_device_handle_t device, ur_context_handle_t context, const ur_queue_properties_t *pProps) : device(device), context(context), @@ -43,7 +46,7 @@ struct ur_queue_handle_t_ : RefCounted { auto ev = *events.begin(); // ur_event_handle_t_::wait removes itself from the events set in the // queue. - ev->incrementReferenceCount(); + ev->RefCount.retain(); // Unlocking mutex for removeEvent and for event callbacks that may need // to acquire it. lock.unlock(); @@ -64,6 +67,8 @@ struct ur_queue_handle_t_ : RefCounted { return events.size() == 0; } + ur::RefCount RefCount; + private: ur_device_handle_t device; ur_context_handle_t context;