14#include "util/cuda_launch.hpp"
15#include "util/ofp_context.hpp"
17#if CUDART_VERSION >= 11000
21 #include "hipcub/hipcub.hpp"
23 #include "cub/cub.cuh"
32template<
typename key_t,
typename val_t>
35template<
typename key_t,
typename val_t>
41 key_val(
const key_t & k,
const val_t & v)
45 key_val(
const key_val_ref<key_t,val_t> & tmp)
50 bool operator<(
const key_val & tmp)
const
55 bool operator>(
const key_val & tmp)
const
60 key_val & operator=(
const key_val_ref<key_t,val_t> & tmp)
70template<
typename key_t,
typename val_t>
76 key_val_ref(key_t & k, val_t & v)
80 key_val_ref(key_val_ref<key_t,val_t> && tmp)
81 :key(tmp.key),val(tmp.val)
84 key_val_ref & operator=(
const key_val<key_t,val_t> & tmp)
92 key_val_ref & operator=(
const key_val_ref<key_t,val_t> & tmp)
100 bool operator<(
const key_val_ref<key_t,val_t> & tmp)
102 return key < tmp.key;
105 bool operator>(
const key_val_ref<key_t,val_t> & tmp)
107 return key > tmp.key;
110 bool operator<(
const key_val<key_t,val_t> & tmp)
112 return key < tmp.key;
115 bool operator>(
const key_val<key_t,val_t> & tmp)
117 return key > tmp.key;
122template<
typename key_t,
typename val_t>
129 key_val_it & operator+=(
int delta)
136 bool operator==(
const key_val_it & tmp)
138 return (key == tmp.key && val == tmp.val);
141 key_val_ref<key_t,val_t> operator*()
143 return key_val_ref<key_t,val_t>(*key,*val);
146 key_val_ref<key_t,val_t> operator[](
int i)
148 return key_val_ref<key_t,val_t>(*key,*val);
151 key_val_it operator+(
size_t count)
const
153 key_val_it tmp(key+count,val+count);
159 size_t operator-(key_val_it & tmp)
const
161 return key - tmp.key;
164 key_val_it operator-(
size_t count)
const
166 key_val_it tmp(key-count,val-count);
171 key_val_it & operator++()
179 key_val_it operator++(
int)
181 key_val_it temp = *
this;
186 key_val_it & operator--()
194 bool operator!=(
const key_val_it & tmp)
const
196 return key != tmp.key && val != tmp.val;
199 bool operator<(
const key_val_it & tmp)
const
201 return key < tmp.key;
204 bool operator>(
const key_val_it & tmp)
const
206 return key > tmp.key;
209 bool operator>=(
const key_val_it & tmp)
const
211 return key >= tmp.key;
214 key_val_it<key_t,val_t> & operator=(key_val_it<key_t,val_t> & tmp)
224 key_val_it(
const key_val_it<key_t,val_t> & tmp)
225 :key(tmp.key),val(tmp.val)
228 key_val_it(key_t * key, val_t * val)
233template<
typename key_t,
typename val_t>
234void swap(key_val_ref<key_t,val_t> a, key_val_ref<key_t,val_t> b)
247 template<
typename key_t,
typename val_t>
248 struct iterator_traits<key_val_it<key_t,val_t>>
250 typedef size_t difference_type;
251 typedef key_val<key_t,val_t> value_type;
252 typedef key_val<key_t,val_t> & reference;
253 typedef key_val<key_t,val_t> & pointer;
254 typedef std::random_access_iterator_tag iterator_category;
261 template<
typename key_t,
typename val_t,
263 void sort(key_t* keys_input, val_t* vals_input,
int count,
268 key_val_it<key_t,val_t> kv(keys_input,vals_input);
270 std::sort(kv,kv+count,comp);
275 void *d_temp_storage = NULL;
276 size_t temp_storage_bytes = 0;
278 auto & temporal2 = context.getTemporalCUB2();
279 temporal2.resize(
sizeof(key_t)*count);
281 auto & temporal3 = context.getTemporalCUB3();
282 temporal3.resize(
sizeof(val_t)*count);
284 if (std::is_same<gpu::template less_t<key_t>,comp_t>::value ==
true)
286 hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
289 (key_t *)temporal2.template getDeviceBuffer<0>(),
291 (val_t *)temporal3.template getDeviceBuffer<0>(),
294 auto & temporal = context.getTemporalCUB();
295 temporal.resize(temp_storage_bytes);
297 d_temp_storage = temporal.template getDeviceBuffer<0>();
300 hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
303 (key_t *)temporal2.template getDeviceBuffer<0>(),
305 (val_t *)temporal3.template getDeviceBuffer<0>(),
308 else if (std::is_same<gpu::template greater_t<key_t>,comp_t>::value ==
true)
310 hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
313 (key_t *)temporal2.template getDeviceBuffer<0>(),
315 (val_t *)temporal3.template getDeviceBuffer<0>(),
318 auto & temporal = context.getTemporalCUB();
319 temporal.resize(temp_storage_bytes);
321 d_temp_storage = temporal.template getDeviceBuffer<0>();
324 hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
327 (key_t *)temporal2.template getDeviceBuffer<0>(),
329 (val_t *)temporal3.template getDeviceBuffer<0>(),
333 cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),
sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
334 cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),
sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
338 void *d_temp_storage = NULL;
339 size_t temp_storage_bytes = 0;
341 auto & temporal2 = context.getTemporalCUB2();
342 temporal2.resize(
sizeof(key_t)*count);
344 auto & temporal3 = context.getTemporalCUB3();
345 temporal3.resize(
sizeof(val_t)*count);
347 if (std::is_same<gpu::template less_t<key_t>,comp_t>::value ==
true)
352 (key_t *)temporal2.template getDeviceBuffer<0>(),
354 (val_t *)temporal3.template getDeviceBuffer<0>(),
357 auto & temporal = context.getTemporalCUB();
358 temporal.resize(temp_storage_bytes);
360 d_temp_storage = temporal.template getDeviceBuffer<0>();
366 (key_t *)temporal2.template getDeviceBuffer<0>(),
368 (val_t *)temporal3.template getDeviceBuffer<0>(),
371 else if (std::is_same<gpu::template greater_t<key_t>,comp_t>::value ==
true)
376 (key_t *)temporal2.template getDeviceBuffer<0>(),
378 (val_t *)temporal3.template getDeviceBuffer<0>(),
381 auto & temporal = context.getTemporalCUB();
382 temporal.resize(temp_storage_bytes);
384 d_temp_storage = temporal.template getDeviceBuffer<0>();
390 (key_t *)temporal2.template getDeviceBuffer<0>(),
392 (val_t *)temporal3.template getDeviceBuffer<0>(),
396 cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),
sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
397 cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),
sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
convert a type into constant type
static CUB_RUNTIME_FUNCTION cudaError_t SortPairs(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into ascending order. (~2N auxiliary storage required)
static CUB_RUNTIME_FUNCTION cudaError_t SortPairsDescending(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into descending order. (~2N auxiliary storage required).