OpenFPM  5.2.0
Project that contain the implementation of distributed structures
sort_ofp.cuh
1 /*
2  * sort_ofp.cuh
3  *
4  * Created on: Aug 23, 2019
5  * Author: i-bird
6  */
7 
8 #ifndef SORT_OFP_CUH_
9 #define SORT_OFP_CUH_
10 
11 
12 #ifdef __NVCC__
13 
14 #include "util/cuda_util.hpp"
15 #include "util/ofp_context.hpp"
16 
17 #if CUDART_VERSION >= 11000
18  // Here we have for sure CUDA >= 11
19  #ifndef CUDA_ON_CPU
20  #ifdef __HIP__
21  #include "hipcub/hipcub.hpp"
22  #else
23  #include "cub/cub.cuh"
24  #endif
25  #endif
26 #else
27  // Here we have old CUDA
28  #include "cub_old/cub.cuh"
29 #endif
30 
31 
32 template<typename key_t, typename val_t>
33 struct key_val_ref;
34 
35 template<typename key_t, typename val_t>
36 struct key_val
37 {
38  key_t key;
39  val_t val;
40 
41  key_val(const key_t & k, const val_t & v)
42  :key(k),val(v)
43  {}
44 
45  key_val(const key_val_ref<key_t,val_t> & tmp)
46  {
47  this->operator=(tmp);
48  }
49 
50  bool operator<(const key_val & tmp) const
51  {
52  return key < tmp.key;
53  }
54 
55  bool operator>(const key_val & tmp) const
56  {
57  return key > tmp.key;
58  }
59 
60  key_val & operator=(const key_val_ref<key_t,val_t> & tmp)
61  {
62  key = tmp.key;
63  val = tmp.val;
64 
65  return *this;
66  }
67 };
68 
69 
70 template<typename key_t, typename val_t>
71 struct key_val_ref
72 {
73  key_t & key;
74  val_t & val;
75 
76  key_val_ref(key_t & k, val_t & v)
77  :key(k),val(v)
78  {}
79 
80  key_val_ref(key_val_ref<key_t,val_t> && tmp)
81  :key(tmp.key),val(tmp.val)
82  {}
83 
84  key_val_ref & operator=(const key_val<key_t,val_t> & tmp)
85  {
86  key = tmp.key;
87  val = tmp.val;
88 
89  return *this;
90  }
91 
92  key_val_ref & operator=(const key_val_ref<key_t,val_t> & tmp)
93  {
94  key = tmp.key;
95  val = tmp.val;
96 
97  return *this;
98  }
99 
100  bool operator<(const key_val_ref<key_t,val_t> & tmp)
101  {
102  return key < tmp.key;
103  }
104 
105  bool operator>(const key_val_ref<key_t,val_t> & tmp)
106  {
107  return key > tmp.key;
108  }
109 
110  bool operator<(const key_val<key_t,val_t> & tmp)
111  {
112  return key < tmp.key;
113  }
114 
115  bool operator>(const key_val<key_t,val_t> & tmp)
116  {
117  return key > tmp.key;
118  }
119 };
120 
121 
122 template<typename key_t, typename val_t>
123 struct key_val_it
124 {
125  key_t * key;
126  val_t * val;
127 
128 
129  key_val_it & operator+=(int delta)
130  {
131  key += delta;
132  val += delta;
133  return *this;
134  }
135 
136  bool operator==(const key_val_it & tmp)
137  {
138  return (key == tmp.key && val == tmp.val);
139  }
140 
141  key_val_ref<key_t,val_t> operator*()
142  {
143  return key_val_ref<key_t,val_t>(*key,*val);
144  }
145 
146  key_val_ref<key_t,val_t> operator[](int i)
147  {
148  return key_val_ref<key_t,val_t>(*key,*val);
149  }
150 
151  key_val_it operator+(size_t count) const
152  {
153  key_val_it tmp(key+count,val+count);
154 
155  return tmp;
156  }
157 
158 
159  size_t operator-(key_val_it & tmp) const
160  {
161  return key - tmp.key;
162  }
163 
164  key_val_it operator-(size_t count) const
165  {
166  key_val_it tmp(key-count,val-count);
167 
168  return tmp;
169  }
170 
171  key_val_it & operator++()
172  {
173  ++key;
174  ++val;
175 
176  return *this;
177  }
178 
179  key_val_it operator++(int)
180  {
181  key_val_it temp = *this;
182  ++*this;
183  return temp;
184  }
185 
186  key_val_it & operator--()
187  {
188  --key;
189  --val;
190 
191  return *this;
192  }
193 
194  bool operator!=(const key_val_it & tmp) const
195  {
196  return key != tmp.key && val != tmp.val;
197  }
198 
199  bool operator<(const key_val_it & tmp) const
200  {
201  return key < tmp.key;
202  }
203 
204  bool operator>(const key_val_it & tmp) const
205  {
206  return key > tmp.key;
207  }
208 
209  bool operator>=(const key_val_it & tmp) const
210  {
211  return key >= tmp.key;
212  }
213 
214  key_val_it<key_t,val_t> & operator=(key_val_it<key_t,val_t> & tmp)
215  {
216  key = tmp.key;
217  val = tmp.val;
218 
219  return *this;
220  }
221 
222  key_val_it() {}
223 
224  key_val_it(const key_val_it<key_t,val_t> & tmp)
225  :key(tmp.key),val(tmp.val)
226  {}
227 
228  key_val_it(key_t * key, val_t * val)
229  :key(key),val(val)
230  {}
231 };
232 
233 template<typename key_t, typename val_t>
234 void swap(key_val_ref<key_t,val_t> a, key_val_ref<key_t,val_t> b)
235 {
236  key_t kt = a.key;
237  a.key = b.key;
238  b.key = kt;
239 
240  val_t vt = a.val;
241  a.val = b.val;
242  b.val = vt;
243 }
244 
245 namespace std
246 {
247  template<typename key_t, typename val_t>
248  struct iterator_traits<key_val_it<key_t,val_t>>
249  {
250  typedef size_t difference_type; //almost always ptrdiff_t
251  typedef key_val<key_t,val_t> value_type; //almost always T
252  typedef key_val<key_t,val_t> & reference; //almost always T& or const T&
253  typedef key_val<key_t,val_t> & pointer; //almost always T* or const T*
254  typedef std::random_access_iterator_tag iterator_category; //usually std::forward_iterator_tag or similar
255  };
256 }
257 
258 
259 namespace openfpm
260 {
261  template<typename key_t, typename val_t,
262  typename comp_t>
263  void sort(key_t* keys_input, val_t* vals_input, int count,
264  comp_t comp, gpu::ofp_context_t& gpuContext)
265  {
266 #ifdef CUDA_ON_CPU
267 
268  key_val_it<key_t,val_t> kv(keys_input,vals_input);
269 
270  std::sort(kv,kv+count,comp);
271 
272 #else
273  #ifdef __HIP__
274 
275  void *d_temp_storage = NULL;
276  size_t temp_storage_bytes = 0;
277 
278  auto & temporal2 = gpuContext.getTemporalCUB2();
279  temporal2.resize(sizeof(key_t)*count);
280 
281  auto & temporal3 = gpuContext.getTemporalCUB3();
282  temporal3.resize(sizeof(val_t)*count);
283 
284  if (std::is_same<gpu::template less_t<key_t>,comp_t>::value == true)
285  {
286  hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
287  temp_storage_bytes,
288  keys_input,
289  (key_t *)temporal2.template getDeviceBuffer<0>(),
290  vals_input,
291  (val_t *)temporal3.template getDeviceBuffer<0>(),
292  count);
293 
294  auto & temporal = gpuContext.getTemporalCUB();
295  temporal.resize(temp_storage_bytes);
296 
297  d_temp_storage = temporal.template getDeviceBuffer<0>();
298 
299  // Run
300  hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
301  temp_storage_bytes,
302  keys_input,
303  (key_t *)temporal2.template getDeviceBuffer<0>(),
304  vals_input,
305  (val_t *)temporal3.template getDeviceBuffer<0>(),
306  count);
307  }
308  else if (std::is_same<gpu::template greater_t<key_t>,comp_t>::value == true)
309  {
310  hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
311  temp_storage_bytes,
312  keys_input,
313  (key_t *)temporal2.template getDeviceBuffer<0>(),
314  vals_input,
315  (val_t *)temporal3.template getDeviceBuffer<0>(),
316  count);
317 
318  auto & temporal = gpuContext.getTemporalCUB();
319  temporal.resize(temp_storage_bytes);
320 
321  d_temp_storage = temporal.template getDeviceBuffer<0>();
322 
323  // Run
324  hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
325  temp_storage_bytes,
326  keys_input,
327  (key_t *)temporal2.template getDeviceBuffer<0>(),
328  vals_input,
329  (val_t *)temporal3.template getDeviceBuffer<0>(),
330  count);
331  }
332 
333  cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
334  cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
335 
336  #else
337 
338  void *d_temp_storage = NULL;
339  size_t temp_storage_bytes = 0;
340 
341  auto & temporal2 = gpuContext.getTemporalCUB2();
342  temporal2.resize(sizeof(key_t)*count);
343 
344  auto & temporal3 = gpuContext.getTemporalCUB3();
345  temporal3.resize(sizeof(val_t)*count);
346 
347  if (std::is_same<gpu::template less_t<key_t>,comp_t>::value == true)
348  {
349  cub::DeviceRadixSort::SortPairs(d_temp_storage,
350  temp_storage_bytes,
351  keys_input,
352  (key_t *)temporal2.template getDeviceBuffer<0>(),
353  vals_input,
354  (val_t *)temporal3.template getDeviceBuffer<0>(),
355  count);
356 
357  auto & temporal = gpuContext.getTemporalCUB();
358  temporal.resize(temp_storage_bytes);
359 
360  d_temp_storage = temporal.template getDeviceBuffer<0>();
361 
362  // Run
363  cub::DeviceRadixSort::SortPairs(d_temp_storage,
364  temp_storage_bytes,
365  keys_input,
366  (key_t *)temporal2.template getDeviceBuffer<0>(),
367  vals_input,
368  (val_t *)temporal3.template getDeviceBuffer<0>(),
369  count);
370  }
371  else if (std::is_same<gpu::template greater_t<key_t>,comp_t>::value == true)
372  {
374  temp_storage_bytes,
375  keys_input,
376  (key_t *)temporal2.template getDeviceBuffer<0>(),
377  vals_input,
378  (val_t *)temporal3.template getDeviceBuffer<0>(),
379  count);
380 
381  auto & temporal = gpuContext.getTemporalCUB();
382  temporal.resize(temp_storage_bytes);
383 
384  d_temp_storage = temporal.template getDeviceBuffer<0>();
385 
386  // Run
388  temp_storage_bytes,
389  keys_input,
390  (key_t *)temporal2.template getDeviceBuffer<0>(),
391  vals_input,
392  (val_t *)temporal3.template getDeviceBuffer<0>(),
393  count);
394  }
395 
396  cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
397  cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
398 
399  #endif
400 #endif
401  }
402 }
403 
404 #endif
405 
406 
407 #endif /* SORT_OFP_CUH_ */
convert a type into constant type
Definition: aggregate.hpp:302
static CUB_RUNTIME_FUNCTION cudaError_t SortPairs(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into ascending order. (~2N auxiliary storage required)
static CUB_RUNTIME_FUNCTION cudaError_t SortPairsDescending(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into descending order. (~2N auxiliary storage required).