Skip to content

[SYCL][UR][L0 v2] get rid of std::function in memory.hpp #18655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ ur_result_t getMemPtr(ur_mem_handle_t memObj,
urAccessMode =
ur_mem_buffer_t::getDeviceAccessMode(properties->memoryAccess);
}
ptr = ur_cast<char *>(
memBuffer->getDevicePtr(device, urAccessMode, 0, memBuffer->getSize(),
[&](void *, void *, size_t) {}));
wait_list_view emptyWaitList(nullptr, 0);
ptr = ur_cast<char *>(memBuffer->getDevicePtr(
device, urAccessMode, 0, memBuffer->getSize(), nullptr, emptyWaitList));
}
assert(ptrStorage != nullptr);
ptrStorage->push_back(std::make_unique<char *>(ptr));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "../ur_interface_loader.hpp"
#include "context.hpp"
#include "kernel.hpp"
#include "memory.hpp"

ur_command_list_manager::ur_command_list_manager(
ur_context_handle_t context, ur_device_handle_t device,
Expand Down Expand Up @@ -44,12 +45,7 @@ ur_result_t ur_command_list_manager::appendGenericFillUnlocked(

auto pDst = ur_cast<char *>(dst->getDevicePtr(
device, ur_mem_buffer_t::device_access_mode_t::read_only, offset, size,
[&](void *src, void *dst, size_t size) {
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
(zeCommandList.get(), dst, src, size, nullptr,
waitListView.num, waitListView.handles));
waitListView.clear();
}));
zeCommandList.get(), waitListView));

// PatternSize must be a power of two for zeCommandListAppendMemoryFill.
// When it's not, the fill is emulated with zeCommandListAppendMemoryCopy.
Expand Down Expand Up @@ -87,21 +83,11 @@ ur_result_t ur_command_list_manager::appendGenericCopyUnlocked(

auto pSrc = ur_cast<char *>(src->getDevicePtr(
device, ur_mem_buffer_t::device_access_mode_t::read_only, srcOffset, size,
[&](void *src, void *dst, size_t size) {
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
(zeCommandList.get(), dst, src, size, nullptr,
waitListView.num, waitListView.handles));
waitListView.clear();
}));
zeCommandList.get(), waitListView));

auto pDst = ur_cast<char *>(dst->getDevicePtr(
device, ur_mem_buffer_t::device_access_mode_t::write_only, dstOffset,
size, [&](void *src, void *dst, size_t size) {
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
(zeCommandList.get(), dst, src, size, nullptr,
waitListView.num, waitListView.handles));
waitListView.clear();
}));
size, zeCommandList.get(), waitListView));

ZE2UR_CALL(zeCommandListAppendMemoryCopy,
(zeCommandList.get(), pDst, pSrc, size, zeSignalEvent,
Expand Down Expand Up @@ -130,20 +116,10 @@ ur_result_t ur_command_list_manager::appendRegionCopyUnlocked(

auto pSrc = ur_cast<char *>(src->getDevicePtr(
device, ur_mem_buffer_t::device_access_mode_t::read_only, 0,
src->getSize(), [&](void *src, void *dst, size_t size) {
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
(zeCommandList.get(), dst, src, size, nullptr,
waitListView.num, waitListView.handles));
waitListView.clear();
}));
src->getSize(), zeCommandList.get(), waitListView));
auto pDst = ur_cast<char *>(dst->getDevicePtr(
device, ur_mem_buffer_t::device_access_mode_t::write_only, 0,
dst->getSize(), [&](void *src, void *dst, size_t size) {
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
(zeCommandList.get(), dst, src, size, nullptr,
waitListView.num, waitListView.handles));
waitListView.clear();
}));
dst->getSize(), zeCommandList.get(), waitListView));

ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion,
(zeCommandList.get(), pDst, &zeParams.dstRegion, zeParams.dstPitch,
Expand Down Expand Up @@ -213,16 +189,9 @@ ur_result_t ur_command_list_manager::appendKernelLaunch(

auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);

auto memoryMigrate = [&](void *src, void *dst, size_t size) {
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
(zeCommandList.get(), dst, src, size, nullptr,
waitListView.num, waitListView.handles));
waitListView.clear();
};

UR_CALL(hKernel->prepareForSubmission(context, device, pGlobalWorkOffset,
workDim, WG[0], WG[1], WG[2],
memoryMigrate));
zeCommandList.get(), waitListView));

TRACK_SCOPE_LATENCY(
"ur_command_list_manager::zeCommandListAppendLaunchKernel");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
#include "common.hpp"
#include "context.hpp"
#include "event_pool_cache.hpp"
#include "memory.hpp"
#include "queue_api.hpp"
#include <ze_api.h>

struct ur_mem_buffer_t;

struct wait_list_view {
ze_event_handle_t *handles;
uint32_t num;
Expand Down
7 changes: 4 additions & 3 deletions unified-runtime/source/adapters/level_zero/v2/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
ur_context_handle_t hContext, ur_device_handle_t hDevice,
const size_t *pGlobalWorkOffset, uint32_t workDim, uint32_t groupSizeX,
uint32_t groupSizeY, uint32_t groupSizeZ,
std::function<void(void *, void *, size_t)> migrate) {
ze_command_list_handle_t commandList, wait_list_view &waitListView) {
auto hZeKernel = getZeHandle(hDevice);

if (pGlobalWorkOffset != NULL) {
Expand All @@ -293,8 +293,9 @@ ur_result_t ur_kernel_handle_t_::prepareForSubmission(
if (pending.hMem) {
if (!pending.hMem->isImage()) {
auto hBuffer = pending.hMem->getBuffer();
zePtr = hBuffer->getDevicePtr(hDevice, pending.mode, 0,
hBuffer->getSize(), migrate);
zePtr =
hBuffer->getDevicePtr(hDevice, pending.mode, 0, hBuffer->getSize(),
commandList, waitListView);
} else {
auto hImage = static_cast<ur_mem_image_t *>(pending.hMem->getImage());
zePtr = reinterpret_cast<void *>(hImage->getZeImage());
Expand Down
13 changes: 7 additions & 6 deletions unified-runtime/source/adapters/level_zero/v2/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ struct ur_kernel_handle_t_ : ur_object {

// Set all required values for the kernel before submission (including pending
// memory allocations).
ur_result_t
prepareForSubmission(ur_context_handle_t hContext, ur_device_handle_t hDevice,
const size_t *pGlobalWorkOffset, uint32_t workDim,
uint32_t groupSizeX, uint32_t groupSizeY,
uint32_t groupSizeZ,
std::function<void(void *, void *, size_t)> migrate);
ur_result_t prepareForSubmission(ur_context_handle_t hContext,
ur_device_handle_t hDevice,
const size_t *pGlobalWorkOffset,
uint32_t workDim, uint32_t groupSizeX,
uint32_t groupSizeY, uint32_t groupSizeZ,
ze_command_list_handle_t cmdList,
wait_list_view &waitListView);

private:
// Keep the program of the kernel.
Expand Down
110 changes: 67 additions & 43 deletions unified-runtime/source/adapters/level_zero/v2/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,24 @@ ur_usm_handle_t::ur_usm_handle_t(ur_context_handle_t hContext, size_t size,
: ur_mem_buffer_t(hContext, size, device_access_mode_t::read_write),
ptr(const_cast<void *>(ptr)) {}

void *ur_usm_handle_t::getDevicePtr(
ur_device_handle_t /*hDevice*/, device_access_mode_t /*access*/,
size_t offset, size_t /*size*/,
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
void *ur_usm_handle_t::getDevicePtr(ur_device_handle_t /*hDevice*/,
device_access_mode_t /*access*/,
size_t offset, size_t /*size*/,
ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
return ur_cast<char *>(ptr) + offset;
}

void *
ur_usm_handle_t::mapHostPtr(ur_map_flags_t /*flags*/, size_t offset,
size_t /*size*/,
std::function<void(void *src, void *dst, size_t)>) {
void *ur_usm_handle_t::mapHostPtr(ur_map_flags_t /*flags*/, size_t offset,
size_t /*size*/,
ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
return ur_cast<char *>(ptr) + offset;
}

void ur_usm_handle_t::unmapHostPtr(
void * /*pMappedPtr*/, std::function<void(void *src, void *dst, size_t)>) {
void ur_usm_handle_t::unmapHostPtr(void * /*pMappedPtr*/,
ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
/* nop */
}

Expand Down Expand Up @@ -106,14 +108,14 @@ ur_integrated_buffer_handle_t::~ur_integrated_buffer_handle_t() {

void *ur_integrated_buffer_handle_t::getDevicePtr(
ur_device_handle_t /*hDevice*/, device_access_mode_t /*access*/,
size_t offset, size_t /*size*/,
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
size_t offset, size_t /*size*/, ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
return ur_cast<char *>(ptr.get()) + offset;
}

void *ur_integrated_buffer_handle_t::mapHostPtr(
ur_map_flags_t /*flags*/, size_t offset, size_t /*size*/,
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
ze_command_list_handle_t /*cmdList*/, wait_list_view & /*waitListView*/) {
// TODO: if writeBackPtr is set, we should map to that pointer
// because that's what SYCL expects, SYCL will attempt to call free
// on the resulting pointer leading to double free with the current
Expand All @@ -122,7 +124,8 @@ void *ur_integrated_buffer_handle_t::mapHostPtr(
}

void ur_integrated_buffer_handle_t::unmapHostPtr(
void * /*pMappedPtr*/, std::function<void(void *src, void *dst, size_t)>) {
void * /*pMappedPtr*/, ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
// TODO: if writeBackPtr is set, we should copy the data back
/* nop */
}
Expand Down Expand Up @@ -250,8 +253,8 @@ void *ur_discrete_buffer_handle_t::getActiveDeviceAlloc(size_t offset) {

void *ur_discrete_buffer_handle_t::getDevicePtr(
ur_device_handle_t hDevice, device_access_mode_t /*access*/, size_t offset,
size_t /*size*/,
std::function<void(void *src, void *dst, size_t)> /*migrate*/) {
size_t /*size*/, ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
TRACK_SCOPE_LATENCY("ur_discrete_buffer_handle_t::getDevicePtr");

if (!activeAllocationDevice) {
Expand Down Expand Up @@ -283,9 +286,22 @@ void *ur_discrete_buffer_handle_t::getDevicePtr(
return getActiveDeviceAlloc(offset);
}

void *ur_discrete_buffer_handle_t::mapHostPtr(
ur_map_flags_t flags, size_t offset, size_t size,
std::function<void(void *src, void *dst, size_t)> migrate) {
static void migrateMemory(ze_command_list_handle_t cmdList, void *src,
void *dst, size_t size,
wait_list_view &waitListView) {
if (!cmdList) {
throw UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}
ZE2UR_CALL_THROWS(zeCommandListAppendMemoryCopy,
(cmdList, dst, src, size, nullptr, waitListView.num,
waitListView.handles));
waitListView.clear();
}

void *ur_discrete_buffer_handle_t::mapHostPtr(ur_map_flags_t flags,
size_t offset, size_t size,
ze_command_list_handle_t cmdList,
wait_list_view &waitListView) {
TRACK_SCOPE_LATENCY("ur_discrete_buffer_handle_t::mapHostPtr");
// TODO: use async alloc?

Expand All @@ -309,15 +325,16 @@ void *ur_discrete_buffer_handle_t::mapHostPtr(

if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
auto srcPtr = getActiveDeviceAlloc(offset);
migrate(srcPtr, hostAllocations.back().ptr.get(), size);
migrateMemory(cmdList, srcPtr, hostAllocations.back().ptr.get(), size,
waitListView);
}

return hostAllocations.back().ptr.get();
}

void ur_discrete_buffer_handle_t::unmapHostPtr(
void *pMappedPtr,
std::function<void(void *src, void *dst, size_t)> migrate) {
void ur_discrete_buffer_handle_t::unmapHostPtr(void *pMappedPtr,
ze_command_list_handle_t cmdList,
wait_list_view &waitListView) {
TRACK_SCOPE_LATENCY("ur_discrete_buffer_handle_t::unmapHostPtr");

auto hostAlloc =
Expand All @@ -341,8 +358,9 @@ void ur_discrete_buffer_handle_t::unmapHostPtr(
// UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
// allocation. is this correct?
if (activeAllocationDevice) {
migrate(hostAlloc->ptr.get(), getActiveDeviceAlloc(hostAlloc->offset),
hostAlloc->size);
migrateMemory(cmdList, hostAlloc->ptr.get(),
getActiveDeviceAlloc(hostAlloc->offset), hostAlloc->size,
waitListView);
}

hostAllocations.erase(hostAlloc);
Expand All @@ -361,18 +379,20 @@ ur_shared_buffer_handle_t::ur_shared_buffer_handle_t(

void *ur_shared_buffer_handle_t::getDevicePtr(
ur_device_handle_t, device_access_mode_t, size_t offset, size_t,
std::function<void(void *src, void *dst, size_t)>) {
ze_command_list_handle_t /*cmdList*/, wait_list_view & /*waitListView*/) {
return reinterpret_cast<char *>(ptr.get()) + offset;
}

void *ur_shared_buffer_handle_t::mapHostPtr(
ur_map_flags_t, size_t offset, size_t,
std::function<void(void *src, void *dst, size_t)>) {
void *
ur_shared_buffer_handle_t::mapHostPtr(ur_map_flags_t, size_t offset, size_t,
ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
return reinterpret_cast<char *>(ptr.get()) + offset;
}

void ur_shared_buffer_handle_t::unmapHostPtr(
void *, std::function<void(void *src, void *dst, size_t)>) {
void *, ze_command_list_handle_t /*cmdList*/,
wait_list_view & /*waitListView*/) {
// nop
}

Expand Down Expand Up @@ -403,24 +423,27 @@ ur_mem_sub_buffer_t::~ur_mem_sub_buffer_t() {
ur::level_zero::urMemRelease(hParent);
}

void *ur_mem_sub_buffer_t::getDevicePtr(
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
void *ur_mem_sub_buffer_t::getDevicePtr(ur_device_handle_t hDevice,
device_access_mode_t access,
size_t offset, size_t size,
ze_command_list_handle_t cmdList,
wait_list_view &waitListView) {
return hParent->getBuffer()->getDevicePtr(
hDevice, access, offset + this->offset, size, std::move(migrate));
hDevice, access, offset + this->offset, size, cmdList, waitListView);
}

void *ur_mem_sub_buffer_t::mapHostPtr(
ur_map_flags_t flags, size_t offset, size_t size,
std::function<void(void *src, void *dst, size_t)> migrate) {
void *ur_mem_sub_buffer_t::mapHostPtr(ur_map_flags_t flags, size_t offset,
size_t size,
ze_command_list_handle_t cmdList,
wait_list_view &waitListView) {
return hParent->getBuffer()->mapHostPtr(flags, offset + this->offset, size,
std::move(migrate));
cmdList, waitListView);
}

void ur_mem_sub_buffer_t::unmapHostPtr(
void *pMappedPtr,
std::function<void(void *src, void *dst, size_t)> migrate) {
return hParent->getBuffer()->unmapHostPtr(pMappedPtr, std::move(migrate));
void ur_mem_sub_buffer_t::unmapHostPtr(void *pMappedPtr,
ze_command_list_handle_t cmdList,
wait_list_view &waitListView) {
return hParent->getBuffer()->unmapHostPtr(pMappedPtr, cmdList, waitListView);
}

ur_shared_mutex &ur_mem_sub_buffer_t::getMutex() {
Expand Down Expand Up @@ -690,9 +713,10 @@ ur_result_t urMemGetNativeHandle(ur_mem_handle_t hMem,

std::scoped_lock<ur_shared_mutex> lock(hBuffer->getMutex());

wait_list_view emptyWaitListView(nullptr, 0);
auto ptr = hBuffer->getDevicePtr(
hDevice, ur_mem_buffer_t::device_access_mode_t::read_write, 0,
hBuffer->getSize(), nullptr);
hBuffer->getSize(), nullptr, emptyWaitListView);
*phNativeMem = reinterpret_cast<ur_native_handle_t>(ptr);
return UR_RESULT_SUCCESS;
} catch (...) {
Expand Down
Loading