Skip to content

Commit f8ac7f2

Browse files
[SYCL] Cache implicit local arg position info (#18654)
1 parent 67d92b5 commit f8ac7f2

File tree

4 files changed

+32
-13
lines changed

4 files changed

+32
-13
lines changed

sycl/source/detail/kernel_name_based_cache_t.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ struct FastKernelSubcacheT {
7474
struct KernelNameBasedCacheT {
7575
FastKernelSubcacheT FastKernelSubcache;
7676
std::optional<bool> UsesAssert;
77+
// Implicit local argument position is represented by an optional int, this
78+
// uses another optional on top of that to represent lazy initialization of
79+
// the cached value.
80+
std::optional<std::optional<int>> ImplicitLocalArgPos;
7781
};
7882

7983
} // namespace detail

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,12 +1841,24 @@ void ProgramManager::cacheKernelImplicitLocalArg(RTDeviceBinaryImage &Img) {
18411841
}
18421842
}
18431843

1844-
std::optional<int>
1845-
ProgramManager::kernelImplicitLocalArgPos(KernelNameStrRefT KernelName) const {
1846-
auto it = m_KernelImplicitLocalArgPos.find(KernelName);
1847-
if (it != m_KernelImplicitLocalArgPos.end())
1848-
return it->second;
1849-
return {};
1844+
std::optional<int> ProgramManager::kernelImplicitLocalArgPos(
1845+
KernelNameStrRefT KernelName,
1846+
KernelNameBasedCacheT *KernelNameBasedCachePtr) const {
1847+
auto getLocalArgPos = [&]() -> std::optional<int> {
1848+
auto it = m_KernelImplicitLocalArgPos.find(KernelName);
1849+
if (it != m_KernelImplicitLocalArgPos.end())
1850+
return it->second;
1851+
return {};
1852+
};
1853+
1854+
if (!KernelNameBasedCachePtr)
1855+
return getLocalArgPos();
1856+
std::optional<std::optional<int>> &ImplicitLocalArgPos =
1857+
KernelNameBasedCachePtr->ImplicitLocalArgPos;
1858+
if (!ImplicitLocalArgPos.has_value()) {
1859+
ImplicitLocalArgPos = getLocalArgPos();
1860+
}
1861+
return ImplicitLocalArgPos.value();
18501862
}
18511863

18521864
static bool isBfloat16DeviceLibImage(sycl_device_binary RawImg,

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,9 @@ class ProgramManager {
373373

374374
SanitizerType kernelUsesSanitizer() const { return m_SanitizerFoundInImage; }
375375

376-
std::optional<int>
377-
kernelImplicitLocalArgPos(KernelNameStrRefT KernelName) const;
376+
std::optional<int> kernelImplicitLocalArgPos(
377+
KernelNameStrRefT KernelName,
378+
KernelNameBasedCacheT *KernelNameBasedCachePtr) const;
378379

379380
std::set<RTDeviceBinaryImage *>
380381
getRawDeviceImages(const std::vector<kernel_id> &KernelIDs);

sycl/source/detail/scheduler/commands.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2389,8 +2389,9 @@ static ur_result_t SetKernelParamsAndLaunch(
23892389
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
23902390
bool IsCooperative, bool KernelUsesClusterLaunch,
23912391
uint32_t WorkGroupMemorySize, const RTDeviceBinaryImage *BinImage,
2392-
KernelNameStrRefT KernelName, void *KernelFuncPtr = nullptr,
2393-
int KernelNumArgs = 0,
2392+
KernelNameStrRefT KernelName,
2393+
KernelNameBasedCacheT *KernelNameBasedCachePtr,
2394+
void *KernelFuncPtr = nullptr, int KernelNumArgs = 0,
23942395
detail::kernel_param_desc_t (*KernelParamDescGetter)(int) = nullptr,
23952396
bool KernelHasSpecialCaptures = true) {
23962397
const AdapterPtr &Adapter = Queue.getAdapter();
@@ -2437,7 +2438,8 @@ static ur_result_t SetKernelParamsAndLaunch(
24372438
}
24382439

24392440
std::optional<int> ImplicitLocalArg =
2440-
ProgramManager::getInstance().kernelImplicitLocalArgPos(KernelName);
2441+
ProgramManager::getInstance().kernelImplicitLocalArgPos(
2442+
KernelName, KernelNameBasedCachePtr);
24412443
// Set the implicit local memory buffer to support
24422444
// get_work_group_scratch_memory. This is for backend not supporting
24432445
// CUDA-style local memory setting. Note that we may have -1 as a position,
@@ -2752,8 +2754,8 @@ void enqueueImpKernel(
27522754
*Queue, Args, DeviceImageImpl, Kernel, NDRDesc, EventsWaitList,
27532755
OutEventImpl, EliminatedArgMask, getMemAllocationFunc,
27542756
KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize,
2755-
BinImage, KernelName, KernelFuncPtr, KernelNumArgs,
2756-
KernelParamDescGetter, KernelHasSpecialCaptures);
2757+
BinImage, KernelName, KernelNameBasedCachePtr, KernelFuncPtr,
2758+
KernelNumArgs, KernelParamDescGetter, KernelHasSpecialCaptures);
27572759
}
27582760
if (UR_RESULT_SUCCESS != Error) {
27592761
// If we have got non-success error code, let's analyze it to emit nice

0 commit comments

Comments
 (0)