Skip to content

Commit 2e1f14a

Browse files
authored
[SYCL] Fix UB and alignment issues in the SYCL default sorter (#13975)
Currently `std::byte*` scratch pointer is not aligned and `reinterpret_cast`ed as `T*` where type `T` may have alignment requirement different from `byte*`, this is UB. As a solution, use `std::align` to align the required buffer in the scratch and use placement `new` so that dynamic type of the buffer in the scratch will be `T*`.
1 parent 4222b4c commit 2e1f14a

File tree

2 files changed

+59
-48
lines changed

2 files changed

+59
-48
lines changed

sycl/include/sycl/detail/group_sort_impl.hpp

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,10 @@ struct GetValueType<sycl::multi_ptr<ElementType, Space, IsDecorated>> {
6868
using type = ElementType;
6969
};
7070

71-
// since we couldn't assign data to raw memory, it's better to use placement
72-
// for first assignment
73-
template <typename Acc, typename T>
74-
void set_value(Acc ptr, const size_t idx, const T &val, bool is_first) {
75-
if (is_first) {
76-
::new (ptr + idx) T(val);
77-
} else {
78-
ptr[idx] = val;
79-
}
80-
}
81-
8271
template <typename InAcc, typename OutAcc, typename Compare>
8372
void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
8473
const size_t start_1, const size_t end_1, const size_t end_2,
85-
const size_t start_out, Compare comp, const size_t chunk,
86-
bool is_first) {
74+
const size_t start_out, Compare comp, const size_t chunk) {
8775
const size_t start_2 = end_1;
8876
// Borders of the sequences to merge within this call
8977
const size_t local_start_1 =
@@ -111,8 +99,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
11199
const size_t l_shift_1 = local_start_1 - start_1;
112100
const size_t l_shift_2 = l_search_bound_2 - start_2;
113101

114-
set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
115-
is_first);
102+
out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_1;
116103

117104
size_t r_search_bound_2{};
118105
// find right border in 2nd sequence
@@ -123,8 +110,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
123110
const auto r_shift_1 = local_end_1 - 1 - start_1;
124111
const auto r_shift_2 = r_search_bound_2 - start_2;
125112

126-
set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
127-
is_first);
113+
out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_1;
128114
}
129115

130116
// Handle intermediate items
@@ -138,8 +124,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
138124
const size_t shift_1 = idx - start_1;
139125
const size_t shift_2 = l_search_bound_2 - start_2;
140126

141-
set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
142-
is_first);
127+
out_acc1[start_out + shift_1 + shift_2] = intermediate_item_1;
143128
}
144129
}
145130
// Process 2nd sequence
@@ -152,8 +137,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
152137
const size_t l_shift_1 = l_search_bound_1 - start_1;
153138
const size_t l_shift_2 = local_start_2 - start_2;
154139

155-
set_value(out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
156-
is_first);
140+
out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_2;
157141

158142
size_t r_search_bound_1{};
159143
// find right border in 1st sequence
@@ -164,8 +148,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
164148
const size_t r_shift_1 = r_search_bound_1 - start_1;
165149
const size_t r_shift_2 = local_end_2 - 1 - start_2;
166150

167-
set_value(out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
168-
is_first);
151+
out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_2;
169152
}
170153

171154
// Handle intermediate items
@@ -179,8 +162,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
179162
const size_t shift_1 = l_search_bound_1 - start_1;
180163
const size_t shift_2 = idx - start_2;
181164

182-
set_value(out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
183-
is_first);
165+
out_acc1[start_out + shift_1 + shift_2] = intermediate_item_2;
184166
}
185167
}
186168
}
@@ -200,10 +182,9 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
200182
}
201183
}
202184

203-
template <typename Group, typename Iter, typename Compare>
185+
template <typename Group, typename Iter, typename T, typename Compare>
204186
void merge_sort(Group group, Iter first, const size_t n, Compare comp,
205-
std::byte *scratch) {
206-
using T = typename GetValueType<Iter>::type;
187+
T *scratch) {
207188
const size_t idx = group.get_local_linear_id();
208189
const size_t local = group.get_local_range().size();
209190
const size_t chunk = (n - 1) / local + 1;
@@ -212,9 +193,7 @@ void merge_sort(Group group, Iter first, const size_t n, Compare comp,
212193
bubble_sort(first, idx * chunk, sycl::min((idx + 1) * chunk, n), comp);
213194
sycl::group_barrier(group);
214195

215-
T *temp = reinterpret_cast<T *>(scratch);
216-
bool data_in_temp = false;
217-
bool is_first = true;
196+
bool data_in_scratch = false;
218197
size_t sorted_size = 1;
219198
while (sorted_size * chunk < n) {
220199
const size_t start_1 =
@@ -223,26 +202,24 @@ void merge_sort(Group group, Iter first, const size_t n, Compare comp,
223202
const size_t end_2 = sycl::min(end_1 + sorted_size * chunk, n);
224203
const size_t offset = chunk * (idx % sorted_size);
225204

226-
if (!data_in_temp) {
227-
merge(offset, first, temp, start_1, end_1, end_2, start_1, comp, chunk,
228-
is_first);
205+
if (!data_in_scratch) {
206+
merge(offset, first, scratch, start_1, end_1, end_2, start_1, comp,
207+
chunk);
229208
} else {
230-
merge(offset, temp, first, start_1, end_1, end_2, start_1, comp, chunk,
231-
/*is_first*/ false);
209+
merge(offset, scratch, first, start_1, end_1, end_2, start_1, comp,
210+
chunk);
232211
}
233212
sycl::group_barrier(group);
234213

235-
data_in_temp = !data_in_temp;
214+
data_in_scratch = !data_in_scratch;
236215
sorted_size *= 2;
237-
if (is_first)
238-
is_first = false;
239216
}
240217

241218
// copy back if data is in a temporary storage
242-
if (data_in_temp) {
219+
if (data_in_scratch) {
243220
for (size_t i = 0; i < chunk; ++i) {
244221
if (idx * chunk + i < n) {
245-
first[idx * chunk + i] = temp[idx * chunk + i];
222+
first[idx * chunk + i] = scratch[idx * chunk + i];
246223
}
247224
}
248225
sycl::group_barrier(group);

sycl/include/sycl/ext/oneapi/experimental/group_helpers_sorters.hpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,27 @@ template <typename Compare = std::less<>> class default_sorter {
6363
void operator()([[maybe_unused]] Group g, [[maybe_unused]] Ptr first,
6464
[[maybe_unused]] Ptr last) {
6565
#ifdef __SYCL_DEVICE_ONLY__
66+
// Adjust the scratch pointer based on alignment of the type T.
6667
// Per extension specification if scratch size is less than the value
6768
// returned by memory_required then behavior is undefined, so we don't check
6869
// that the scratch size statisfies the requirement.
69-
sycl::detail::merge_sort(g, first, last - first, comp, scratch.data());
70+
using T = typename sycl::detail::GetValueType<Ptr>::type;
71+
T *scratch_begin = nullptr;
72+
size_t n = last - first;
73+
// We must have a barrier here before array placement new because it is
74+
// possible that scratch memory is already in use, so we need to synchronize
75+
// work items.
76+
sycl::group_barrier(g);
77+
if (g.leader()) {
78+
void *scratch_ptr = scratch.data();
79+
size_t space = scratch.size();
80+
scratch_ptr = std::align(alignof(T), n * sizeof(T), scratch_ptr, space);
81+
scratch_begin = ::new (scratch_ptr) T[n];
82+
}
83+
// Broadcast leader's pointer (the beginning of the scratch) to all work
84+
// items in the group.
85+
scratch_begin = sycl::group_broadcast(g, scratch_begin);
86+
sycl::detail::merge_sort(g, first, n, comp, scratch_begin);
7087
#else
7188
throw sycl::exception(
7289
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),
@@ -77,16 +94,33 @@ template <typename Compare = std::less<>> class default_sorter {
7794
template <typename Group, typename T>
7895
T operator()([[maybe_unused]] Group g, T val) {
7996
#ifdef __SYCL_DEVICE_ONLY__
97+
// Adjust the scratch pointer based on alignment of the type T.
8098
// Per extension specification if scratch size is less than the value
8199
// returned by memory_required then behavior is undefined, so we don't check
82100
// that the scratch size statisfies the requirement.
101+
T *scratch_begin = nullptr;
102+
std::size_t local_id = g.get_local_linear_id();
83103
auto range_size = g.get_local_range().size();
84-
size_t local_id = g.get_local_linear_id();
85-
T *temp = reinterpret_cast<T *>(scratch.data());
86-
::new (temp + local_id) T(val);
87-
sycl::detail::merge_sort(g, temp, range_size, comp,
88-
scratch.data() + range_size * sizeof(T));
89-
val = temp[local_id];
104+
// We must have a barrier here before array placement new because it is
105+
// possible that scratch memory is already in use, so we need to synchronize
106+
// work items.
107+
sycl::group_barrier(g);
108+
if (g.leader()) {
109+
void *scratch_ptr = scratch.data();
110+
size_t space = scratch.size();
111+
scratch_ptr =
112+
std::align(alignof(T), /* output storage and temporary storage */ 2 *
113+
range_size * sizeof(T),
114+
scratch_ptr, space);
115+
scratch_begin = ::new (scratch_ptr) T[2 * range_size];
116+
}
117+
// Broadcast leader's pointer (the beginning of the scratch) to all work
118+
// items in the group.
119+
scratch_begin = sycl::group_broadcast(g, scratch_begin);
120+
scratch_begin[local_id] = val;
121+
sycl::detail::merge_sort(g, scratch_begin, range_size, comp,
122+
scratch_begin + range_size);
123+
val = scratch_begin[local_id];
90124
#else
91125
throw sycl::exception(
92126
std::error_code(PI_ERROR_INVALID_DEVICE, sycl::sycl_category()),

0 commit comments

Comments
 (0)