36#include "../../util_arch.cuh"
37#include "../../util_ptx.cuh"
38#include "../../warp/warp_scan.cuh"
39#include "../../util_namespace.cuh"
56struct BlockScanWarpScans
86 T warp_aggregates[
WARPS];
100 _TempStorage &temp_storage;
101 unsigned int linear_tid;
102 unsigned int warp_id;
103 unsigned int lane_id;
114 temp_storage(temp_storage.Alias()),
115 linear_tid(
RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)),
125 template <
typename ScanOp,
int WARP>
133 warp_prefix = block_aggregate;
135 T addend = temp_storage.warp_aggregates[WARP];
136 block_aggregate =
scan_op(block_aggregate, addend);
141 template <
typename ScanOp>
151 template <
typename ScanOp>
159 temp_storage.warp_aggregates[warp_id] = warp_aggregate;
165 block_aggregate = temp_storage.warp_aggregates[0];
186 template <
typename ScanOp>
191 const T &initial_value)
195 warp_prefix =
scan_op(initial_value, warp_prefix);
198 warp_prefix = initial_value;
208 template <
typename ScanOp>
221 template <
typename ScanOp>
225 const T &initial_value,
234 template <
typename ScanOp>
241 WarpScanT my_warp_scan(temp_storage.warp_scan[warp_id]);
245 my_warp_scan.
Scan(input, inclusive_output, exclusive_output,
scan_op);
253 temp_storage.warp_aggregates[warp_id] = inclusive_output;
258 T warp_inclusive, warp_prefix;
262 T warp_val = temp_storage.warp_aggregates[lane_id];
266 warp_prefix = my_warp_scan.
Broadcast(warp_prefix, warp_id);
267 block_aggregate = my_warp_scan.
Broadcast(warp_inclusive,
WARPS - 1);
273 exclusive_output =
scan_op(warp_prefix, exclusive_output);
275 exclusive_output = warp_prefix;
281 template <
typename ScanOp>
285 const T &initial_value,
289 WarpScanT my_warp_scan(temp_storage.warp_scan[warp_id]);
293 my_warp_scan.
Scan(input, inclusive_output, exclusive_output,
scan_op);
301 temp_storage.warp_aggregates[warp_id] = inclusive_output;
306 T warp_inclusive, warp_prefix;
310 T warp_val = temp_storage.warp_aggregates[lane_id];
311 WarpAggregateScanT(temp_storage.inner_scan[warp_id]).Scan(warp_val, warp_inclusive, warp_prefix, initial_value,
scan_op);
314 warp_prefix = my_warp_scan.
Broadcast(warp_prefix, warp_id);
315 block_aggregate = my_warp_scan.
Broadcast(warp_inclusive,
WARPS - 1);
319 exclusive_output =
scan_op(warp_prefix, exclusive_output);
321 exclusive_output = warp_prefix;
328 typename BlockPrefixCallbackOp>
333 BlockPrefixCallbackOp &block_prefix_callback_op)
342 T block_prefix = block_prefix_callback_op(block_aggregate);
346 temp_storage.block_prefix = block_prefix;
347 exclusive_output = block_prefix;
354 T block_prefix = temp_storage.block_prefix;
357 exclusive_output =
scan_op(block_prefix, exclusive_output);
367 template <
typename ScanOp>
379 template <
typename ScanOp>
386 WarpScanT(temp_storage.warp_scan[warp_id]).InclusiveScan(input, inclusive_output,
scan_op);
394 inclusive_output =
scan_op(warp_prefix, inclusive_output);
402 typename BlockPrefixCallbackOp>
407 BlockPrefixCallbackOp &block_prefix_callback_op)
415 T block_prefix = block_prefix_callback_op(block_aggregate);
419 temp_storage.block_prefix = block_prefix;
426 T block_prefix = temp_storage.block_prefix;
427 exclusive_output =
scan_op(block_prefix, exclusive_output);
The WarpScan class provides collective methods for computing a parallel prefix scan of items partitio...
__device__ __forceinline__ void Scan(T input, T &inclusive_output, T &exclusive_output, ScanOp scan_op)
Computes both inclusive and exclusive prefix scans using the specified binary scan functor across the...
__device__ __forceinline__ T Broadcast(T input, unsigned int src_lane)
Broadcast the value input from warp-lanesrc_lane to all lanes in the warp.
__device__ __forceinline__ int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
Returns the row-major linear thread identifier for a multidimensional thread block.
__device__ __forceinline__ unsigned int LaneId()
Returns the warp lane ID of the calling thread.
Optional outer namespace(s)
OutputIteratorT ScanTileStateT int ScanOpT scan_op
Binary scan functor.
Alias wrapper allowing storage to be unioned.
Shared memory storage layout type.
WarpScanT::TempStorage warp_scan[WARPS]
Buffer for warp-synchronous scans.
T block_prefix
Shared prefix for the entire thread block.
WarpAggregateScanT::TempStorage inner_scan[WARPS]
Buffer for warp-synchronous scans.
__device__ __forceinline__ void ExclusiveScan(T input, T &exclusive_output, ScanOp scan_op)
Computes an exclusive thread block-wide prefix scan using the specified binary scan_op functor....
@ WARP_THREADS
Number of warp threads.
@ WARPS
Number of active warps.
__device__ __forceinline__ void InclusiveScan(T input, T &inclusive_output, ScanOp scan_op)
Computes an inclusive thread block-wide prefix scan using the specified binary scan_op functor....
__device__ __forceinline__ void ApplyWarpAggregates(T &warp_prefix, ScanOp scan_op, T &block_aggregate, Int2Type< WARP > addend_warp)
__device__ __forceinline__ T ComputeWarpPrefix(ScanOp scan_op, T warp_aggregate, T &block_aggregate)
Use the warp-wide aggregates to compute the calling warp's prefix. Also returns block-wide aggregate ...
__device__ __forceinline__ BlockScanWarpScans(TempStorage &temp_storage)
Constructor.
__device__ __forceinline__ void ApplyWarpAggregates(T &warp_prefix, ScanOp scan_op, T &block_aggregate, Int2Type< WARPS > addend_warp)
WarpScan< T, WARP_THREADS, PTX_ARCH > WarpScanT
WarpScan utility type.
@ BLOCK_THREADS
The thread block size in threads.
WarpScan< T, WARPS, PTX_ARCH > WarpAggregateScanT
WarpScan utility type.
__device__ __forceinline__ void ApplyWarpAggregates(T &warp_prefix, ScanOp scan_op, T &block_aggregate, Int2Type< WARP >)
Allows for the treatment of an integral constant as a type at compile-time (e.g., to achieve static c...
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.