OpenFPM  5.2.0
Project that contain the implementation of distributed structures
merge_ofp.cuh
1 /*
2  * segreduce_ofp.hpp
3  *
4  * Created on: May 15, 2019
5  * Author: i-bird
6  */
7 
8  #ifndef MERGE_OFP_HPP_
9  #define MERGE_OFP_HPP_
10 
11  #ifdef __NVCC__
12 
13  #include "Vector/map_vector.hpp"
14  #include "util/cuda_util.hpp"
15 
16  #ifndef CUDA_ON_CPU
17  // Here we have for sure CUDA >= 11
18  #ifdef __HIP__
19  #undef __CUDACC__
20  #undef __CUDA__
21  #include <thrust/merge.h>
22  #include <thrust/execution_policy.h>
23  #define __CUDACC__
24  #define __CUDA__
25  #else
26  #include <thrust/merge.h>
27  #include <thrust/execution_policy.h>
28  #endif
29  #endif
30 
31 
32  namespace openfpm
33  {
34  template<typename a_keys_it, typename a_vals_it,
35  typename b_keys_it, typename b_vals_it,
36  typename c_keys_it, typename c_vals_it,
37  typename comp_t, typename context_t>
38  void merge(a_keys_it a_keys, a_vals_it a_vals, int a_count,
39  b_keys_it b_keys, b_vals_it b_vals, int b_count,
40  c_keys_it c_keys, c_vals_it c_vals, comp_t comp, context_t& gpuContext)
41  {
42  #ifdef CUDA_ON_CPU
43 
44  int a_it = 0;
45  int b_it = 0;
46  int c_it = 0;
47 
48  while (a_it < a_count || b_it < b_count)
49  {
50  if (a_it < a_count)
51  {
52  if (b_it < b_count)
53  {
54  if (comp(b_keys[b_it],a_keys[a_it]))
55  {
56  c_keys[c_it] = b_keys[b_it];
57  c_vals[c_it] = b_vals[b_it];
58  c_it++;
59  b_it++;
60  }
61  else
62  {
63  c_keys[c_it] = a_keys[a_it];
64  c_vals[c_it] = a_vals[a_it];
65  c_it++;
66  a_it++;
67  }
68  }
69  else
70  {
71  c_keys[c_it] = a_keys[a_it];
72  c_vals[c_it] = a_vals[a_it];
73  c_it++;
74  a_it++;
75  }
76  }
77  else
78  {
79  c_keys[c_it] = b_keys[b_it];
80  c_vals[c_it] = b_vals[b_it];
81  c_it++;
82  b_it++;
83  }
84  }
85 
86  #else
87 
88  #ifdef __HIP__
89 
90  thrust::merge_by_key(thrust::device, a_keys,a_keys + a_count,
91  b_keys,b_keys + b_count,
92  a_vals,b_vals,
93  c_keys,c_vals,comp);
94 
95  #else
96 
97  thrust::merge_by_key(thrust::device, a_keys,a_keys + a_count,
98  b_keys,b_keys + b_count,
99  a_vals,b_vals,
100  c_keys,c_vals,comp);
101 
102  #endif
103 
104  #endif
105  }
106  }
107 
108  #endif /* __NVCC__ */
109 
110  #endif /* SCAN_OFP_HPP_ */
convert a type into constant type
Definition: aggregate.hpp:302