@@ -68,22 +68,10 @@ struct GetValueType<sycl::multi_ptr<ElementType, Space, IsDecorated>> {
68
68
using type = ElementType;
69
69
};
70
70
71
- // since we couldn't assign data to raw memory, it's better to use placement
72
- // for first assignment
73
- template <typename Acc, typename T>
74
- void set_value (Acc ptr, const size_t idx, const T &val, bool is_first) {
75
- if (is_first) {
76
- ::new (ptr + idx) T (val);
77
- } else {
78
- ptr[idx] = val;
79
- }
80
- }
81
-
82
71
template <typename InAcc, typename OutAcc, typename Compare>
83
72
void merge (const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
84
73
const size_t start_1, const size_t end_1, const size_t end_2,
85
- const size_t start_out, Compare comp, const size_t chunk,
86
- bool is_first) {
74
+ const size_t start_out, Compare comp, const size_t chunk) {
87
75
const size_t start_2 = end_1;
88
76
// Borders of the sequences to merge within this call
89
77
const size_t local_start_1 =
@@ -111,8 +99,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
111
99
const size_t l_shift_1 = local_start_1 - start_1;
112
100
const size_t l_shift_2 = l_search_bound_2 - start_2;
113
101
114
- set_value (out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_1,
115
- is_first);
102
+ out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_1;
116
103
117
104
size_t r_search_bound_2{};
118
105
// find right border in 2nd sequence
@@ -123,8 +110,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
123
110
const auto r_shift_1 = local_end_1 - 1 - start_1;
124
111
const auto r_shift_2 = r_search_bound_2 - start_2;
125
112
126
- set_value (out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_1,
127
- is_first);
113
+ out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_1;
128
114
}
129
115
130
116
// Handle intermediate items
@@ -138,8 +124,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
138
124
const size_t shift_1 = idx - start_1;
139
125
const size_t shift_2 = l_search_bound_2 - start_2;
140
126
141
- set_value (out_acc1, start_out + shift_1 + shift_2, intermediate_item_1,
142
- is_first);
127
+ out_acc1[start_out + shift_1 + shift_2] = intermediate_item_1;
143
128
}
144
129
}
145
130
// Process 2nd sequence
@@ -152,8 +137,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
152
137
const size_t l_shift_1 = l_search_bound_1 - start_1;
153
138
const size_t l_shift_2 = local_start_2 - start_2;
154
139
155
- set_value (out_acc1, start_out + l_shift_1 + l_shift_2, local_l_item_2,
156
- is_first);
140
+ out_acc1[start_out + l_shift_1 + l_shift_2] = local_l_item_2;
157
141
158
142
size_t r_search_bound_1{};
159
143
// find right border in 1st sequence
@@ -164,8 +148,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
164
148
const size_t r_shift_1 = r_search_bound_1 - start_1;
165
149
const size_t r_shift_2 = local_end_2 - 1 - start_2;
166
150
167
- set_value (out_acc1, start_out + r_shift_1 + r_shift_2, local_r_item_2,
168
- is_first);
151
+ out_acc1[start_out + r_shift_1 + r_shift_2] = local_r_item_2;
169
152
}
170
153
171
154
// Handle intermediate items
@@ -179,8 +162,7 @@ void merge(const size_t offset, InAcc &in_acc1, OutAcc &out_acc1,
179
162
const size_t shift_1 = l_search_bound_1 - start_1;
180
163
const size_t shift_2 = idx - start_2;
181
164
182
- set_value (out_acc1, start_out + shift_1 + shift_2, intermediate_item_2,
183
- is_first);
165
+ out_acc1[start_out + shift_1 + shift_2] = intermediate_item_2;
184
166
}
185
167
}
186
168
}
@@ -200,10 +182,9 @@ void bubble_sort(Iter first, const size_t begin, const size_t end,
200
182
}
201
183
}
202
184
203
- template <typename Group, typename Iter, typename Compare>
185
+ template <typename Group, typename Iter, typename T, typename Compare>
204
186
void merge_sort (Group group, Iter first, const size_t n, Compare comp,
205
- std::byte *scratch) {
206
- using T = typename GetValueType<Iter>::type;
187
+ T *scratch) {
207
188
const size_t idx = group.get_local_linear_id ();
208
189
const size_t local = group.get_local_range ().size ();
209
190
const size_t chunk = (n - 1 ) / local + 1 ;
@@ -212,9 +193,7 @@ void merge_sort(Group group, Iter first, const size_t n, Compare comp,
212
193
bubble_sort (first, idx * chunk, sycl::min ((idx + 1 ) * chunk, n), comp);
213
194
sycl::group_barrier (group);
214
195
215
- T *temp = reinterpret_cast <T *>(scratch);
216
- bool data_in_temp = false ;
217
- bool is_first = true ;
196
+ bool data_in_scratch = false ;
218
197
size_t sorted_size = 1 ;
219
198
while (sorted_size * chunk < n) {
220
199
const size_t start_1 =
@@ -223,26 +202,24 @@ void merge_sort(Group group, Iter first, const size_t n, Compare comp,
223
202
const size_t end_2 = sycl::min (end_1 + sorted_size * chunk, n);
224
203
const size_t offset = chunk * (idx % sorted_size);
225
204
226
- if (!data_in_temp ) {
227
- merge (offset, first, temp , start_1, end_1, end_2, start_1, comp, chunk ,
228
- is_first );
205
+ if (!data_in_scratch ) {
206
+ merge (offset, first, scratch , start_1, end_1, end_2, start_1, comp,
207
+ chunk );
229
208
} else {
230
- merge (offset, temp , first, start_1, end_1, end_2, start_1, comp, chunk ,
231
- /* is_first */ false );
209
+ merge (offset, scratch , first, start_1, end_1, end_2, start_1, comp,
210
+ chunk );
232
211
}
233
212
sycl::group_barrier (group);
234
213
235
- data_in_temp = !data_in_temp ;
214
+ data_in_scratch = !data_in_scratch ;
236
215
sorted_size *= 2 ;
237
- if (is_first)
238
- is_first = false ;
239
216
}
240
217
241
218
// copy back if data is in a temporary storage
242
- if (data_in_temp ) {
219
+ if (data_in_scratch ) {
243
220
for (size_t i = 0 ; i < chunk; ++i) {
244
221
if (idx * chunk + i < n) {
245
- first[idx * chunk + i] = temp [idx * chunk + i];
222
+ first[idx * chunk + i] = scratch [idx * chunk + i];
246
223
}
247
224
}
248
225
sycl::group_barrier (group);
0 commit comments