Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl/source/detail/kernel_name_based_cache_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct FastKernelSubcacheT {

struct KernelNameBasedCacheT {
FastKernelSubcacheT FastKernelSubcache;
std::optional<std::optional<int>> ImplicitLocalArgPos;
};

} // namespace detail
Expand Down
24 changes: 18 additions & 6 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1851,12 +1851,24 @@ void ProgramManager::cacheKernelImplicitLocalArg(RTDeviceBinaryImage &Img) {
}
}

std::optional<int>
ProgramManager::kernelImplicitLocalArgPos(KernelNameStrRefT KernelName) const {
auto it = m_KernelImplicitLocalArgPos.find(KernelName);
if (it != m_KernelImplicitLocalArgPos.end())
return it->second;
return {};
std::optional<int> ProgramManager::kernelImplicitLocalArgPos(
KernelNameStrRefT KernelName,
KernelNameBasedCacheT *KernelNameBasedCachePtr) const {
auto getLocalArgPos = [&]() -> std::optional<int> {
auto it = m_KernelImplicitLocalArgPos.find(KernelName);
if (it != m_KernelImplicitLocalArgPos.end())
return it->second;
return {};
};

if (!KernelNameBasedCachePtr)
return getLocalArgPos();
std::optional<std::optional<int>> &ImplicitLocalArgPos =
KernelNameBasedCachePtr->ImplicitLocalArgPos;
if (!ImplicitLocalArgPos.has_value()) {
ImplicitLocalArgPos = getLocalArgPos();
}
return ImplicitLocalArgPos.value();
}

static bool isBfloat16DeviceLibImage(sycl_device_binary RawImg,
Expand Down
5 changes: 3 additions & 2 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,9 @@ class ProgramManager {

SanitizerType kernelUsesSanitizer() const { return m_SanitizerFoundInImage; }

std::optional<int>
kernelImplicitLocalArgPos(KernelNameStrRefT KernelName) const;
std::optional<int> kernelImplicitLocalArgPos(
KernelNameStrRefT KernelName,
KernelNameBasedCacheT *KernelNameBasedCachePtr) const;

std::set<RTDeviceBinaryImage *>
getRawDeviceImages(const std::vector<kernel_id> &KernelIDs);
Expand Down
12 changes: 7 additions & 5 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2390,8 +2390,9 @@ static ur_result_t SetKernelParamsAndLaunch(
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
bool IsCooperative, bool KernelUsesClusterLaunch,
uint32_t WorkGroupMemorySize, const RTDeviceBinaryImage *BinImage,
KernelNameStrRefT KernelName, void *KernelFuncPtr = nullptr,
int KernelNumArgs = 0,
KernelNameStrRefT KernelName,
KernelNameBasedCacheT *KernelNameBasedCachePtr,
void *KernelFuncPtr = nullptr, int KernelNumArgs = 0,
detail::kernel_param_desc_t (*KernelParamDescGetter)(int) = nullptr,
bool KernelHasSpecialCaptures = true) {
assert(Queue && "Kernel submissions should have an associated queue");
Expand Down Expand Up @@ -2439,7 +2440,8 @@ static ur_result_t SetKernelParamsAndLaunch(
}

std::optional<int> ImplicitLocalArg =
ProgramManager::getInstance().kernelImplicitLocalArgPos(KernelName);
ProgramManager::getInstance().kernelImplicitLocalArgPos(
KernelName, KernelNameBasedCachePtr);
// Set the implicit local memory buffer to support
// get_work_group_scratch_memory. This is for backend not supporting
// CUDA-style local memory setting. Note that we may have -1 as a position,
Expand Down Expand Up @@ -2775,8 +2777,8 @@ void enqueueImpKernel(
Queue, Args, DeviceImageImpl, Kernel, NDRDesc, EventsWaitList,
OutEventImpl, EliminatedArgMask, getMemAllocationFunc,
KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize,
BinImage, KernelName, KernelFuncPtr, KernelNumArgs,
KernelParamDescGetter, KernelHasSpecialCaptures);
BinImage, KernelName, KernelNameBasedCachePtr, KernelFuncPtr,
KernelNumArgs, KernelParamDescGetter, KernelHasSpecialCaptures);

const AdapterPtr &Adapter = Queue->getAdapter();
if (!SyclKernelImpl && !MSyclKernel) {
Expand Down
Loading