Skip to content

Commit 05ea17d

Browse files
committed
Refactor UR reference counting with new common RefCount class, specifically for Native CPU adapter.
1 parent 44b4c09 commit 05ea17d

File tree

13 files changed

+68
-43
lines changed

13 files changed

+68
-43
lines changed

unified-runtime/source/adapters/native_cpu/adapter.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,22 @@
1010

1111
#include "adapter.hpp"
1212
#include "common.hpp"
13+
#include "common/ur_ref_count.hpp"
1314
#include "ur_api.h"
1415

1516
struct ur_adapter_handle_t_ : ur::native_cpu::handle_base {
16-
std::atomic<uint32_t> RefCount = 0;
1717
logger::Logger &logger = logger::get_logger("native_cpu");
18+
19+
ur::RefCount &getRefCount() noexcept { return RefCount; }
20+
21+
private:
22+
ur::RefCount RefCount;
1823
} Adapter;
1924

2025
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
2126
uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) {
2227
if (phAdapters) {
23-
Adapter.RefCount++;
28+
Adapter.getRefCount().retain();
2429
*phAdapters = &Adapter;
2530
}
2631
if (pNumAdapters) {
@@ -30,12 +35,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
3035
}
3136

3237
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
33-
Adapter.RefCount--;
38+
Adapter.getRefCount().release();
3439
return UR_RESULT_SUCCESS;
3540
}
3641

3742
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
38-
Adapter.RefCount++;
43+
Adapter.getRefCount().retain();
3944
return UR_RESULT_SUCCESS;
4045
}
4146

@@ -57,7 +62,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
5762
case UR_ADAPTER_INFO_BACKEND:
5863
return ReturnValue(UR_BACKEND_NATIVE_CPU);
5964
case UR_ADAPTER_INFO_REFERENCE_COUNT:
60-
return ReturnValue(Adapter.RefCount.load());
65+
return ReturnValue(Adapter.getRefCount().getCount());
6166
case UR_ADAPTER_INFO_VERSION:
6267
return ReturnValue(uint32_t{1});
6368
default:

unified-runtime/source/adapters/native_cpu/common.hpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,13 @@ struct ddi_getter {
4444
using handle_base = ur::handle_base<ddi_getter>;
4545
} // namespace ur::native_cpu
4646

47-
// Todo: replace this with a common helper once it is available
48-
struct RefCounted : ur::native_cpu::handle_base {
49-
std::atomic_uint32_t _refCount;
50-
uint32_t incrementReferenceCount() { return ++_refCount; }
51-
uint32_t decrementReferenceCount() { return --_refCount; }
52-
RefCounted() : handle_base(), _refCount{1} {}
53-
uint32_t getReferenceCount() const { return _refCount; }
54-
};
55-
5647
// Base class to store common data
57-
struct ur_object : RefCounted {
48+
struct ur_object {
5849
ur_shared_mutex Mutex;
5950
};
6051

6152
template <typename T> inline void decrementOrDelete(T *refC) {
62-
if (refC->decrementReferenceCount() == 0)
53+
if (refC->getRefCount().release() == 0)
6354
delete refC;
6455
}
6556

unified-runtime/source/adapters/native_cpu/context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
3030

3131
UR_APIEXPORT ur_result_t UR_APICALL
3232
urContextRetain(ur_context_handle_t hContext) {
33-
hContext->incrementReferenceCount();
33+
hContext->getRefCount().retain();
3434
return UR_RESULT_SUCCESS;
3535
}
3636

@@ -51,7 +51,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
5151
case UR_CONTEXT_INFO_DEVICES:
5252
return returnValue(hContext->_device);
5353
case UR_CONTEXT_INFO_REFERENCE_COUNT:
54-
return returnValue(uint32_t{hContext->getReferenceCount()});
54+
return returnValue(uint32_t{hContext->getRefCount().getCount()});
5555
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
5656
return returnValue(true);
5757
case UR_CONTEXT_INFO_USM_FILL2D_SUPPORT:

unified-runtime/source/adapters/native_cpu/context.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <ur_api.h>
1616

1717
#include "common.hpp"
18+
#include "common/ur_ref_count.hpp"
1819
#include "device.hpp"
1920
#include "ur/ur.hpp"
2021

@@ -83,7 +84,7 @@ static usm_alloc_info get_alloc_info(void *ptr) {
8384

8485
} // namespace native_cpu
8586

86-
struct ur_context_handle_t_ : RefCounted {
87+
struct ur_context_handle_t_ {
8788
ur_context_handle_t_(ur_device_handle_t_ *phDevices) : _device{phDevices} {}
8889

8990
ur_device_handle_t _device;
@@ -135,7 +136,10 @@ struct ur_context_handle_t_ : RefCounted {
135136
return ptr;
136137
}
137138

139+
ur::RefCount &getRefCount() noexcept { return RefCount; }
140+
138141
private:
139142
std::mutex alloc_mutex;
140143
std::set<const void *> allocations;
144+
ur::RefCount RefCount;
141145
};

unified-runtime/source/adapters/native_cpu/event.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
2828
case UR_EVENT_INFO_COMMAND_TYPE:
2929
return ReturnValue(hEvent->getCommandType());
3030
case UR_EVENT_INFO_REFERENCE_COUNT:
31-
return ReturnValue(hEvent->getReferenceCount());
31+
return ReturnValue(hEvent->getRefCount().getCount());
3232
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS:
3333
return ReturnValue(hEvent->getExecutionStatus());
3434
case UR_EVENT_INFO_CONTEXT:
@@ -69,7 +69,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
6969
}
7070

7171
UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
72-
hEvent->incrementReferenceCount();
72+
hEvent->getRefCount().retain();
7373
return UR_RESULT_SUCCESS;
7474
}
7575

unified-runtime/source/adapters/native_cpu/event.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88
//
99
//===----------------------------------------------------------------------===//
1010
#pragma once
11-
#include "common.hpp"
12-
#include "ur_api.h"
11+
1312
#include <cstdint>
1413
#include <future>
1514
#include <mutex>
1615
#include <vector>
1716

18-
struct ur_event_handle_t_ : RefCounted {
17+
#include "common.hpp"
18+
#include "common/ur_ref_count.hpp"
19+
#include "ur_api.h"
20+
21+
struct ur_event_handle_t_ {
1922

2023
ur_event_handle_t_(ur_queue_handle_t queue, ur_command_t command_type);
2124

@@ -55,6 +58,8 @@ struct ur_event_handle_t_ : RefCounted {
5558

5659
uint64_t get_end_timestamp() const { return timestamp_end; }
5760

61+
ur::RefCount &getRefCount() noexcept { return RefCount; }
62+
5863
private:
5964
ur_queue_handle_t queue;
6065
ur_context_handle_t context;
@@ -65,4 +70,5 @@ struct ur_event_handle_t_ : RefCounted {
6570
std::packaged_task<void()> callback;
6671
uint64_t timestamp_start = 0;
6772
uint64_t timestamp_end = 0;
73+
ur::RefCount RefCount;
6874
};

unified-runtime/source/adapters/native_cpu/kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
9595
case UR_KERNEL_INFO_FUNCTION_NAME:
9696
return ReturnValue(hKernel->_name.c_str());
9797
case UR_KERNEL_INFO_REFERENCE_COUNT:
98-
return ReturnValue(uint32_t{hKernel->getReferenceCount()});
98+
return ReturnValue(uint32_t{hKernel->getRefCount().getCount()});
9999
case UR_KERNEL_INFO_ATTRIBUTES:
100100
return ReturnValue("");
101101
case UR_KERNEL_INFO_SPILL_MEM_SIZE:
@@ -194,7 +194,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetSubGroupInfo(
194194
}
195195

196196
UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) {
197-
hKernel->incrementReferenceCount();
197+
hKernel->getRefCount().retain();
198198
return UR_RESULT_SUCCESS;
199199
}
200200

unified-runtime/source/adapters/native_cpu/kernel.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88

99
#pragma once
1010

11+
#include <cstring>
12+
#include <utility>
13+
1114
#include "common.hpp"
15+
#include "common/ur_ref_count.hpp"
1216
#include "memory.hpp"
1317
#include "nativecpu_state.hpp"
1418
#include "program.hpp"
15-
#include <cstring>
1619
#include <ur_api.h>
17-
#include <utility>
1820

1921
using nativecpu_kernel_t = void(void *const *, native_cpu::state *);
2022
using nativecpu_ptr_t = nativecpu_kernel_t *;
@@ -27,7 +29,7 @@ struct local_arg_info_t {
2729
: argIndex(argIndex), argSize(argSize) {}
2830
};
2931

30-
struct ur_kernel_handle_t_ : RefCounted {
32+
struct ur_kernel_handle_t_ {
3133

3234
ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
3335
nativecpu_task_t subhandler)
@@ -188,10 +190,12 @@ struct ur_kernel_handle_t_ : RefCounted {
188190
void addPtrArg(void *Ptr, size_t Index) { Args.addPtrArg(Index, Ptr); }
189191

190192
void addArgReference(ur_mem_handle_t Arg) {
191-
Arg->incrementReferenceCount();
193+
Arg->getRefCount().getCount();
192194
ReferencedArgs.push_back(Arg);
193195
}
194196

197+
ur::RefCount &getRefCount() noexcept { return RefCount; }
198+
195199
private:
196200
void removeArgReferences() {
197201
for (auto arg : ReferencedArgs)
@@ -209,4 +213,5 @@ struct ur_kernel_handle_t_ : RefCounted {
209213
std::optional<native_cpu::WGSize_t> MaxWGSize = std::nullopt;
210214
std::optional<uint64_t> MaxLinearWGSize = std::nullopt;
211215
std::vector<ur_mem_handle_t> ReferencedArgs;
216+
ur::RefCount RefCount;
212217
};

unified-runtime/source/adapters/native_cpu/memory.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "common.hpp"
1818
#include "context.hpp"
19+
#include "common/ur_ref_count.hpp"
1920

2021
struct ur_mem_handle_t_ : ur_object {
2122
ur_mem_handle_t_(size_t Size, bool _IsImage)
@@ -43,8 +44,11 @@ struct ur_mem_handle_t_ : ur_object {
4344
char *_mem;
4445
bool _ownsMem;
4546

47+
ur::RefCount &getRefCount() noexcept { return RefCount; }
48+
4649
private:
4750
const bool IsImage;
51+
ur::RefCount RefCount;
4852
};
4953

5054
struct ur_buffer final : ur_mem_handle_t_ {

unified-runtime/source/adapters/native_cpu/program.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
171171

172172
UR_APIEXPORT ur_result_t UR_APICALL
173173
urProgramRetain(ur_program_handle_t hProgram) {
174-
hProgram->incrementReferenceCount();
174+
hProgram->getRefCount().retain();
175175
return UR_RESULT_SUCCESS;
176176
}
177177

@@ -205,7 +205,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName,
205205

206206
switch (propName) {
207207
case UR_PROGRAM_INFO_REFERENCE_COUNT:
208-
return returnValue(hProgram->getReferenceCount());
208+
return returnValue(hProgram->getRefCount().getCount());
209209
case UR_PROGRAM_INFO_CONTEXT:
210210
return returnValue(nullptr);
211211
case UR_PROGRAM_INFO_NUM_DEVICES:

unified-runtime/source/adapters/native_cpu/program.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,22 @@
1010

1111
#pragma once
1212

13+
#include <array>
14+
#include <map>
15+
1316
#include <ur_api.h>
1417

18+
#include "common/ur_ref_count.hpp"
1519
#include "context.hpp"
1620

17-
#include <array>
18-
#include <map>
19-
2021
namespace native_cpu {
2122
using WGSize_t = std::array<uint32_t, 3>;
2223
}
2324

24-
struct ur_program_handle_t_ : RefCounted {
25+
struct ur_program_handle_t_ {
2526
ur_program_handle_t_(ur_context_handle_t ctx, const unsigned char *pBinary)
2627
: _ctx{ctx}, _ptr{pBinary} {}
2728

28-
uint32_t getReferenceCount() const noexcept { return _refCount; }
29-
3029
ur_context_handle_t _ctx;
3130
const unsigned char *_ptr;
3231
struct _compare {
@@ -41,6 +40,11 @@ struct ur_program_handle_t_ : RefCounted {
4140
std::unordered_map<std::string, native_cpu::WGSize_t>
4241
KernelMaxWorkGroupSizeMD;
4342
std::unordered_map<std::string, uint64_t> KernelMaxLinearWorkGroupSizeMD;
43+
44+
ur::RefCount &getRefCount() noexcept { return RefCount; }
45+
46+
private:
47+
ur::RefCount RefCount;
4448
};
4549

4650
// The nativecpu_entry struct is also defined as LLVM-IR in the

unified-runtime/source/adapters/native_cpu/queue.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue,
2828
case UR_QUEUE_INFO_DEVICE:
2929
return ReturnValue(hQueue->getDevice());
3030
case UR_QUEUE_INFO_REFERENCE_COUNT:
31-
return ReturnValue(hQueue->getReferenceCount());
31+
return ReturnValue(hQueue->getRefCount().getCount());
3232
case UR_QUEUE_INFO_EMPTY:
3333
return ReturnValue(hQueue->isEmpty());
3434
default:
@@ -48,7 +48,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate(
4848
}
4949

5050
UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) {
51-
hQueue->incrementReferenceCount();
51+
hQueue->getRefCount().retain();
5252

5353
return UR_RESULT_SUCCESS;
5454
}

unified-runtime/source/adapters/native_cpu/queue.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
//
99
//===----------------------------------------------------------------------===//
1010
#pragma once
11+
12+
#include <set>
13+
1114
#include "common.hpp"
15+
#include "common/ur_ref_count.hpp"
1216
#include "event.hpp"
1317
#include "ur_api.h"
14-
#include <set>
1518

16-
struct ur_queue_handle_t_ : RefCounted {
19+
struct ur_queue_handle_t_ {
1720
ur_queue_handle_t_(ur_device_handle_t device, ur_context_handle_t context,
1821
const ur_queue_properties_t *pProps)
1922
: device(device), context(context),
@@ -43,7 +46,7 @@ struct ur_queue_handle_t_ : RefCounted {
4346
auto ev = *events.begin();
4447
// ur_event_handle_t_::wait removes itself from the events set in the
4548
// queue.
46-
ev->incrementReferenceCount();
49+
ev->getRefCount().retain();
4750
// Unlocking mutex for removeEvent and for event callbacks that may need
4851
// to acquire it.
4952
lock.unlock();
@@ -64,11 +67,14 @@ struct ur_queue_handle_t_ : RefCounted {
6467
return events.size() == 0;
6568
}
6669

70+
ur::RefCount &getRefCount() noexcept { return RefCount; }
71+
6772
private:
6873
ur_device_handle_t device;
6974
ur_context_handle_t context;
7075
std::set<ur_event_handle_t> events;
7176
const bool inOrder;
7277
const bool profilingEnabled;
7378
std::mutex mutex;
79+
ur::RefCount RefCount;
7480
};

0 commit comments

Comments
 (0)