Skip to content

[SYCL][NFC] Make kernel_impl::getAdapter() return by reference #19313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/source/detail/error_handling/error_handling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ void handleErrorOrWarning(ur_result_t Error, const device_impl &DeviceImpl,

namespace detail::kernel_get_group_info {
void handleErrorOrWarning(ur_result_t Error, ur_kernel_group_info_t Descriptor,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to drop const?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, const was on shared_ptr<adapter_impl> and not the adapter_impl object itself. Ensuring const-correctness on adapter_impl object throughout the code would be orthogonal to the refactoring work that I'm doing and it's not trivial (I tried that before). So, for uniformity, throughout the refactoring, I've replaced const AdapterPtr with adapter_impl.
With that said, for the handleErrorOrWarning function specifically, since adapter_impl object is not passed around anywhere, I can add const adapter_impl& if that's desirable.

assert(Error != UR_RESULT_SUCCESS &&
"Success is expected to be handled on caller side");
switch (Error) {
Expand All @@ -483,7 +483,7 @@ void handleErrorOrWarning(ur_result_t Error, ur_kernel_group_info_t Descriptor,
break;
// TODO: Handle other error codes
default:
Adapter->checkUrResult(Error);
Adapter.checkUrResult(Error);
break;
}
}
Expand Down
3 changes: 1 addition & 2 deletions sycl/source/detail/error_handling/error_handling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ void handleErrorOrWarning(ur_result_t, const device_impl &, ur_kernel_handle_t,

namespace kernel_get_group_info {
/// Analyzes error code of urKernelGetGroupInfo.
void handleErrorOrWarning(ur_result_t, ur_kernel_group_info_t,
const AdapterPtr &);
void handleErrorOrWarning(ur_result_t, ur_kernel_group_info_t, adapter_impl &);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question again.

} // namespace kernel_get_group_info

} // namespace detail
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/kernel_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ kernel_impl::kernel_impl(ur_kernel_handle_t Kernel, context_impl &Context,
MIsInterop(true), MKernelArgMaskPtr{ArgMask} {
ur_context_handle_t UrContext = nullptr;
// Using the adapter from the passed ContextImpl
getAdapter()->call<UrApiKind::urKernelGetInfo>(
getAdapter().call<UrApiKind::urKernelGetInfo>(
MKernel, UR_KERNEL_INFO_CONTEXT, sizeof(UrContext), &UrContext, nullptr);
if (Context.getHandleRef() != UrContext)
throw sycl::exception(
Expand Down Expand Up @@ -61,7 +61,7 @@ kernel_impl::kernel_impl(ur_kernel_handle_t Kernel, context_impl &ContextImpl,
kernel_impl::~kernel_impl() {
try {
// TODO catch an exception and put it to list of asynchronous exceptions
getAdapter()->call<UrApiKind::urKernelRelease>(MKernel);
getAdapter().call<UrApiKind::urKernelRelease>(MKernel);
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~kernel_impl", e);
}
Expand Down Expand Up @@ -135,7 +135,7 @@ void kernel_impl::enableUSMIndirectAccess() const {
// Some UR Adapters (like OpenCL) require this call to enable USM
// For others, UR will turn this into a NOP.
bool EnableAccess = true;
getAdapter()->call<UrApiKind::urKernelSetExecInfo>(
getAdapter().call<UrApiKind::urKernelSetExecInfo>(
MKernel, UR_KERNEL_EXEC_INFO_USM_INDIRECT_ACCESS, sizeof(ur_bool_t),
nullptr, &EnableAccess);
}
Expand Down
31 changes: 16 additions & 15 deletions sycl/source/detail/kernel_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ class kernel_impl {
/// \return a valid cl_kernel instance
cl_kernel get() const {
ur_native_handle_t nativeHandle = 0;
getAdapter()->call<UrApiKind::urKernelGetNativeHandle>(MKernel,
&nativeHandle);
getAdapter().call<UrApiKind::urKernelGetNativeHandle>(MKernel,
&nativeHandle);
__SYCL_OCL_CALL(clRetainKernel, ur::cast<cl_kernel>(nativeHandle));
return ur::cast<cl_kernel>(nativeHandle);
}

const AdapterPtr &getAdapter() const { return MContext->getAdapter(); }
adapter_impl &getAdapter() const { return *MContext->getAdapter(); }

/// Query information from the kernel object using the info::kernel_info
/// descriptor.
Expand Down Expand Up @@ -360,7 +360,7 @@ kernel_impl::queryMaxNumWorkGroups(queue Queue,
throw exception(sycl::make_error_code(errc::invalid),
"The launch work-group size cannot be zero.");

const auto &Adapter = getAdapter();
adapter_impl &Adapter = getAdapter();
const auto &Handle = getHandleRef();
auto Device = Queue.get_device();
auto DeviceHandleRef = sycl::detail::getSyclObjImpl(Device)->getHandleRef();
Expand All @@ -373,15 +373,16 @@ kernel_impl::queryMaxNumWorkGroups(queue Queue,
WG[2] = WorkGroupSize[2];

uint32_t GroupCount{0};
if (auto Result = Adapter->call_nocheck<
UrApiKind::urKernelSuggestMaxCooperativeGroupCount>(
Handle, DeviceHandleRef, Dimensions, WG, DynamicLocalMemorySize,
&GroupCount);
if (auto Result =
Adapter
.call_nocheck<UrApiKind::urKernelSuggestMaxCooperativeGroupCount>(
Handle, DeviceHandleRef, Dimensions, WG,
DynamicLocalMemorySize, &GroupCount);
Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE &&
Result != UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE) {
// The feature is supported and the group size is valid. Check for other
// errors and throw if any.
Adapter->checkUrResult(Result);
Adapter.checkUrResult(Result);
return GroupCount;
}

Expand Down Expand Up @@ -452,12 +453,12 @@ inline typename syclex::info::kernel_queue_specific::max_work_group_size::
kernel_impl::ext_oneapi_get_info<
syclex::info::kernel_queue_specific::max_work_group_size>(
queue Queue) const {
const auto &Adapter = getAdapter();
adapter_impl &Adapter = getAdapter();
const auto DeviceNativeHandle =
getSyclObjImpl(Queue.get_device())->getHandleRef();

size_t KernelWGSize = 0;
Adapter->call<UrApiKind::urKernelGetGroupInfo>(
Adapter.call<UrApiKind::urKernelGetGroupInfo>(
MKernel, DeviceNativeHandle, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
sizeof(size_t), &KernelWGSize, nullptr);
return KernelWGSize;
Expand Down Expand Up @@ -508,13 +509,13 @@ ADD_TEMPLATE_METHOD_SPEC(3)
if (WG.size() == 0) \
throw exception(sycl::make_error_code(errc::invalid), \
"The work-group size cannot be zero."); \
const auto &Adapter = getAdapter(); \
adapter_impl &Adapter = getAdapter(); \
const auto DeviceNativeHandle = \
getSyclObjImpl(Queue.get_device())->getHandleRef(); \
uint32_t KernelSubWGSize = 0; \
Adapter->call<UrApiKind::Kind>(MKernel, DeviceNativeHandle, Reg, \
sizeof(uint32_t), &KernelSubWGSize, \
nullptr); \
Adapter.call<UrApiKind::Kind>(MKernel, DeviceNativeHandle, Reg, \
sizeof(uint32_t), &KernelSubWGSize, \
nullptr); \
return KernelSubWGSize; \
}

Expand Down
60 changes: 30 additions & 30 deletions sycl/source/detail/kernel_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,33 @@ template <typename Param>
typename std::enable_if<
std::is_same<typename Param::return_type, std::string>::value,
std::string>::type
get_kernel_info(ur_kernel_handle_t Kernel, const AdapterPtr &Adapter) {
get_kernel_info(ur_kernel_handle_t Kernel, adapter_impl &Adapter) {
static_assert(detail::is_kernel_info_desc<Param>::value,
"Invalid kernel information descriptor");
size_t ResultSize = 0;

// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value, 0,
nullptr, &ResultSize);
Adapter.call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value, 0,
nullptr, &ResultSize);
if (ResultSize == 0) {
return "";
}
std::vector<char> Result(ResultSize);
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value,
ResultSize, Result.data(), nullptr);
Adapter.call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value,
ResultSize, Result.data(), nullptr);
return std::string(Result.data());
}

template <typename Param>
typename std::enable_if<
std::is_same<typename Param::return_type, uint32_t>::value, uint32_t>::type
get_kernel_info(ur_kernel_handle_t Kernel, const AdapterPtr &Adapter) {
get_kernel_info(ur_kernel_handle_t Kernel, adapter_impl &Adapter) {
ur_result_t Result = UR_RESULT_SUCCESS;

// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value,
sizeof(uint32_t), &Result, nullptr);
Adapter.call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value,
sizeof(uint32_t), &Result, nullptr);
return Result;
}

Expand All @@ -80,29 +80,29 @@ template <typename Param>
typename std::enable_if<IsSubGroupInfo<Param>::value>::type
get_kernel_device_specific_info_helper(ur_kernel_handle_t Kernel,
ur_device_handle_t Device,
const AdapterPtr &Adapter, void *Result,
adapter_impl &Adapter, void *Result,
size_t Size) {
Adapter->call<UrApiKind::urKernelGetSubGroupInfo>(
Adapter.call<UrApiKind::urKernelGetSubGroupInfo>(
Kernel, Device, UrInfoCode<Param>::value, Size, Result, nullptr);
}

template <typename Param>
typename std::enable_if<IsKernelInfo<Param>::value>::type
get_kernel_device_specific_info_helper(
ur_kernel_handle_t Kernel, [[maybe_unused]] ur_device_handle_t Device,
const AdapterPtr &Adapter, void *Result, size_t Size) {
Adapter->call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value,
Size, Result, nullptr);
adapter_impl &Adapter, void *Result, size_t Size) {
Adapter.call<UrApiKind::urKernelGetInfo>(Kernel, UrInfoCode<Param>::value,
Size, Result, nullptr);
}

template <typename Param>
typename std::enable_if<!IsSubGroupInfo<Param>::value &&
!IsKernelInfo<Param>::value>::type
get_kernel_device_specific_info_helper(ur_kernel_handle_t Kernel,
ur_device_handle_t Device,
const AdapterPtr &Adapter, void *Result,
adapter_impl &Adapter, void *Result,
size_t Size) {
ur_result_t Error = Adapter->call_nocheck<UrApiKind::urKernelGetGroupInfo>(
ur_result_t Error = Adapter.call_nocheck<UrApiKind::urKernelGetGroupInfo>(
Kernel, Device, UrInfoCode<Param>::value, Size, Result, nullptr);
if (Error != UR_RESULT_SUCCESS)
kernel_get_group_info::handleErrorOrWarning(Error, UrInfoCode<Param>::value,
Expand All @@ -115,7 +115,7 @@ typename std::enable_if<
typename Param::return_type>::type
get_kernel_device_specific_info(ur_kernel_handle_t Kernel,
ur_device_handle_t Device,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
static_assert(is_kernel_device_specific_info_desc<Param>::value,
"Unexpected kernel_device_specific information descriptor");
typename Param::return_type Result = {};
Expand All @@ -131,7 +131,7 @@ typename std::enable_if<
sycl::range<3>>::type
get_kernel_device_specific_info(ur_kernel_handle_t Kernel,
ur_device_handle_t Device,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
static_assert(is_kernel_device_specific_info_desc<Param>::value,
"Unexpected kernel_device_specific information descriptor");
size_t Result[3] = {0, 0, 0};
Expand All @@ -148,7 +148,7 @@ template <typename Param>
uint32_t get_kernel_device_specific_info_with_input(ur_kernel_handle_t Kernel,
ur_device_handle_t Device,
sycl::range<3>,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
static_assert(is_kernel_device_specific_info_desc<Param>::value,
"Unexpected kernel_device_specific information descriptor");
static_assert(std::is_same<typename Param::return_type, uint32_t>::value,
Expand All @@ -159,7 +159,7 @@ uint32_t get_kernel_device_specific_info_with_input(ur_kernel_handle_t Kernel,

uint32_t Result = 0;
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urKernelGetSubGroupInfo>(
Adapter.call<UrApiKind::urKernelGetSubGroupInfo>(
Kernel, Device, UrInfoCode<Param>::value, sizeof(uint32_t), &Result,
nullptr);

Expand All @@ -171,35 +171,35 @@ inline ext::intel::info::kernel_device_specific::spill_memory_size::return_type
get_kernel_device_specific_info<
ext::intel::info::kernel_device_specific::spill_memory_size>(
ur_kernel_handle_t Kernel, ur_device_handle_t Device,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
size_t ResultSize = 0;

// First call to get the number of device images
Adapter->call<UrApiKind::urKernelGetInfo>(
Adapter.call<UrApiKind::urKernelGetInfo>(
Kernel, UR_KERNEL_INFO_SPILL_MEM_SIZE, 0, nullptr, &ResultSize);

size_t DeviceCount = ResultSize / sizeof(uint32_t);

// Second call to retrieve the data
std::vector<uint32_t> Device2SpillMap(DeviceCount);
Adapter->call<UrApiKind::urKernelGetInfo>(
Adapter.call<UrApiKind::urKernelGetInfo>(
Kernel, UR_KERNEL_INFO_SPILL_MEM_SIZE, ResultSize, Device2SpillMap.data(),
nullptr);

ur_program_handle_t Program;
Adapter->call<UrApiKind::urKernelGetInfo>(Kernel, UR_KERNEL_INFO_PROGRAM,
sizeof(ur_program_handle_t),
&Program, nullptr);
Adapter.call<UrApiKind::urKernelGetInfo>(Kernel, UR_KERNEL_INFO_PROGRAM,
sizeof(ur_program_handle_t),
&Program, nullptr);
// Retrieve the associated device list
size_t URDevicesSize = 0;
Adapter->call<UrApiKind::urProgramGetInfo>(Program, UR_PROGRAM_INFO_DEVICES,
0, nullptr, &URDevicesSize);
Adapter.call<UrApiKind::urProgramGetInfo>(Program, UR_PROGRAM_INFO_DEVICES, 0,
nullptr, &URDevicesSize);

std::vector<ur_device_handle_t> URDevices(URDevicesSize /
sizeof(ur_device_handle_t));
Adapter->call<UrApiKind::urProgramGetInfo>(Program, UR_PROGRAM_INFO_DEVICES,
URDevicesSize, URDevices.data(),
nullptr);
Adapter.call<UrApiKind::urProgramGetInfo>(Program, UR_PROGRAM_INFO_DEVICES,
URDevicesSize, URDevices.data(),
nullptr);
assert(Device2SpillMap.size() == URDevices.size());

// Map the result back to the program devices. UR provides the following
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2271,7 +2271,7 @@ static void adjustNDRangePerKernel(NDRDescT &NDR, ur_kernel_handle_t Kernel,
// avoid get_kernel_work_group_info on every kernel run
range<3> WGSize = get_kernel_device_specific_info<
sycl::info::kernel_device_specific::compile_work_group_size>(
Kernel, DeviceImpl.getHandleRef(), DeviceImpl.getAdapter());
Kernel, DeviceImpl.getHandleRef(), *DeviceImpl.getAdapter());

if (WGSize[0] == 0) {
WGSize = {1, 1, 1};
Expand Down
Loading