14 #include "util/cuda_launch.hpp" 16 #if CUDART_VERSION >= 11000 20 #include "hipcub/hipcub.hpp" 22 #include "cub/cub.cuh" 31 #include "util/cuda/moderngpu/kernel_mergesort.hxx" 34 #include "util/cuda/ofp_context.hxx" 36 template<
typename key_t,
typename val_t>
39 template<
typename key_t,
typename val_t>
45 key_val(
const key_t & k,
const val_t & v)
49 key_val(
const key_val_ref<key_t,val_t> & tmp)
54 bool operator<(
const key_val & tmp)
const 59 bool operator>(
const key_val & tmp)
const 64 key_val & operator=(
const key_val_ref<key_t,val_t> & tmp)
74 template<
typename key_t,
typename val_t>
80 key_val_ref(key_t & k, val_t & v)
84 key_val_ref(key_val_ref<key_t,val_t> && tmp)
85 :key(tmp.key),val(tmp.val)
88 key_val_ref & operator=(
const key_val<key_t,val_t> & tmp)
96 key_val_ref & operator=(
const key_val_ref<key_t,val_t> & tmp)
104 bool operator<(
const key_val_ref<key_t,val_t> & tmp)
106 return key < tmp.key;
109 bool operator>(
const key_val_ref<key_t,val_t> & tmp)
111 return key > tmp.key;
114 bool operator<(
const key_val<key_t,val_t> & tmp)
116 return key < tmp.key;
119 bool operator>(
const key_val<key_t,val_t> & tmp)
121 return key > tmp.key;
126 template<
typename key_t,
typename val_t>
133 key_val_it & operator+=(
int delta)
140 bool operator==(
const key_val_it & tmp)
142 return (key == tmp.key && val == tmp.val);
145 key_val_ref<key_t,val_t> operator*()
147 return key_val_ref<key_t,val_t>(*key,*val);
150 key_val_ref<key_t,val_t> operator[](
int i)
152 return key_val_ref<key_t,val_t>(*key,*val);
155 key_val_it operator+(
size_t count)
const 157 key_val_it tmp(key+count,val+count);
163 size_t operator-(key_val_it & tmp)
const 165 return key - tmp.key;
168 key_val_it operator-(
size_t count)
const 170 key_val_it tmp(key-count,val-count);
175 key_val_it & operator++()
183 key_val_it operator++(
int)
185 key_val_it temp = *
this;
190 key_val_it & operator--()
198 bool operator!=(
const key_val_it & tmp)
const 200 return key != tmp.key && val != tmp.val;
203 bool operator<(
const key_val_it & tmp)
const 205 return key < tmp.key;
208 bool operator>(
const key_val_it & tmp)
const 210 return key > tmp.key;
213 bool operator>=(
const key_val_it & tmp)
const 215 return key >= tmp.key;
218 key_val_it<key_t,val_t> & operator=(key_val_it<key_t,val_t> & tmp)
228 key_val_it(
const key_val_it<key_t,val_t> & tmp)
229 :key(tmp.key),val(tmp.val)
232 key_val_it(key_t * key, val_t * val)
237 template<
typename key_t,
typename val_t>
238 void swap(key_val_ref<key_t,val_t> a, key_val_ref<key_t,val_t> b)
251 template<
typename key_t,
typename val_t>
252 struct iterator_traits<key_val_it<key_t,val_t>>
254 typedef size_t difference_type;
255 typedef key_val<key_t,val_t> value_type;
256 typedef key_val<key_t,val_t> & reference;
257 typedef key_val<key_t,val_t> & pointer;
258 typedef std::random_access_iterator_tag iterator_category;
265 template<
typename key_t,
typename val_t,
267 void sort(key_t* keys_input, val_t* vals_input,
int count,
268 comp_t comp, mgpu::ofp_context_t& context)
272 key_val_it<key_t,val_t> kv(keys_input,vals_input);
274 std::sort(kv,kv+count,comp);
282 void *d_temp_storage = NULL;
283 size_t temp_storage_bytes = 0;
285 auto & temporal2 = context.getTemporalCUB2();
286 temporal2.resize(
sizeof(key_t)*count);
288 auto & temporal3 = context.getTemporalCUB3();
289 temporal3.resize(
sizeof(val_t)*count);
291 if (std::is_same<mgpu::template less_t<key_t>,comp_t>::value ==
true)
293 hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
296 (key_t *)temporal2.template getDeviceBuffer<0>(),
298 (val_t *)temporal3.template getDeviceBuffer<0>(),
301 auto & temporal = context.getTemporalCUB();
302 temporal.resize(temp_storage_bytes);
304 d_temp_storage = temporal.template getDeviceBuffer<0>();
307 hipcub::DeviceRadixSort::SortPairs(d_temp_storage,
310 (key_t *)temporal2.template getDeviceBuffer<0>(),
312 (val_t *)temporal3.template getDeviceBuffer<0>(),
315 else if (std::is_same<mgpu::template greater_t<key_t>,comp_t>::value ==
true)
317 hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
320 (key_t *)temporal2.template getDeviceBuffer<0>(),
322 (val_t *)temporal3.template getDeviceBuffer<0>(),
325 auto & temporal = context.getTemporalCUB();
326 temporal.resize(temp_storage_bytes);
328 d_temp_storage = temporal.template getDeviceBuffer<0>();
331 hipcub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
334 (key_t *)temporal2.template getDeviceBuffer<0>(),
336 (val_t *)temporal3.template getDeviceBuffer<0>(),
340 cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),
sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
341 cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),
sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
346 void *d_temp_storage = NULL;
347 size_t temp_storage_bytes = 0;
349 auto & temporal2 = context.getTemporalCUB2();
350 temporal2.resize(
sizeof(key_t)*count);
352 auto & temporal3 = context.getTemporalCUB3();
353 temporal3.resize(
sizeof(val_t)*count);
355 if (std::is_same<mgpu::template less_t<key_t>,comp_t>::value ==
true)
360 (key_t *)temporal2.template getDeviceBuffer<0>(),
362 (val_t *)temporal3.template getDeviceBuffer<0>(),
365 auto & temporal = context.getTemporalCUB();
366 temporal.resize(temp_storage_bytes);
368 d_temp_storage = temporal.template getDeviceBuffer<0>();
374 (key_t *)temporal2.template getDeviceBuffer<0>(),
376 (val_t *)temporal3.template getDeviceBuffer<0>(),
379 else if (std::is_same<mgpu::template greater_t<key_t>,comp_t>::value ==
true)
384 (key_t *)temporal2.template getDeviceBuffer<0>(),
386 (val_t *)temporal3.template getDeviceBuffer<0>(),
389 auto & temporal = context.getTemporalCUB();
390 temporal.resize(temp_storage_bytes);
392 d_temp_storage = temporal.template getDeviceBuffer<0>();
398 (key_t *)temporal2.template getDeviceBuffer<0>(),
400 (val_t *)temporal3.template getDeviceBuffer<0>(),
404 cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),
sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
405 cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),
sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
410 mgpu::mergesort(keys_input,vals_input,count,comp,context);
convert a type into constant type
static CUB_RUNTIME_FUNCTION cudaError_t SortPairsDescending(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into descending order. (~2N auxiliary storage required).
static CUB_RUNTIME_FUNCTION cudaError_t SortPairs(void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT) *8, cudaStream_t stream=0, bool debug_synchronous=false)
Sorts key-value pairs into ascending order. (~2N auxiliary storage required)