OpenFPM_pdata  4.1.0
Project that contain the implementation of distributed structures
 
Loading...
Searching...
No Matches
sort_ofp.cuh
1/*
2 * sort_ofp.cuh
3 *
4 * Created on: Aug 23, 2019
5 * Author: i-bird
6 */
7
8#ifndef SORT_OFP_CUH_
9#define SORT_OFP_CUH_
10
11
12#ifdef __NVCC__
13
14#include "util/cuda_launch.hpp"
15#include "util/ofp_context.hpp"
16
17#if CUDART_VERSION >= 11000
18 // Here we have for sure CUDA >= 11
19 #ifndef CUDA_ON_CPU
20 #ifdef __HIP__
21 #include "hipcub/hipcub.hpp"
22 #else
23 #include "cub/cub.cuh"
24 #endif
25 #endif
26#else
27 // Here we have old CUDA
28 #include "cub_old/cub.cuh"
29#endif
30
31
32template<typename key_t, typename val_t>
33struct key_val_ref;
34
35template<typename key_t, typename val_t>
36struct key_val
37{
38 key_t key;
39 val_t val;
40
41 key_val(const key_t & k, const val_t & v)
42 :key(k),val(v)
43 {}
44
45 key_val(const key_val_ref<key_t,val_t> & tmp)
46 {
47 this->operator=(tmp);
48 }
49
50 bool operator<(const key_val & tmp) const
51 {
52 return key < tmp.key;
53 }
54
55 bool operator>(const key_val & tmp) const
56 {
57 return key > tmp.key;
58 }
59
60 key_val & operator=(const key_val_ref<key_t,val_t> & tmp)
61 {
62 key = tmp.key;
63 val = tmp.val;
64
65 return *this;
66 }
67};
68
69
70template<typename key_t, typename val_t>
71struct key_val_ref
72{
73 key_t & key;
74 val_t & val;
75
76 key_val_ref(key_t & k, val_t & v)
77 :key(k),val(v)
78 {}
79
80 key_val_ref(key_val_ref<key_t,val_t> && tmp)
81 :key(tmp.key),val(tmp.val)
82 {}
83
84 key_val_ref & operator=(const key_val<key_t,val_t> & tmp)
85 {
86 key = tmp.key;
87 val = tmp.val;
88
89 return *this;
90 }
91
92 key_val_ref & operator=(const key_val_ref<key_t,val_t> & tmp)
93 {
94 key = tmp.key;
95 val = tmp.val;
96
97 return *this;
98 }
99
100 bool operator<(const key_val_ref<key_t,val_t> & tmp)
101 {
102 return key < tmp.key;
103 }
104
105 bool operator>(const key_val_ref<key_t,val_t> & tmp)
106 {
107 return key > tmp.key;
108 }
109
110 bool operator<(const key_val<key_t,val_t> & tmp)
111 {
112 return key < tmp.key;
113 }
114
115 bool operator>(const key_val<key_t,val_t> & tmp)
116 {
117 return key > tmp.key;
118 }
119};
120
121
122template<typename key_t, typename val_t>
123struct key_val_it
124{
125 key_t * key;
126 val_t * val;
127
128
129 key_val_it & operator+=(int delta)
130 {
131 key += delta;
132 val += delta;
133 return *this;
134 }
135
136 bool operator==(const key_val_it & tmp)
137 {
138 return (key == tmp.key && val == tmp.val);
139 }
140
141 key_val_ref<key_t,val_t> operator*()
142 {
143 return key_val_ref<key_t,val_t>(*key,*val);
144 }
145
146 key_val_ref<key_t,val_t> operator[](int i)
147 {
148 return key_val_ref<key_t,val_t>(*key,*val);
149 }
150
151 key_val_it operator+(size_t count) const
152 {
153 key_val_it tmp(key+count,val+count);
154
155 return tmp;
156 }
157
158
159 size_t operator-(key_val_it & tmp) const
160 {
161 return key - tmp.key;
162 }
163
164 key_val_it operator-(size_t count) const
165 {
166 key_val_it tmp(key-count,val-count);
167
168 return tmp;
169 }
170
171 key_val_it & operator++()
172 {
173 ++key;
174 ++val;
175
176 return *this;
177 }
178
179 key_val_it operator++(int)
180 {
181 key_val_it temp = *this;
182 ++*this;
183 return temp;
184 }
185
186 key_val_it & operator--()
187 {
188 --key;
189 --val;
190
191 return *this;
192 }
193
194 bool operator!=(const key_val_it & tmp) const
195 {
196 return key != tmp.key && val != tmp.val;
197 }
198
199 bool operator<(const key_val_it & tmp) const
200 {
201 return key < tmp.key;
202 }
203
204 bool operator>(const key_val_it & tmp) const
205 {
206 return key > tmp.key;
207 }
208
209 bool operator>=(const key_val_it & tmp) const
210 {
211 return key >= tmp.key;
212 }
213
214 key_val_it<key_t,val_t> & operator=(key_val_it<key_t,val_t> & tmp)
215 {
216 key = tmp.key;
217 val = tmp.val;
218
219 return *this;
220 }
221
222 key_val_it() {}
223
224 key_val_it(const key_val_it<key_t,val_t> & tmp)
225 :key(tmp.key),val(tmp.val)
226 {}
227
228 key_val_it(key_t * key, val_t * val)
229 :key(key),val(val)
230 {}
231};
232
233template<typename key_t, typename val_t>
234void swap(key_val_ref<key_t,val_t> a, key_val_ref<key_t,val_t> b)
235{
236 key_t kt = a.key;
237 a.key = b.key;
238 b.key = kt;
239
240 val_t vt = a.val;
241 a.val = b.val;
242 b.val = vt;
243}
244
245namespace std
246{
247 template<typename key_t, typename val_t>
248 struct iterator_traits<key_val_it<key_t,val_t>>
249 {
250 typedef size_t difference_type; //almost always ptrdiff_t
251 typedef key_val<key_t,val_t> value_type; //almost always T
252 typedef key_val<key_t,val_t> & reference; //almost always T& or const T&
253 typedef key_val<key_t,val_t> & pointer; //almost always T* or const T*
254 typedef std::random_access_iterator_tag iterator_category; //usually std::forward_iterator_tag or similar
255 };
256}
257
258
259namespace openfpm
260{
261 template<typename key_t, typename val_t,
262 typename comp_t>
263 void sort(key_t* keys_input, val_t* vals_input, int count,
264 comp_t comp, gpu::ofp_context_t& context)
265 {
266#ifdef CUDA_ON_CPU
267
268 key_val_it<key_t,val_t> kv(keys_input,vals_input);
269
270 std::sort(kv,kv+count,comp);
271
272#else
273 #ifdef __HIP__
274
275 void *d_temp_storage = NULL;
276 size_t temp_storage_bytes = 0;
277
278 auto & temporal2 = context.getTemporalCUB2();
279 temporal2.resize(sizeof(key_t)*count);
280
281 auto & temporal3 = context.getTemporalCUB3();
282 temporal3.resize(sizeof(val_t)*count);
283
284 if (std::is_same<gpu::template less_t<key_t>,comp_t>::value == true)
285 {
286 hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
287 temp_storage_bytes,
288 keys_input,
289 (key_t *)temporal2.template getDeviceBuffer<0>(),
290 vals_input,
291 (val_t *)temporal3.template getDeviceBuffer<0>(),
292 count);
293
294 auto & temporal = context.getTemporalCUB();
295 temporal.resize(temp_storage_bytes);
296
297 d_temp_storage = temporal.template getDeviceBuffer<0>();
298
299 // Run
300 hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
301 temp_storage_bytes,
302 keys_input,
303 (key_t *)temporal2.template getDeviceBuffer<0>(),
304 vals_input,
305 (val_t *)temporal3.template getDeviceBuffer<0>(),
306 count);
307 }
308 else if (std::is_same<gpu::template greater_t<key_t>,comp_t>::value == true)
309 {
310 hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
311 temp_storage_bytes,
312 keys_input,
313 (key_t *)temporal2.template getDeviceBuffer<0>(),
314 vals_input,
315 (val_t *)temporal3.template getDeviceBuffer<0>(),
316 count);
317
318 auto & temporal = context.getTemporalCUB();
319 temporal.resize(temp_storage_bytes);
320
321 d_temp_storage = temporal.template getDeviceBuffer<0>();
322
323 // Run
324 hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
325 temp_storage_bytes,
326 keys_input,
327 (key_t *)temporal2.template getDeviceBuffer<0>(),
328 vals_input,
329 (val_t *)temporal3.template getDeviceBuffer<0>(),
330 count);
331 }
332
333 cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
334 cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
335
336 #else
337
338 void *d_temp_storage = NULL;
339 size_t temp_storage_bytes = 0;
340
341 auto & temporal2 = context.getTemporalCUB2();
342 temporal2.resize(sizeof(key_t)*count);
343
344 auto & temporal3 = context.getTemporalCUB3();
345 temporal3.resize(sizeof(val_t)*count);
346
347 if (std::is_same<gpu::template less_t<key_t>,comp_t>::value == true)
348 {
349 cub::DeviceRadixSort::SortPairs(d_temp_storage,
350 temp_storage_bytes,
351 keys_input,
352 (key_t *)temporal2.template getDeviceBuffer<0>(),
353 vals_input,
354 (val_t *)temporal3.template getDeviceBuffer<0>(),
355 count);
356
357 auto & temporal = context.getTemporalCUB();
358 temporal.resize(temp_storage_bytes);
359
360 d_temp_storage = temporal.template getDeviceBuffer<0>();
361
362 // Run
363 cub::DeviceRadixSort::SortPairs(d_temp_storage,
364 temp_storage_bytes,
365 keys_input,
366 (key_t *)temporal2.template getDeviceBuffer<0>(),
367 vals_input,
368 (val_t *)temporal3.template getDeviceBuffer<0>(),
369 count);
370 }
371 else if (std::is_same<gpu::template greater_t<key_t>,comp_t>::value == true)
372 {
374 temp_storage_bytes,
375 keys_input,
376 (key_t *)temporal2.template getDeviceBuffer<0>(),
377 vals_input,
378 (val_t *)temporal3.template getDeviceBuffer<0>(),
379 count);
380
381 auto & temporal = context.getTemporalCUB();
382 temporal.resize(temp_storage_bytes);
383
384 d_temp_storage = temporal.template getDeviceBuffer<0>();
385
386 // Run
388 temp_storage_bytes,
389 keys_input,
390 (key_t *)temporal2.template getDeviceBuffer<0>(),
391 vals_input,
392 (val_t *)temporal3.template getDeviceBuffer<0>(),
393 count);
394 }
395
396 cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
397 cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
398
399 #endif
400#endif
401 }
402}
403
404#endif
405
406
407#endif /* SORT_OFP_CUH_ */
convert a type into constant type
static CUB_RUNTIME_FUNCTION cudaError_t SortPairs(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into ascending order. (~2N auxiliary storage required)
static CUB_RUNTIME_FUNCTION cudaError_t SortPairsDescending(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into descending order. (~2N auxiliary storage required).