Skip to content

[SYCL] Use shared_ptr instead of manual changing UR counters #18565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0edc034
[SYCL] Use shared_ptr instead of manual changing UR counters
Alexandr-Konovalov May 20, 2025
3a3f017
Code formatting.
Alexandr-Konovalov May 20, 2025
f0be3fe
Code formatting.
Alexandr-Konovalov May 20, 2025
dad0228
Merge branch 'sycl' into Alexandr-Konovalov/UR-counters-shared_ptr
Alexandr-Konovalov May 20, 2025
a11ef81
Clarify comment, drop unneeded move.
Alexandr-Konovalov May 21, 2025
93d9cc4
Extend scope of checked calls.
Alexandr-Konovalov May 22, 2025
e9852a6
Merge branch 'sycl' into Alexandr-Konovalov/UR-counters-shared_ptr
Alexandr-Konovalov May 28, 2025
d5e20f6
Use raw pointer to reference adapter.
Alexandr-Konovalov May 28, 2025
db56adf
Code formatting.
Alexandr-Konovalov May 28, 2025
7715cc6
Delete copy constructor and copy-assignment operator.
Alexandr-Konovalov May 30, 2025
54d94f1
Merge branch 'sycl' into Alexandr-Konovalov/UR-counters-shared_ptr
Alexandr-Konovalov May 30, 2025
4a20d5f
Return shared_ptr, not individual handlers, from getCGKernelInfo().
Alexandr-Konovalov Jun 2, 2025
76980f2
Code formatting.
Alexandr-Konovalov Jun 2, 2025
8c16b6e
Remove sycl/test/native_cpu/multiple_definitions.cpp.
Alexandr-Konovalov Jun 2, 2025
441344d
Merge branch 'sycl' into Alexandr-Konovalov/UR-counters-shared_ptr
Alexandr-Konovalov Jun 2, 2025
cbb03d2
Update sycl/source/detail/program_manager/program_manager.cpp
Alexandr-Konovalov Jun 2, 2025
37bde6f
Keep reference to Adapter, not pointer.
Alexandr-Konovalov Jun 2, 2025
8c0c60a
Code formatting.
Alexandr-Konovalov Jun 2, 2025
728f698
Extend comment.
Alexandr-Konovalov Jun 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,8 +1480,7 @@ bool exec_graph_impl::needsScheduledUpdate(
}

void exec_graph_impl::populateURKernelUpdateStructs(
const std::shared_ptr<node_impl> &Node,
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
const std::shared_ptr<node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
Expand Down Expand Up @@ -1517,12 +1516,11 @@ void exec_graph_impl::populateURKernelUpdateStructs(
UrKernel = SyclKernelImpl->getHandleRef();
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else {
ur_program_handle_t UrProgram = nullptr;
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, ExecCG.MKernelName,
ExecCG.MKernelNameBasedCachePtr);
BundleObjs = std::make_pair(UrProgram, UrKernel);
BundleObjs = sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, ExecCG.MKernelName,
ExecCG.MKernelNameBasedCachePtr);
UrKernel = BundleObjs->MKernelHandle;
EliminatedArgMask = BundleObjs->MKernelArgMask;
}

// Remove eliminated args
Expand Down Expand Up @@ -1717,8 +1715,7 @@ void exec_graph_impl::updateURImpl(
std::vector<sycl::detail::NDRDescT> NDRDescList(NumUpdatableNodes);
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> UpdateDescList(
NumUpdatableNodes);
std::vector<std::pair<ur_program_handle_t, ur_kernel_handle_t>>
KernelBundleObjList(NumUpdatableNodes);
std::vector<FastKernelCacheValPtr> KernelBundleObjList(NumUpdatableNodes);

size_t StructListIndex = 0;
for (auto &Node : Nodes) {
Expand All @@ -1743,17 +1740,6 @@ void exec_graph_impl::updateURImpl(
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
Adapter->call<sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
CommandBuffer, UpdateDescList.size(), UpdateDescList.data());

for (auto &BundleObjs : KernelBundleObjList) {
// We retained these objects by inside populateUpdateStruct() by calling
// getOrCreateKernel()
if (auto &UrKernel = BundleObjs.second; nullptr != UrKernel) {
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
}
if (auto &UrProgram = BundleObjs.first; nullptr != UrProgram) {
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
}
}
}

modifiable_command_graph::modifiable_command_graph(
Expand Down
3 changes: 1 addition & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,8 +1521,7 @@ class exec_graph_impl {
/// @param[out] NDRDesc ND-Range to update.
/// @param[out] UpdateDesc Base struct in the pointer chain.
void populateURKernelUpdateStructs(
const std::shared_ptr<node_impl> &Node,
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
const std::shared_ptr<node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
Expand Down
44 changes: 40 additions & 4 deletions sycl/source/detail/kernel_name_based_cache_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,47 @@ namespace sycl {
inline namespace _V1 {
namespace detail {
using FastKernelCacheKeyT = std::pair<ur_device_handle_t, ur_context_handle_t>;
using FastKernelCacheValT =
std::tuple<ur_kernel_handle_t, std::mutex *, const KernelArgMask *,
ur_program_handle_t>;

struct FastKernelCacheVal {
ur_kernel_handle_t MKernelHandle; /* UR kernel handle pointer. */
std::mutex *MMutex; /* Mutex guarding this kernel. When
caching is disabled, the pointer is
nullptr. */
const KernelArgMask *MKernelArgMask; /* Eliminated kernel argument mask. */
ur_program_handle_t MProgramHandle; /* UR program handle corresponding to
this kernel. */
const Adapter &MAdapterPtr; /* We can keep reference to the adapter
because during 2-stage shutdown the kernel
cache is destroyed deliberately before the
adapter. */

FastKernelCacheVal(ur_kernel_handle_t KernelHandle, std::mutex *Mutex,
const KernelArgMask *KernelArgMask,
ur_program_handle_t ProgramHandle,
const Adapter &AdapterPtr)
: MKernelHandle(KernelHandle), MMutex(Mutex),
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle),
MAdapterPtr(AdapterPtr) {}

~FastKernelCacheVal() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set to nullptr after "Release"? It should be relatively cheap but can save lots of time debugging obscure errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (MKernelHandle)
MAdapterPtr.call<sycl::detail::UrApiKind::urKernelRelease>(MKernelHandle);
if (MProgramHandle)
MAdapterPtr.call<sycl::detail::UrApiKind::urProgramRelease>(
MProgramHandle);
MKernelHandle = nullptr;
MMutex = nullptr;
MKernelArgMask = nullptr;
MProgramHandle = nullptr;
}

FastKernelCacheVal(const FastKernelCacheVal &) = delete;
FastKernelCacheVal &operator=(const FastKernelCacheVal &) = delete;
};
using FastKernelCacheValPtr = std::shared_ptr<FastKernelCacheVal>;

using FastKernelSubcacheMapT =
::boost::unordered_flat_map<FastKernelCacheKeyT, FastKernelCacheValT>;
::boost::unordered_flat_map<FastKernelCacheKeyT, FastKernelCacheValPtr>;

using FastKernelSubcacheMutexT = SpinLock;
using FastKernelSubcacheReadLockT = std::lock_guard<FastKernelSubcacheMutexT>;
Expand Down
14 changes: 7 additions & 7 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ class KernelProgramCache {
return std::make_pair(It->second, DidInsert);
}

FastKernelCacheValT
FastKernelCacheValPtr
tryToGetKernelFast(KernelNameStrRefT KernelName, ur_device_handle_t Device,
FastKernelSubcacheT *KernelSubcacheHint) {
FastKernelCacheWriteLockT Lock(MFastKernelCacheMutex);
Expand All @@ -486,27 +486,27 @@ class KernelProgramCache {
traceKernel("Kernel fetched.", KernelName, true);
return It->second;
}
return std::make_tuple(nullptr, nullptr, nullptr, nullptr);
return FastKernelCacheValPtr();
}

void saveKernel(KernelNameStrRefT KernelName, ur_device_handle_t Device,
FastKernelCacheValT CacheVal,
const FastKernelCacheValPtr &CacheVal,
FastKernelSubcacheT *KernelSubcacheHint) {
ur_program_handle_t Program = std::get<3>(CacheVal);
if (SYCLConfig<SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::
isProgramCacheEvictionEnabled()) {
// Save kernel in fast cache only if the corresponding program is also
// in the cache.
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
if (ProgCache.ProgramSizeMap.find(Program) ==
if (ProgCache.ProgramSizeMap.find(CacheVal->MProgramHandle) ==
ProgCache.ProgramSizeMap.end())
return;
}

// Save reference between the program and the fast cache key.
FastKernelCacheWriteLockT Lock(MFastKernelCacheMutex);
MProgramToFastKernelCacheKeyMap[Program].emplace_back(KernelName, Device);
MProgramToFastKernelCacheKeyMap[CacheVal->MProgramHandle].emplace_back(
KernelName, Device);

// if no insertion took place, then some other thread has already inserted
// smth in the cache
Expand All @@ -518,7 +518,7 @@ class KernelProgramCache {
FastKernelSubcacheWriteLockT SubcacheLock{KernelSubcacheHint->Mutex};
ur_context_handle_t Context = getURContext();
KernelSubcacheHint->Map.emplace(FastKernelCacheKeyT(Device, Context),
std::move(CacheVal));
CacheVal);
}

// Expects locked program cache
Expand Down
38 changes: 14 additions & 24 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <detail/device_impl.hpp>
#include <detail/event_impl.hpp>
#include <detail/global_handler.hpp>
#include <detail/kernel_name_based_cache_t.hpp>
#include <detail/persistent_device_code_cache.hpp>
#include <detail/platform_impl.hpp>
#include <detail/program_manager/program_manager.hpp>
Expand Down Expand Up @@ -1108,11 +1107,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
Adapter->call<UrApiKind::urProgramRetain>(ResProgram);
return ResProgram;
}
// When caching is enabled, the returned UrProgram and UrKernel will
// already have their ref count incremented.
std::tuple<ur_kernel_handle_t, std::mutex *, const KernelArgMask *,
ur_program_handle_t>
ProgramManager::getOrCreateKernel(

FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
const ContextImplPtr &ContextImpl, device_impl &DeviceImpl,
KernelNameStrRefT KernelName,
KernelNameBasedCacheT *KernelNameBasedCachePtr, const NDRDescT &NDRDesc) {
Expand All @@ -1129,18 +1125,11 @@ ProgramManager::getOrCreateKernel(
KernelNameBasedCachePtr ? &KernelNameBasedCachePtr->FastKernelSubcache
: nullptr;
if (SYCLConfig<SYCL_CACHE_IN_MEM>::get()) {
auto ret_tuple =
auto KernelCacheValPtr =
Cache.tryToGetKernelFast(KernelName, UrDevice, CacheHintPtr);
constexpr size_t Kernel = 0; // see FastKernelCacheValT tuple
constexpr size_t Program = 3; // see FastKernelCacheValT tuple
if (std::get<Kernel>(ret_tuple)) {
// Pulling a copy of a kernel and program from the cache,
// so we need to retain those resources.
ContextImpl->getAdapter()->call<UrApiKind::urKernelRetain>(
std::get<Kernel>(ret_tuple));
ContextImpl->getAdapter()->call<UrApiKind::urProgramRetain>(
std::get<Program>(ret_tuple));
return ret_tuple;
if (auto KernelCacheValPtr =
Cache.tryToGetKernelFast(KernelName, UrDevice, CacheHintPtr)) {
return KernelCacheValPtr;
}
}

Expand Down Expand Up @@ -1179,20 +1168,21 @@ ProgramManager::getOrCreateKernel(
// threads when caching is disabled, so we can return
// nullptr for the mutex.
auto [Kernel, ArgMask] = BuildF();
return make_tuple(Kernel, nullptr, ArgMask, Program);
return std::make_shared<FastKernelCacheVal>(
Kernel, nullptr, ArgMask, Program, *ContextImpl->getAdapter().get());
}

auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
// getOrBuild is not supposed to return nullptr
assert(BuildResult != nullptr && "Invalid build result");
const KernelArgMaskPairT &KernelArgMaskPair = BuildResult->Val;
auto ret_val = std::make_tuple(KernelArgMaskPair.first,
&(BuildResult->MBuildResultMutex),
KernelArgMaskPair.second, Program);
auto ret_val = std::make_shared<FastKernelCacheVal>(
KernelArgMaskPair.first, &(BuildResult->MBuildResultMutex),
KernelArgMaskPair.second, Program, *ContextImpl->getAdapter().get());
// If caching is enabled, one copy of the kernel handle will be
// stored in the cache, and one handle is returned to the
// caller. In that case, we need to increase the ref count of the
// kernel.
// stored in FastKernelCacheVal, and one is in
// KernelProgramCache::MKernelsPerProgramCache. To cover
// MKernelsPerProgramCache, we need to increase the ref count of the kernel.
ContextImpl->getAdapter()->call<UrApiKind::urKernelRetain>(
KernelArgMaskPair.first);
Cache.saveKernel(KernelName, UrDevice, ret_val, CacheHintPtr);
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <detail/device_global_map_entry.hpp>
#include <detail/host_pipe_map_entry.hpp>
#include <detail/kernel_arg_mask.hpp>
#include <detail/kernel_name_based_cache_t.hpp>
#include <detail/spec_constant_impl.hpp>
#include <sycl/detail/cg_types.hpp>
#include <sycl/detail/common.hpp>
Expand Down Expand Up @@ -197,8 +198,7 @@ class ProgramManager {
const DevImgPlainWithDeps *DevImgWithDeps = nullptr,
const SerializedObj &SpecConsts = {});

std::tuple<ur_kernel_handle_t, std::mutex *, const KernelArgMask *,
ur_program_handle_t>
FastKernelCacheValPtr
getOrCreateKernel(const ContextImplPtr &ContextImpl, device_impl &DeviceImpl,
KernelNameStrRefT KernelName,
KernelNameBasedCacheT *KernelNameBasedCachePtr,
Expand Down
Loading