Skip to content

[SYCL][NFC] Pass adapter by ref in ur::getAdapter and event:getAdapter #19202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

static const adapter_impl &getAdapter(backend Backend) {
static adapter_impl &getAdapter(backend Backend) {
switch (Backend) {
case backend::opencl:
return *ur::getAdapter<backend::opencl>();
return ur::getAdapter<backend::opencl>();
case backend::ext_oneapi_level_zero:
return *ur::getAdapter<backend::ext_oneapi_level_zero>();
return ur::getAdapter<backend::ext_oneapi_level_zero>();
case backend::ext_oneapi_cuda:
return *ur::getAdapter<backend::ext_oneapi_cuda>();
return ur::getAdapter<backend::ext_oneapi_cuda>();
case backend::ext_oneapi_hip:
return *ur::getAdapter<backend::ext_oneapi_hip>();
return ur::getAdapter<backend::ext_oneapi_hip>();
default:
throw sycl::exception(
sycl::make_error_code(sycl::errc::runtime),
Expand Down Expand Up @@ -71,7 +71,7 @@ backend convertUrBackend(ur_backend_t UrBackend) {
}

platform make_platform(ur_native_handle_t NativeHandle, backend Backend) {
const adapter_impl &Adapter = getAdapter(Backend);
adapter_impl &Adapter = getAdapter(Backend);

// Create UR platform first.
ur_platform_handle_t UrPlatform = nullptr;
Expand All @@ -84,7 +84,7 @@ platform make_platform(ur_native_handle_t NativeHandle, backend Backend) {

__SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle,
backend Backend) {
const adapter_impl &Adapter = getAdapter(Backend);
adapter_impl &Adapter = getAdapter(Backend);

ur_device_handle_t UrDevice = nullptr;
Adapter.call<UrApiKind::urDeviceCreateWithNativeHandle>(
Expand All @@ -100,7 +100,7 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
const async_handler &Handler,
backend Backend, bool KeepOwnership,
const std::vector<device> &DeviceList) {
const adapter_impl &Adapter = getAdapter(Backend);
adapter_impl &Adapter = getAdapter(Backend);

ur_context_handle_t UrContext = nullptr;
ur_context_native_properties_t Properties{};
Expand Down Expand Up @@ -193,7 +193,7 @@ std::shared_ptr<detail::kernel_bundle_impl>
make_kernel_bundle(ur_native_handle_t NativeHandle,
const context &TargetContext, bool KeepOwnership,
bundle_state State, backend Backend) {
const adapter_impl &Adapter = getAdapter(Backend);
adapter_impl &Adapter = getAdapter(Backend);
context_impl &ContextImpl = *getSyclObjImpl(TargetContext);

ur_program_handle_t UrProgram = nullptr;
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/backend/level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ using namespace sycl::detail;

__SYCL_EXPORT device make_device(const platform &Platform,
ur_native_handle_t NativeHandle) {
const auto &Adapter = ur::getAdapter<backend::ext_oneapi_level_zero>();
adapter_impl &Adapter = ur::getAdapter<backend::ext_oneapi_level_zero>();
// Create UR device first.
ur_device_handle_t UrDevice;
Adapter->call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
Adapter.call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Adapter.getUrAdapter(), nullptr, &UrDevice);

return detail::createSyclObjFromImpl<device>(
getSyclObjImpl(Platform)->getOrMakeDeviceImpl(UrDevice));
Expand Down
9 changes: 5 additions & 4 deletions sycl/source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,16 @@ context::context(const std::vector<device> &DeviceList,
impl = detail::context_impl::create(DeviceList, AsyncHandler, PropList);
}
context::context(cl_context ClContext, async_handler AsyncHandler) {
const auto &Adapter = sycl::detail::ur::getAdapter<backend::opencl>();
detail::adapter_impl &Adapter =
sycl::detail::ur::getAdapter<backend::opencl>();

ur_context_handle_t hContext = nullptr;
ur_native_handle_t nativeHandle =
reinterpret_cast<ur_native_handle_t>(ClContext);
Adapter->call<detail::UrApiKind::urContextCreateWithNativeHandle>(
nativeHandle, Adapter->getUrAdapter(), 0, nullptr, nullptr, &hContext);
Adapter.call<detail::UrApiKind::urContextCreateWithNativeHandle>(
nativeHandle, Adapter.getUrAdapter(), 0, nullptr, nullptr, &hContext);

impl = detail::context_impl::create(hContext, AsyncHandler, *Adapter);
impl = detail::context_impl::create(hContext, AsyncHandler, Adapter);
}

template <typename Param>
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/adapter_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class adapter_impl {
return UrPlatforms;
}

ur_adapter_handle_t getUrAdapter() const { return MAdapter; }
ur_adapter_handle_t getUrAdapter() { return MAdapter; }

/// Calls the UR Api, traces the call, and returns the result.
///
Expand Down
5 changes: 2 additions & 3 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
}

context_impl::context_impl(ur_context_handle_t UrContext,
async_handler AsyncHandler,
const adapter_impl &Adapter,
async_handler AsyncHandler, adapter_impl &Adapter,
const std::vector<sycl::device> &DeviceList,
bool OwnedByRuntime, private_tag)
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(AsyncHandler),
Expand Down Expand Up @@ -366,7 +365,7 @@ std::vector<ur_event_handle_t> context_impl::initializeDeviceGlobals(
InitEventsRef.begin(), InitEventsRef.end(),
[&Adapter](const ur_event_handle_t &Event) {
return get_event_info<info::event::command_execution_status>(
Event, Adapter) == info::event_command_status::complete;
Event, *Adapter) == info::event_command_status::complete;
});
// Release the removed events.
for (auto EventIt = NewEnd; EventIt != InitEventsRef.end(); ++EventIt)
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
/// \param OwnedByRuntime is the flag if ownership is kept by user or
/// transferred to runtime
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
const adapter_impl &Adapter,
adapter_impl &Adapter,
const std::vector<sycl::device> &DeviceList, bool OwnedByRuntime,
private_tag);

context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
const adapter_impl &Adapter, private_tag tag)
adapter_impl &Adapter, private_tag tag)
: context_impl(UrContext, AsyncHandler, Adapter,
std::vector<sycl::device>{},
/*OwnedByRuntime*/ true, tag) {}
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/device_global_map_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ OwnedUrEvent DeviceGlobalUSMMem::getInitEvent(const AdapterPtr &Adapter) {
// If there is a init event we can remove it if it is done.
if (MInitEvent.has_value()) {
if (get_event_info<info::event::command_execution_status>(
*MInitEvent, Adapter) == info::event_command_status::complete) {
*MInitEvent, *Adapter) == info::event_command_status::complete) {
Adapter->call<UrApiKind::urEventRelease>(*MInitEvent);
MInitEvent = {};
return OwnedUrEvent(Adapter);
Expand Down
22 changes: 11 additions & 11 deletions sycl/source/detail/event_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ event_impl::~event_impl() {
try {
auto Handle = this->getHandle();
if (Handle)
getAdapter()->call<UrApiKind::urEventRelease>(Handle);
getAdapter().call<UrApiKind::urEventRelease>(Handle);
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~event_impl", e);
}
Expand All @@ -59,7 +59,7 @@ void event_impl::waitInternal(bool *Success) {
if (!MIsHostEvent && Handle) {
// Wait for the native event
ur_result_t Err =
getAdapter()->call_nocheck<UrApiKind::urEventWait>(1, &Handle);
getAdapter().call_nocheck<UrApiKind::urEventWait>(1, &Handle);
// TODO drop the UR_RESULT_ERROR_UKNOWN from here (this was waiting for
// https://github.com/oneapi-src/unified-runtime/issues/1459 which is now
// closed).
Expand All @@ -68,7 +68,7 @@ void event_impl::waitInternal(bool *Success) {
Err == UR_RESULT_ERROR_IN_EVENT_LIST_EXEC_STATUS))
*Success = false;
else {
getAdapter()->checkUrResult(Err);
getAdapter().checkUrResult(Err);
if (Success != nullptr)
*Success = true;
}
Expand Down Expand Up @@ -148,9 +148,9 @@ context_impl &event_impl::getContextImpl() {
return *MContext;
}

const AdapterPtr &event_impl::getAdapter() {
adapter_impl &event_impl::getAdapter() {
initContextIfNeeded();
return MContext->getAdapter();
return *MContext->getAdapter();
}

void event_impl::setStateIncomplete() { MState = HES_NotComplete; }
Expand All @@ -166,7 +166,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
MIsFlushed(true), MState(HES_Complete) {

ur_context_handle_t TempContext;
getAdapter()->call<UrApiKind::urEventGetInfo>(
getAdapter().call<UrApiKind::urEventGetInfo>(
this->getHandle(), UR_EVENT_INFO_CONTEXT, sizeof(ur_context_handle_t),
&TempContext, nullptr);

Expand Down Expand Up @@ -519,19 +519,19 @@ ur_native_handle_t event_impl::getNative() {
return {};
initContextIfNeeded();

auto Adapter = getAdapter();
adapter_impl &Adapter = getAdapter();
auto Handle = getHandle();
if (MIsDefaultConstructed && !Handle) {
auto TempContext = MContext.get()->getHandleRef();
ur_event_native_properties_t NativeProperties{};
ur_event_handle_t UREvent = nullptr;
Adapter->call<UrApiKind::urEventCreateWithNativeHandle>(
Adapter.call<UrApiKind::urEventCreateWithNativeHandle>(
0, TempContext, &NativeProperties, &UREvent);
this->setHandle(UREvent);
Handle = UREvent;
}
ur_native_handle_t OutHandle;
Adapter->call<UrApiKind::urEventGetNativeHandle>(Handle, &OutHandle);
Adapter.call<UrApiKind::urEventGetNativeHandle>(Handle, &OutHandle);
if (MContext->getBackend() == backend::opencl)
__SYCL_OCL_CALL(clRetainEvent, ur::cast<cl_event>(OutHandle));
return OutHandle;
Expand Down Expand Up @@ -569,11 +569,11 @@ void event_impl::flushIfNeeded(queue_impl *UserQueue) {

// Check if the task for this event has already been submitted.
ur_event_status_t Status = UR_EVENT_STATUS_QUEUED;
getAdapter()->call<UrApiKind::urEventGetInfo>(
getAdapter().call<UrApiKind::urEventGetInfo>(
Handle, UR_EVENT_INFO_COMMAND_EXECUTION_STATUS, sizeof(ur_event_status_t),
&Status, nullptr);
if (Status == UR_EVENT_STATUS_QUEUED) {
getAdapter()->call<UrApiKind::urQueueFlush>(Queue->getHandleRef());
getAdapter().call<UrApiKind::urQueueFlush>(Queue->getHandleRef());
}
MIsFlushed = true;
}
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/event_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class event_impl {

/// \return the Adapter associated with the context of this event.
/// Should be called when this is not a Host Event.
const AdapterPtr &getAdapter();
adapter_impl &getAdapter();

/// Associate event with the context.
///
Expand Down
12 changes: 6 additions & 6 deletions sycl/source/detail/event_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ inline namespace _V1 {
namespace detail {

template <typename Param>
typename Param::return_type
get_event_profiling_info(ur_event_handle_t Event, const AdapterPtr &Adapter) {
typename Param::return_type get_event_profiling_info(ur_event_handle_t Event,
adapter_impl &Adapter) {
static_assert(is_event_profiling_info_desc<Param>::value,
"Unexpected event profiling info descriptor");
typename Param::return_type Result{0};
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urEventGetProfilingInfo>(
Adapter.call<UrApiKind::urEventGetProfilingInfo>(
Event, UrInfoCode<Param>::value, sizeof(Result), &Result, nullptr);
return Result;
}

template <typename Param>
typename Param::return_type get_event_info(ur_event_handle_t Event,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
static_assert(is_event_info_desc<Param>::value,
"Unexpected event info descriptor");
typename Param::return_type Result{0};
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urEventGetInfo>(Event, UrInfoCode<Param>::value,
sizeof(Result), &Result, nullptr);
Adapter.call<UrApiKind::urEventGetInfo>(Event, UrInfoCode<Param>::value,
sizeof(Result), &Result, nullptr);

// If the status is UR_EVENT_STATUS_QUEUED We need to change it since QUEUE is
// not a valid status in sycl.
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/detail/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ static void waitForEvents(const std::vector<EventImplPtr> &Events) {
// Assuming all events will be on the same device or
// devices associated with the same Backend.
if (!Events.empty()) {
const AdapterPtr &Adapter = Events[0]->getAdapter();
adapter_impl &Adapter = Events[0]->getAdapter();
std::vector<ur_event_handle_t> UrEvents(Events.size());
std::transform(
Events.begin(), Events.end(), UrEvents.begin(),
[](const EventImplPtr &EventImpl) { return EventImpl->getHandle(); });
if (!UrEvents.empty() && UrEvents[0]) {
Adapter->call<UrApiKind::urEventWait>(UrEvents.size(), &UrEvents[0]);
Adapter.call<UrApiKind::urEventWait>(UrEvents.size(), &UrEvents[0]);
}
}
}
Expand Down Expand Up @@ -318,8 +318,8 @@ void *MemoryManager::allocateInteropMemObject(
// Retain the event since it will be released during alloca command
// destruction
if (nullptr != OutEventToWait) {
const AdapterPtr &Adapter = InteropEvent->getAdapter();
Adapter->call<UrApiKind::urEventRetain>(OutEventToWait);
adapter_impl &Adapter = InteropEvent->getAdapter();
Adapter.call<UrApiKind::urEventRetain>(OutEventToWait);
}
return UserPtr;
}
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace detail {

platform_impl &
platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,
const adapter_impl &Adapter) {
adapter_impl &Adapter) {
std::shared_ptr<platform_impl> Result;
{
const std::lock_guard<std::mutex> Guard(
Expand All @@ -50,8 +50,8 @@ platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,
// Otherwise make the impl. Our ctor/dtor are private, so std::make_shared
// needs a bit of help...
struct creator : platform_impl {
creator(ur_platform_handle_t APlatform, const adapter_impl &AAdapter)
: platform_impl(APlatform, &AAdapter) {}
creator(ur_platform_handle_t APlatform, adapter_impl &AAdapter)
: platform_impl(APlatform, AAdapter) {}
};
Result = std::make_shared<creator>(UrPlatform, Adapter);
PlatformCache.emplace_back(Result);
Expand All @@ -62,7 +62,7 @@ platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,

platform_impl &
platform_impl::getPlatformFromUrDevice(ur_device_handle_t UrDevice,
const adapter_impl &Adapter) {
adapter_impl &Adapter) {
ur_platform_handle_t Plt =
nullptr; // TODO catch an exception and put it to list
// of asynchronous exceptions
Expand Down
15 changes: 6 additions & 9 deletions sycl/source/detail/platform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,16 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
/// Constructs platform_impl from a UR platform handle.
///
/// \param APlatform is a raw plug-in platform handle.
/// \param AAdapter is a plug-in handle.
/// \param Adapter is a plug-in handle.
//
// Platforms can only be created under `GlobalHandler`'s ownership via
// `platform_impl::getOrMakePlatformImpl` method.
explicit platform_impl(ur_platform_handle_t APlatform,
const adapter_impl *AAdapter)
: MPlatform(APlatform) {

MAdapter = const_cast<AdapterPtr>(AAdapter);
explicit platform_impl(ur_platform_handle_t APlatform, adapter_impl &Adapter)
: MPlatform(APlatform), MAdapter(&Adapter) {

// Find out backend of the platform
ur_backend_t UrBackend = UR_BACKEND_UNKNOWN;
AAdapter->call_nocheck<UrApiKind::urPlatformGetInfo>(
Adapter.call_nocheck<UrApiKind::urPlatformGetInfo>(
APlatform, UR_PLATFORM_INFO_BACKEND, sizeof(ur_backend_t), &UrBackend,
nullptr);
MBackend = convertUrBackend(UrBackend);
Expand Down Expand Up @@ -183,7 +180,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
/// \param Adapter is the UR adapter providing the backend for the platform
/// \return the platform_impl representing the UR platform
static platform_impl &getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,
const adapter_impl &Adapter);
adapter_impl &Adapter);

/// Queries the cache for the specified platform based on an input device.
/// If found, returns the the cached platform_impl, otherwise creates a new
Expand All @@ -195,7 +192,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
/// platform
/// \return the platform_impl that contains the input device
static platform_impl &getPlatformFromUrDevice(ur_device_handle_t UrDevice,
const adapter_impl &Adapter);
adapter_impl &Adapter);

context_impl &khr_get_default_context();

Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,12 @@ class DispatchHostTask {
std::vector<ur_mem_handle_t> MReqUrMem;

bool waitForEvents() const {
std::map<const AdapterPtr, std::vector<EventImplPtr>>
std::map<adapter_impl *, std::vector<EventImplPtr>>
RequiredEventsPerAdapter;

for (const EventImplPtr &Event : MThisCmd->MPreparedDepsEvents) {
const AdapterPtr &Adapter = Event->getAdapter();
RequiredEventsPerAdapter[Adapter].push_back(Event);
adapter_impl &Adapter = Event->getAdapter();
RequiredEventsPerAdapter[&Adapter].push_back(Event);
}

// wait for dependency device events
Expand Down
Loading