39 #include "../block/block_load.cuh"
40 #include "../block/block_store.cuh"
41 #include "../block/block_scan.cuh"
42 #include "../block/block_discontinuity.cuh"
43 #include "../iterator/cache_modified_input_iterator.cuh"
44 #include "../iterator/constant_input_iterator.cuh"
45 #include "../util_namespace.cuh"
63 int _ITEMS_PER_THREAD,
89 typename AgentReduceByKeyPolicyT,
90 typename KeysInputIteratorT,
91 typename UniqueOutputIteratorT,
92 typename ValuesInputIteratorT,
93 typename AggregatesOutputIteratorT,
94 typename NumRunsOutputIteratorT,
96 typename ReductionOpT,
105 typedef typename std::iterator_traits<KeysInputIteratorT>::value_type KeyInputT;
109 typename std::iterator_traits<KeysInputIteratorT>::value_type,
110 typename std::iterator_traits<UniqueOutputIteratorT>::value_type>::Type
KeyOutputT;
113 typedef typename std::iterator_traits<ValuesInputIteratorT>::value_type ValueInputT;
117 typename std::iterator_traits<ValuesInputIteratorT>::value_type,
118 typename std::iterator_traits<AggregatesOutputIteratorT>::value_type>::Type
ValueOutputT;
130 template <
typename _EqualityOpT>
137 __host__ __device__ __forceinline__
141 template <
typename T>
142 __host__ __device__ __forceinline__
bool operator()(
const T &a,
const T &b,
int idx)
const
156 BLOCK_THREADS = AgentReduceByKeyPolicyT::BLOCK_THREADS,
157 ITEMS_PER_THREAD = AgentReduceByKeyPolicyT::ITEMS_PER_THREAD,
158 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
159 TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1),
166 typedef typename If<IsPointer<KeysInputIteratorT>::VALUE,
167 CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,
168 KeysInputIteratorT>::Type
169 WrappedKeysInputIteratorT;
172 typedef typename If<IsPointer<ValuesInputIteratorT>::VALUE,
173 CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
174 ValuesInputIteratorT>::Type
175 WrappedValuesInputIteratorT;
178 typedef typename If<IsPointer<AggregatesOutputIteratorT>::VALUE,
179 CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,
180 AggregatesOutputIteratorT>::Type
181 WrappedFixupInputIteratorT;
184 typedef ReduceBySegmentOp<ReductionOpT> ReduceBySegmentOpT;
191 AgentReduceByKeyPolicyT::LOAD_ALGORITHM>
199 AgentReduceByKeyPolicyT::LOAD_ALGORITHM>
203 typedef BlockDiscontinuity<
206 BlockDiscontinuityKeys;
212 AgentReduceByKeyPolicyT::SCAN_ALGORITHM>
216 typedef TilePrefixCallbackOp<
220 TilePrefixCallbackOpT;
223 typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1];
224 typedef ValueOutputT ValueExchangeT[TILE_ITEMS + 1];
270 __device__ __forceinline__
301 KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
303 OffsetT (&segment_indices)[ITEMS_PER_THREAD])
307 for (
int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
309 if (segment_flags[ITEM])
311 d_unique_out[segment_indices[ITEM]] = scatter_items[ITEM].key;
325 KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
327 OffsetT (&segment_indices)[ITEMS_PER_THREAD],
329 OffsetT num_tile_segments_prefix)
335 for (
int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
337 if (segment_flags[ITEM])
339 temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = scatter_items[ITEM];
345 for (
int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS)
357 __device__ __forceinline__
void Scatter(
358 KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD],
360 OffsetT (&segment_indices)[ITEMS_PER_THREAD],
362 OffsetT num_tile_segments_prefix)
365 if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS))
372 num_tile_segments_prefix);
391 template <
bool IS_LAST_TILE>
401 OffsetT head_flags[ITEMS_PER_THREAD];
402 OffsetT segment_indices[ITEMS_PER_THREAD];
414 if (threadIdx.x == 0)
416 tile_predecessor = (tile_idx == 0) ?
437 head_flags, keys, prev_keys, flag_op, tile_predecessor);
443 head_flags, keys, prev_keys, flag_op, tile_predecessor);
448 for (
int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
450 scan_items[ITEM].
value = values[ITEM];
451 scan_items[ITEM].
key = head_flags[ITEM];
462 num_segments_prefix = 0;
463 total_aggregate = block_aggregate;
466 if ((!IS_LAST_TILE) && (threadIdx.x == 0))
475 block_aggregate = prefix_op.GetBlockAggregate();
476 num_segments_prefix = prefix_op.GetExclusivePrefix().key;
477 total_aggregate = prefix_op.GetInclusivePrefix();
482 for (
int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
484 scatter_items[ITEM].
key = prev_keys[ITEM];
485 scatter_items[ITEM].
value = scan_items[ITEM].
value;
486 segment_indices[ITEM] = scan_items[ITEM].
key;
495 OffsetT num_tile_segments = block_aggregate.key;
496 Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix);
499 if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1))
501 OffsetT num_segments = num_segments_prefix + num_tile_segments;
504 if (num_remaining == TILE_ITEMS)
506 d_unique_out[num_segments] = keys[ITEMS_PER_THREAD - 1];
530 if (num_remaining > TILE_ITEMS)
533 ConsumeTile<false>(num_remaining, tile_idx, tile_offset,
tile_state);
535 else if (num_remaining > 0)
538 ConsumeTile<true>(num_remaining, tile_idx, tile_offset,
tile_state);