36#include "../../thread/thread_operators.cuh"
37#include "../../util_type.cuh"
38#include "../../util_ptx.cuh"
39#include "../../util_namespace.cuh"
54 int LOGICAL_WARP_THREADS,
65 IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
71 SHFL_C = (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS) << 8
110 member_mask = 0xffffffffu >> (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS);
133 int shfl_c = first_lane |
SHFL_C;
136#ifdef CUB_USE_COOPERATIVE_GROUPS
141 " shfl.sync.up.b32 r0|p, %1, %2, %3, %5;"
142 " @p add.s32 r0, r0, %4;"
145 :
"=r"(output) :
"r"(input),
"r"(offset),
"r"(shfl_c),
"r"(input),
"r"(
member_mask));
151 " shfl.up.b32 r0|p, %1, %2, %3;"
152 " @p add.s32 r0, r0, %4;"
155 :
"=r"(output) :
"r"(input),
"r"(offset),
"r"(shfl_c),
"r"(input));
169 int shfl_c = first_lane |
SHFL_C;
172#ifdef CUB_USE_COOPERATIVE_GROUPS
177 " shfl.sync.up.b32 r0|p, %1, %2, %3, %5;"
178 " @p add.u32 r0, r0, %4;"
181 :
"=r"(output) :
"r"(input),
"r"(offset),
"r"(shfl_c),
"r"(input),
"r"(
member_mask));
187 " shfl.up.b32 r0|p, %1, %2, %3;"
188 " @p add.u32 r0, r0, %4;"
191 :
"=r"(output) :
"r"(input),
"r"(offset),
"r"(shfl_c),
"r"(input));
206 int shfl_c = first_lane |
SHFL_C;
209#ifdef CUB_USE_COOPERATIVE_GROUPS
214 " shfl.sync.up.b32 r0|p, %1, %2, %3, %5;"
215 " @p add.f32 r0, r0, %4;"
218 :
"=f"(output) :
"f"(input),
"r"(offset),
"r"(shfl_c),
"f"(input),
"r"(
member_mask));
224 " shfl.up.b32 r0|p, %1, %2, %3;"
225 " @p add.f32 r0, r0, %4;"
228 :
"=f"(output) :
"f"(input),
"r"(offset),
"r"(shfl_c),
"f"(input));
237 unsigned long long input,
242 unsigned long long output;
243 int shfl_c = first_lane |
SHFL_C;
246#ifdef CUB_USE_COOPERATIVE_GROUPS
253 " mov.b64 {lo, hi}, %1;"
254 " shfl.sync.up.b32 lo|p, lo, %2, %3, %5;"
255 " shfl.sync.up.b32 hi|p, hi, %2, %3, %5;"
256 " mov.b64 r0, {lo, hi};"
257 " @p add.u64 r0, r0, %4;"
260 :
"=l"(output) :
"l"(input),
"r"(offset),
"r"(shfl_c),
"l"(input),
"r"(
member_mask));
268 " mov.b64 {lo, hi}, %1;"
269 " shfl.up.b32 lo|p, lo, %2, %3;"
270 " shfl.up.b32 hi|p, hi, %2, %3;"
271 " mov.b64 r0, {lo, hi};"
272 " @p add.u64 r0, r0, %4;"
275 :
"=l"(output) :
"l"(input),
"r"(offset),
"r"(shfl_c),
"l"(input));
290 int shfl_c = first_lane |
SHFL_C;
293#ifdef CUB_USE_COOPERATIVE_GROUPS
300 " mov.b64 {lo, hi}, %1;"
301 " shfl.sync.up.b32 lo|p, lo, %2, %3, %5;"
302 " shfl.sync.up.b32 hi|p, hi, %2, %3, %5;"
303 " mov.b64 r0, {lo, hi};"
304 " @p add.s64 r0, r0, %4;"
307 :
"=l"(output) :
"l"(input),
"r"(offset),
"r"(shfl_c),
"l"(input),
"r"(
member_mask));
315 " mov.b64 {lo, hi}, %1;"
316 " shfl.up.b32 lo|p, lo, %2, %3;"
317 " shfl.up.b32 hi|p, hi, %2, %3;"
318 " mov.b64 r0, {lo, hi};"
319 " @p add.s64 r0, r0, %4;"
322 :
"=l"(output) :
"l"(input),
"r"(offset),
"r"(shfl_c),
"l"(input));
337 int shfl_c = first_lane |
SHFL_C;
340#ifdef CUB_USE_COOPERATIVE_GROUPS
348 " mov.b64 {lo, hi}, %1;"
349 " shfl.sync.up.b32 lo|p, lo, %2, %3, %4;"
350 " shfl.sync.up.b32 hi|p, hi, %2, %3, %4;"
351 " mov.b64 r0, {lo, hi};"
352 " @p add.f64 %0, %0, r0;"
354 :
"=d"(output) :
"d"(input),
"r"(offset),
"r"(shfl_c),
"r"(
member_mask));
363 " mov.b64 {lo, hi}, %1;"
364 " shfl.up.b32 lo|p, lo, %2, %3;"
365 " shfl.up.b32 hi|p, hi, %2, %3;"
366 " mov.b64 r0, {lo, hi};"
367 " @p add.f64 %0, %0, r0;"
369 :
"=d"(output) :
"d"(input),
"r"(offset),
"r"(shfl_c));
398 template <
typename _T,
typename ScanOpT>
405 _T temp = ShuffleUp<LOGICAL_WARP_THREADS>(input, offset, first_lane,
member_mask);
408 _T output =
scan_op(temp, input);
409 if (
static_cast<int>(
lane_id) < first_lane + offset)
417 template <
typename _T,
typename ScanOpT>
430 template <
typename _T,
typename ScanOpT>
455 return ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane,
member_mask);
464 template <
typename _T,
typename ScanOpT>
467 _T &inclusive_output,
470 inclusive_output = input;
473 int segment_first_lane = 0;
477 for (
int STEP = 0; STEP <
STEPS; STEP++)
490 template <
typename KeyT,
typename ValueT,
typename ReductionOpT>
496 inclusive_output = input;
498 KeyT pred_key = ShuffleUp<LOGICAL_WARP_THREADS>(inclusive_output.
key, 1, 0,
member_mask);
506 int segment_first_lane =
CUB_MAX(0, 31 - __clz(ballot));
510 for (
int STEP = 0; STEP <
STEPS; STEP++)
513 inclusive_output.
value,
523 template <
typename ScanOpT>
533 warp_aggregate = ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, LOGICAL_WARP_THREADS - 1,
member_mask);
542 template <
typename ScanOpT,
typename IsIntegerT>
551 exclusive = ShuffleUp<LOGICAL_WARP_THREADS>(inclusive, 1, 0,
member_mask);
563 exclusive = inclusive - input;
567 template <
typename ScanOpT,
typename IsIntegerT>
576 inclusive =
scan_op(initial_value, inclusive);
577 exclusive = ShuffleUp<LOGICAL_WARP_THREADS>(inclusive, 1, 0,
member_mask);
580 exclusive = initial_value;
592 inclusive =
scan_op(initial_value, inclusive);
593 exclusive = inclusive - input;
598 template <
typename ScanOpT,
typename IsIntegerT>
605 IsIntegerT is_integer)
607 warp_aggregate = ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive, LOGICAL_WARP_THREADS - 1,
member_mask);
612 template <
typename ScanOpT,
typename IsIntegerT>
620 IsIntegerT is_integer)
622 warp_aggregate = ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive, LOGICAL_WARP_THREADS - 1,
member_mask);
623 Update(input, inclusive, exclusive,
scan_op, initial_value, is_integer);
#define CUB_MAX(a, b)
Select maximum(a, b)
__device__ __forceinline__ int WARP_BALLOT(int predicate, unsigned int member_mask)
__device__ __forceinline__ unsigned int LaneMaskLe()
Returns the warp lane mask of all lanes less than or equal to the calling thread.
__device__ __forceinline__ unsigned int LaneId()
Returns the warp lane ID of the calling thread.
Optional outer namespace(s)
KeyT const ValueT ValueT OffsetIteratorT OffsetIteratorT int
[in] The number of segments that comprise the sorting data
OutputIteratorT ScanTileStateT int ScanOpT scan_op
Binary scan functor.
Allows for the treatment of an integral constant as a type at compile-time (e.g., to achieve static c...
A key identifier paired with a corresponding value.
Statically determine log2(N), rounded up.
< Binary reduction operator to apply to values
@ IS_SMALL_UNSIGNED
Whether the data type is a small (32b or less) integer for which we can use a single SFHL instruction...
Shared memory storage layout type.
WarpScanShfl provides SHFL-based variants of parallel prefix scan of items partitioned across a CUDA ...
unsigned int member_mask
32-thread physical warp member mask of logical warp
__device__ __forceinline__ WarpScanShfl(TempStorage &)
Constructor.
__device__ __forceinline__ void Update(T, T &inclusive, T &exclusive, ScanOpT, IsIntegerT)
Update inclusive and exclusive using input and inclusive.
__device__ __forceinline__ _T InclusiveScanStep(_T input, ScanOpT scan_op, int first_lane, int offset)
Inclusive prefix scan step (generic)
__device__ __forceinline__ _T InclusiveScanStep(_T input, ScanOpT scan_op, int first_lane, int offset, Int2Type< false >)
Inclusive prefix scan step (specialized for types other than small integers size 32b or less)
__device__ __forceinline__ T Broadcast(T input, int src_lane)
Broadcast.
__device__ __forceinline__ void InclusiveScan(KeyValuePair< KeyT, ValueT > input, KeyValuePair< KeyT, ValueT > &inclusive_output, ReduceByKeyOp< ReductionOpT > scan_op)
Inclusive scan, specialized for reduce-value-by-key.
unsigned int lane_id
Lane index in logical warp.
__device__ __forceinline__ void InclusiveScan(T input, T &inclusive_output, ScanOpT scan_op, T &warp_aggregate)
Inclusive scan with aggregate.
@ IS_ARCH_WARP
Whether the logical warp size and the PTX warp size coincide.
@ STEPS
The number of warp scan steps.
@ SHFL_C
The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up.
__device__ __forceinline__ void InclusiveScan(_T input, _T &inclusive_output, ScanOpT scan_op)
Inclusive scan.
__device__ __forceinline__ void Update(T input, T &inclusive, T &exclusive, cub::Sum scan_op, T initial_value, Int2Type< true >)
Update inclusive and exclusive using initial value using input and inclusive (specialized for summati...
unsigned int warp_id
Logical warp index in 32-thread physical warp.
__device__ __forceinline__ void Update(T input, T &inclusive, T &exclusive, cub::Sum, Int2Type< true >)
Update inclusive and exclusive using input and inclusive (specialized for summation of integer types)
__device__ __forceinline__ float InclusiveScanStep(float input, cub::Sum, int first_lane, int offset)
Inclusive prefix scan step (specialized for summation across fp32 types)
__device__ __forceinline__ _T InclusiveScanStep(_T input, ScanOpT scan_op, int first_lane, int offset, Int2Type< true >)
Inclusive prefix scan step (specialized for small integers size 32b or less)
__device__ __forceinline__ void Update(T input, T &inclusive, T &exclusive, T &warp_aggregate, ScanOpT scan_op, T initial_value, IsIntegerT is_integer)
Update inclusive, exclusive, and warp aggregate using input, inclusive, and initial value.
__device__ __forceinline__ int InclusiveScanStep(int input, cub::Sum, int first_lane, int offset)
Inclusive prefix scan step (specialized for summation across int32 types)
__device__ __forceinline__ void Update(T, T &inclusive, T &exclusive, ScanOpT scan_op, T initial_value, IsIntegerT)
Update inclusive and exclusive using initial value using input, inclusive, and initial value.
__device__ __forceinline__ unsigned long long InclusiveScanStep(unsigned long long input, cub::Sum, int first_lane, int offset)
Inclusive prefix scan step (specialized for summation across unsigned long long types)
__device__ __forceinline__ unsigned int InclusiveScanStep(unsigned int input, cub::Sum, int first_lane, int offset)
Inclusive prefix scan step (specialized for summation across uint32 types)
__device__ __forceinline__ void Update(T input, T &inclusive, T &exclusive, T &warp_aggregate, ScanOpT scan_op, IsIntegerT is_integer)
Update inclusive, exclusive, and warp aggregate using input and inclusive.
__device__ __forceinline__ long long InclusiveScanStep(long long input, cub::Sum, int first_lane, int offset)
Inclusive prefix scan step (specialized for summation across long long types)
__device__ __forceinline__ double InclusiveScanStep(double input, cub::Sum, int first_lane, int offset)
Inclusive prefix scan step (specialized for summation across fp64 types)