Skip to content

Commit 858fd15

Browse files
authored
[UR][Offload] Add OL_RETURN_ON_ERR macro (#19170)
Add a macro for calling Offload API functions and returning a UR error on failure. Some Offload calls were unchecked before.
1 parent 345b9ab commit 858fd15

File tree

10 files changed

+49
-84
lines changed

10 files changed

+49
-84
lines changed

unified-runtime/source/adapters/offload/common.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#pragma once
1212

1313
#include "ur/ur.hpp"
14+
#include "ur2offload.hpp"
1415
#include <atomic>
1516

1617
namespace ur::offload {
@@ -23,3 +24,8 @@ using handle_base = ur::handle_base<ur::offload::ddi_getter>;
2324
struct RefCounted : ur::offload::handle_base {
2425
std::atomic_uint32_t RefCount = 1;
2526
};
27+
28+
#define OL_RETURN_ON_ERR(call) \
29+
if (auto OlRes = call) { \
30+
return offloadResultToUR(OlRes); \
31+
}

unified-runtime/source/adapters/offload/device.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
9494
}
9595

9696
if (pPropSizeRet) {
97-
if (auto Res =
98-
olGetDeviceInfoSize(hDevice->OffloadDevice, olInfo, pPropSizeRet)) {
99-
return offloadResultToUR(Res);
100-
}
97+
OL_RETURN_ON_ERR(
98+
olGetDeviceInfoSize(hDevice->OffloadDevice, olInfo, pPropSizeRet));
10199
}
102100

103101
if (pPropValue) {
104-
if (auto Res = olGetDeviceInfo(hDevice->OffloadDevice, olInfo, propSize,
105-
pPropValue)) {
106-
return offloadResultToUR(Res);
107-
}
102+
OL_RETURN_ON_ERR(
103+
olGetDeviceInfo(hDevice->OffloadDevice, olInfo, propSize, pPropValue));
108104
// Need to explicitly map this type
109105
if (olInfo == OL_DEVICE_INFO_TYPE) {
110106
auto urPropPtr = reinterpret_cast<ur_device_type_t *>(pPropValue);
@@ -149,8 +145,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceSelectBinary(
149145
uint32_t NumBinaries, uint32_t *pSelectedBinary) {
150146

151147
ol_platform_backend_t Backend;
152-
olGetPlatformInfo(hDevice->Platform->OffloadPlatform,
153-
OL_PLATFORM_INFO_BACKEND, sizeof(Backend), &Backend);
148+
OL_RETURN_ON_ERR(olGetPlatformInfo(hDevice->Platform->OffloadPlatform,
149+
OL_PLATFORM_INFO_BACKEND, sizeof(Backend),
150+
&Backend));
154151

155152
const char *ImageTarget = UR_DEVICE_BINARY_TARGET_UNKNOWN;
156153
if (Backend == OL_PLATFORM_BACKEND_CUDA) {

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6868
LaunchArgs.DynSharedMemory = 0;
6969

7070
ol_event_handle_t EventOut;
71-
auto Ret =
71+
OL_RETURN_ON_ERR(
7272
olLaunchKernel(hQueue->OffloadQueue, hQueue->OffloadDevice,
7373
hKernel->OffloadKernel, hKernel->Args.getStorage(),
74-
hKernel->Args.getStorageSize(), &LaunchArgs, &EventOut);
75-
76-
if (Ret != OL_SUCCESS) {
77-
return offloadResultToUR(Ret);
78-
}
74+
hKernel->Args.getStorageSize(), &LaunchArgs, &EventOut));
7975

8076
if (phEvent) {
8177
auto *Event = new ur_event_handle_t_();
@@ -112,11 +108,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
112108
char *DevPtr =
113109
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
114110

115-
olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice, DevPtr + offset,
116-
hQueue->OffloadDevice, size, phEvent ? &EventOut : nullptr);
111+
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice,
112+
DevPtr + offset, hQueue->OffloadDevice, size,
113+
phEvent ? &EventOut : nullptr));
117114

118115
if (blockingRead) {
119-
olWaitQueue(hQueue->OffloadQueue);
116+
OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue));
120117
}
121118

122119
if (phEvent) {
@@ -143,18 +140,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
143140
char *DevPtr =
144141
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
145142

146-
auto Res =
147-
olMemcpy(hQueue->OffloadQueue, DevPtr + offset, hQueue->OffloadDevice,
148-
pSrc, Adapter->HostDevice, size, phEvent ? &EventOut : nullptr);
149-
if (Res) {
150-
return offloadResultToUR(Res);
151-
}
143+
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DevPtr + offset,
144+
hQueue->OffloadDevice, pSrc, Adapter->HostDevice,
145+
size, phEvent ? &EventOut : nullptr));
152146

153147
if (blockingWrite) {
154-
auto Res = olWaitQueue(hQueue->OffloadQueue);
155-
if (Res) {
156-
return offloadResultToUR(Res);
157-
}
148+
OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue));
158149
}
159150

160151
if (phEvent) {

unified-runtime/source/adapters/offload/event.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ UR_APIEXPORT ur_result_t UR_APICALL
4343
urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
4444
for (uint32_t i = 0; i < numEvents; i++) {
4545
if (phEventWaitList[i]->OffloadEvent) {
46-
auto Res = olWaitEvent(phEventWaitList[i]->OffloadEvent);
47-
if (Res) {
48-
return offloadResultToUR(Res);
49-
}
46+
OL_RETURN_ON_ERR(olWaitEvent(phEventWaitList[i]->OffloadEvent));
5047
}
5148
}
5249
return UR_RESULT_SUCCESS;

unified-runtime/source/adapters/offload/memory.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
3535
auto AllocMode = BufferMem::AllocMode::Default;
3636

3737
if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) {
38-
auto Res = olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_HOST, size, &HostPtr);
39-
if (Res) {
40-
return offloadResultToUR(Res);
41-
}
38+
OL_RETURN_ON_ERR(
39+
olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_HOST, size, &HostPtr));
40+
4241
// TODO: We (probably) need something like cuMemHostGetDevicePointer
4342
// for this to work everywhere. For now assume the managed host pointer is
4443
// device-accessible.
4544
Ptr = HostPtr;
4645
AllocMode = BufferMem::AllocMode::AllocHostPtr;
4746
} else {
48-
auto Res = olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, &Ptr);
49-
if (Res) {
50-
return offloadResultToUR(Res);
51-
}
47+
OL_RETURN_ON_ERR(
48+
olMemAlloc(OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, &Ptr));
5249
if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) {
5350
AllocMode = BufferMem::AllocMode::CopyIn;
5451
}
@@ -59,11 +56,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
5956
hContext, ParentBuffer, flags, AllocMode, Ptr, HostPtr, size});
6057

6158
if (PerformInitialCopy) {
62-
auto Res = olMemcpy(nullptr, Ptr, OffloadDevice, HostPtr,
63-
Adapter->HostDevice, size, nullptr);
64-
if (Res) {
65-
return offloadResultToUR(Res);
66-
}
59+
OL_RETURN_ON_ERR(olMemcpy(nullptr, Ptr, OffloadDevice, HostPtr,
60+
Adapter->HostDevice, size, nullptr));
6761
}
6862

6963
*phBuffer = URMemObj.release();
@@ -85,10 +79,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
8579
if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
8680
// TODO: Handle registered host memory
8781
auto &BufferImpl = std::get<BufferMem>(MemObjPtr->Mem);
88-
auto Res = olMemFree(BufferImpl.Ptr);
89-
if (Res) {
90-
return offloadResultToUR(Res);
91-
}
82+
OL_RETURN_ON_ERR(olMemFree(BufferImpl.Ptr));
9283
}
9384

9485
return UR_RESULT_SUCCESS;

unified-runtime/source/adapters/offload/platform.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,13 @@ urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName,
6868
}
6969

7070
if (pPropSizeRet) {
71-
if (auto Res = olGetPlatformInfoSize(hPlatform->OffloadPlatform, olInfo,
72-
pPropSizeRet)) {
73-
return offloadResultToUR(Res);
74-
}
71+
OL_RETURN_ON_ERR(olGetPlatformInfoSize(hPlatform->OffloadPlatform, olInfo,
72+
pPropSizeRet));
7573
}
7674

7775
if (pPropValue) {
78-
if (auto Res = olGetPlatformInfo(hPlatform->OffloadPlatform, olInfo,
79-
propSize, pPropValue)) {
80-
return offloadResultToUR(Res);
81-
}
76+
OL_RETURN_ON_ERR(olGetPlatformInfo(hPlatform->OffloadPlatform, olInfo,
77+
propSize, pPropValue));
8278
}
8379

8480
return UR_RESULT_SUCCESS;

unified-runtime/source/adapters/offload/program.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
8787
if (auto Parser = HipOffloadBundleParser::load(RealBinary, RealLength)) {
8888
std::string DevName{};
8989
size_t DevNameLength;
90-
olGetDeviceInfoSize(phDevices[0]->OffloadDevice, OL_DEVICE_INFO_NAME,
91-
&DevNameLength);
90+
OL_RETURN_ON_ERR(olGetDeviceInfoSize(phDevices[0]->OffloadDevice,
91+
OL_DEVICE_INFO_NAME, &DevNameLength));
9292
DevName.resize(DevNameLength);
93-
olGetDeviceInfo(phDevices[0]->OffloadDevice, OL_DEVICE_INFO_NAME,
94-
DevNameLength, DevName.data());
93+
OL_RETURN_ON_ERR(olGetDeviceInfo(phDevices[0]->OffloadDevice,
94+
OL_DEVICE_INFO_NAME, DevNameLength,
95+
DevName.data()));
9596

9697
auto Res = Parser->extract(DevName, RealBinary, RealLength);
9798
if (Res != UR_RESULT_SUCCESS) {

unified-runtime/source/adapters/offload/queue.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) {
6161

6262
UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) {
6363
if (--hQueue->RefCount == 0) {
64-
auto Res = olDestroyQueue(hQueue->OffloadQueue);
65-
if (Res) {
66-
return offloadResultToUR(Res);
67-
}
64+
OL_RETURN_ON_ERR(olDestroyQueue(hQueue->OffloadQueue));
6865
delete hQueue;
6966
}
7067

unified-runtime/source/adapters/offload/ur2offload.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
88
//
99
//===----------------------------------------------------------------------===//
10+
#pragma once
1011

1112
#include <OffloadAPI.h>
1213
#include <ur_api.h>

unified-runtime/source/adapters/offload/usm.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
2020
const ur_usm_desc_t *,
2121
ur_usm_pool_handle_t,
2222
size_t size, void **ppMem) {
23-
auto Res = olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_HOST,
24-
size, ppMem);
25-
26-
if (Res != OL_SUCCESS) {
27-
return offloadResultToUR(Res);
28-
}
23+
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
24+
OL_ALLOC_TYPE_HOST, size, ppMem));
2925

3026
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST);
3127
return UR_RESULT_SUCCESS;
@@ -34,12 +30,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
3430
UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
3531
ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *,
3632
ur_usm_pool_handle_t, size_t size, void **ppMem) {
37-
auto Res = olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_DEVICE,
38-
size, ppMem);
39-
40-
if (Res != OL_SUCCESS) {
41-
return offloadResultToUR(Res);
42-
}
33+
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
34+
OL_ALLOC_TYPE_DEVICE, size, ppMem));
4335

4436
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE);
4537
return UR_RESULT_SUCCESS;
@@ -48,12 +40,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
4840
UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
4941
ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *,
5042
ur_usm_pool_handle_t, size_t size, void **ppMem) {
51-
auto Res = olMemAlloc(hContext->Device->OffloadDevice, OL_ALLOC_TYPE_MANAGED,
52-
size, ppMem);
53-
54-
if (Res != OL_SUCCESS) {
55-
return offloadResultToUR(Res);
56-
}
43+
OL_RETURN_ON_ERR(olMemAlloc(hContext->Device->OffloadDevice,
44+
OL_ALLOC_TYPE_MANAGED, size, ppMem));
5745

5846
hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_MANAGED);
5947
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)