Part 3 — Warp Shuffles, Reductions, and Cooperative Groups
In this section, we will look at a special set of instructions that break the clean SIMT picture of CUDA: Warp shuffles, votes, and reductions allow all threads in one warp to interact directly with each other, without the need to go through shared memory.
3.1 Reduction: from Shared Memory to Warp Shuffles
This section focuses on the block-wide reduction — how threads in a single block combine blockDim.x values into one. The accompanying programs launch exactly one block on blockDim.x inputs so the reduction itself is the only thing happening. A real reduction over a large array launches many blocks and combines their partials in a second pass or via an atomic; that scaffolding is orthogonal to how a block reduces internally and is left out here.
Baseline: shared memory data exchange
The canonical parallel reduction sums (or maxes) an array of values held in registers across a block. The classic approach stages the work in shared memory, halving the active thread count each round:
// Block-wide sum reduction. Thread 0 writes the block's sum to *out.
// Requires: blockDim.x is a power of two, n == blockDim.x, launched with
// (blockDim.x * sizeof(float)) bytes of dynamic shared memory.
__global__ void reduce_smem(const float* __restrict__ in,
float* __restrict__ out)
{
extern __shared__ float smem[]; // size = blockDim.x floats
unsigned tid = threadIdx.x;
smem[tid] = in[tid];
__syncthreads(); // all writes visible to the block
// Tree reduction: halve the number of active threads each round.
// Round k: threads [0, s) add smem[tid + s] into smem[tid].
for (unsigned s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) smem[tid] += smem[tid + s];
__syncthreads(); // full-block barrier every round
}
if (tid == 0) *out = smem[0]; // thread 0 holds the block sum
}
A few things are worth pointing out:
- One shared-memory write and one read per round. The
__syncthreads()between rounds forces every thread in the block to wait for the slowest one before the next halving can begin. blockDim.xfloats of shared memory — for a 1024-thread block that is 4 KB just to hold the in-flight reduction state, which competes with other shared-memory consumers for the same SM budget.- O(log blockDim.x) full-block barriers. Each
__syncthreads()is a hardware barrier that drains all warps, so the cost scales with block size.
The host launcher allocates device buffers, requests blockDim.x floats of dynamic shared memory at launch, and copies the scalar result back:
// Host-side launcher: allocate device buffers, launch one block of BLOCK
// threads with BLOCK floats of dynamic shared memory, copy the scalar
// result back. BLOCK must be a power of two; n must equal BLOCK because
// each thread reads exactly one input element.
static float launch_reduce_smem(const float* h_in, int n) {
constexpr int BLOCK = 256;
assert(n == BLOCK && "single-block demo: n must equal BLOCK");
float* d_in = nullptr;
float* d_out = nullptr;
CUDA_CHECK(cudaMalloc(&d_in, n * sizeof(float)));
CUDA_CHECK(cudaMalloc(&d_out, sizeof(float)));
CUDA_CHECK(cudaMemcpy(d_in, h_in, n * sizeof(float),
cudaMemcpyHostToDevice));
size_t smem_bytes = BLOCK * sizeof(float);
reduce_smem<<<1, BLOCK, smem_bytes>>>(d_in, d_out);
CUDA_CHECK(cudaGetLastError());
float h_out = 0.0f;
CUDA_CHECK(cudaMemcpy(&h_out, d_out, sizeof(float),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaFree(d_in));
CUDA_CHECK(cudaFree(d_out));
return h_out;
}
This works, but the costs above are exactly what warp shuffles are designed to eliminate.
Enter warp shuffles
To handle these kinds of problems at the warp level, CUDA provides a family of warp shuffle functions:
T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
T __shfl_up_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);
T __shfl_down_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);
T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);
Warp shuffle instructions move a register value directly from one lane's register file to another's, with no shared memory and no barrier. The hardware completes the exchange in a single instruction. As indicated by the _sync suffix,
these functions require all participating threads to be synchronized. A thread participates in a shuffle if its lane index
is set in the mask argument; to use the entire warp, set mask to 0xffffffff.
Deadlock: mismatched shuffle calls
if(threadIdx.x < 16) {
__shfl_sync(0xffffffff, var, 0);
} else {
__shfl_sync(0xffffffff, var, 16);
}
The idea here would be to broadcast the value of thread 0 to the lower half of the warp, and thread 16 to the upper half. However, both calls specify the full warp. Neither call can proceed, because each is waiting for the other half of the warp to arrive at the same shuffle.
While __shfl_sync allows full generality by setting arbitrary source lanes, for many practical cases the
specialized variants are more convenient, as they directly encode a required pattern. For example, a warp-level
reduction can be implemented using __shfl_down_sync.
A warp-wide sum reduction collapses to 5 rounds. Visually after each round (8-lane illustration; 3 rounds for 8 lanes, 5 for a full 32-lane warp):
Initial: [a, b, c, d, e, f, g, h]
+delta=4: [a+e, b+f, c+g, d+h, e, f, g, h]
+delta=2: [a+e+c+g, b+f+d+h, c+g, d+h, ...]
+delta=1: [a+b+c+d+e+f+g+h, ...] ← lane 0
However, despite this being a single instruction instead of shared memory access, it is important to note that the sequence above still suffers from latency problems, because of the long dependency chain. Each warp shuffle takes a few cycles to complete, and the next shuffle can only start once the previous one completes.
※ Broadcast result
If the result of a reduction is needed by all threads, a simple implementation would be
to do the reduction as described above, then broadcast the result to all lanes using __shfl_sync with 0 as the source lane.
However, it is possible to achieve the same effect using one less shuffle.
Integer reductions: __reduce_*_sync
For integer types, CUDA exposes a family of single-instruction warp reductions introduced in Ampere (sm_80):
// Overloaded for int and unsigned; result broadcast to all lanes in mask
__reduce_add_sync(mask, value) // warp-wide sum
__reduce_min_sync(mask, value) // warp-wide min
__reduce_max_sync(mask, value) // warp-wide max
__reduce_and_sync(mask, value) // bitwise AND (unsigned only)
__reduce_or_sync(mask, value) // bitwise OR (unsigned only)
__reduce_xor_sync(mask, value) // bitwise XOR (unsigned only)
// Example: count threads in this warp where predicate is true
int count = __reduce_add_sync(0xffffffff, (int)(myVal > threshold));
Unlike the __shfl_down_sync loop, these are single instructions and broadcast the result to all lanes in the mask — not just lane 0. There are no float versions;
however; for minimum and maximum can be implemented based on the integer operations.
AdvancedFloating-point reductions
The hardware redux unit only compares integers — but for min/max, that is less of a restriction than it looks.
IEEE-754 was designed so that for non-negative floats, the bit pattern, read as an unsigned integer, is monotone
in the float value (sign bit 0, then exponent, then mantissa in decreasing significance). So if all participating
values are known to be non-negative (absolute values, probabilities, squared distances, ...), a float max is just an
integer max on the raw bits:
// All lanes hold non-negative, non-NaN floats
float mx = __uint_as_float(__reduce_max_sync(0xffffffff, __float_as_uint(v)));
For arbitrary signs (but assuming no NaNs), a small monotone bijection fixes the ordering: negative floats have the sign bit set and grow downward as their unsigned bit pattern grows, so we flip all bits of negatives and set the sign bit of non-negatives:
__device__ unsigned float_to_ordered(float f) {
unsigned u = __float_as_uint(f);
return (u >> 31) ? ~u : (u | 0x80000000u);
}
__device__ float ordered_to_float(unsigned u) {
return __uint_as_float((u >> 31) ? (u & 0x7fffffffu) : ~u);
}
float mx = ordered_to_float(__reduce_max_sync(0xffffffff, float_to_ordered(v)));
(You may recognize this as the same transform used to radix-sort floats.)
Starting with Blackwell, the hardware can do this natively: PTX ISA 8.6 adds redux.sync.{min,max}{.abs}{.NaN}.f32
for the sm_100a architecture- and family-specific targets, including proper NaN handling via the .NaN modifier
(returns canonical NaN if any input is NaN) and a free .abs. As of
CUDA 13.2, these are not exposed as __reduce_*_sync overloads, so you need inline PTX:
float mx;
asm("redux.sync.max.NaN.f32 %0, %1, 0xffffffff;" : "=f"(mx) : "f"(v));
Block-wide reduction combining both
The standard pattern: shuffle-reduce within each warp, write warp results to shared memory, shuffle-reduce the partial results in the first warp:
// Block-wide sum reduction via warp shuffles + a tiny shared-memory stage.
// Thread 0 writes the block's sum to *out.
// Requires: blockDim.x is a multiple of 32 and <= 1024, n == blockDim.x.
__global__ void reduce_shfl(const float* __restrict__ in,
float* __restrict__ out)
{
__shared__ float warp_sums[32]; // at most 32 warps in a 1024-thread block
int lane = threadIdx.x & 31;
int warpId = threadIdx.x >> 5;
float val = in[threadIdx.x];
// Stage 1: each warp reduces into its lane 0 -- no shared memory, no barrier.
val += __shfl_down_sync(0xffffffff, val, 16);
val += __shfl_down_sync(0xffffffff, val, 8);
val += __shfl_down_sync(0xffffffff, val, 4);
val += __shfl_down_sync(0xffffffff, val, 2);
val += __shfl_down_sync(0xffffffff, val, 1);
// Stage 2: lane 0 of each warp deposits its partial in shared memory.
if (lane == 0) warp_sums[warpId] = val;
__syncthreads(); // the ONE block-wide barrier
// Stage 3: first warp reduces the warp partials. Inactive lanes contribute 0.
int nWarps = blockDim.x >> 5;
if (warpId == 0) {
val = (lane < nWarps) ? warp_sums[lane] : 0.0f;
val += __shfl_down_sync(0xffffffff, val, 16);
val += __shfl_down_sync(0xffffffff, val, 8);
val += __shfl_down_sync(0xffffffff, val, 4);
val += __shfl_down_sync(0xffffffff, val, 2);
val += __shfl_down_sync(0xffffffff, val, 1);
if (lane == 0) *out = val; // lane 0 of warp 0 holds the block sum
}
}
The launcher is nearly identical to the shared-memory version, with one difference: no dynamic shared memory is requested at launch, because the kernel only needs a small fixed __shared__ array internally:
// Host-side launcher -- same shape as launch_reduce_smem, with one
// difference: no dynamic shared memory is requested at launch, because
// the kernel only uses a small fixed __shared__ array internally
// (32 floats, one per warp).
static float launch_reduce_shfl(const float* h_in, int n) {
constexpr int BLOCK = 256;
assert(n == BLOCK && "single-block demo: n must equal BLOCK");
float* d_in = nullptr;
float* d_out = nullptr;
CUDA_CHECK(cudaMalloc(&d_in, n * sizeof(float)));
CUDA_CHECK(cudaMalloc(&d_out, sizeof(float)));
CUDA_CHECK(cudaMemcpy(d_in, h_in, n * sizeof(float),
cudaMemcpyHostToDevice));
reduce_shfl<<<1, BLOCK>>>(d_in, d_out);
CUDA_CHECK(cudaGetLastError());
float h_out = 0.0f;
CUDA_CHECK(cudaMemcpy(&h_out, d_out, sizeof(float),
cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaFree(d_in));
CUDA_CHECK(cudaFree(d_out));
return h_out;
}
Compare against the shared-memory version:
| Resource | smem version | shuffle version |
|---|---|---|
__syncthreads() |
O(log blockDim.x) | exactly 1 |
| Shared memory | blockDim.x floats | nWarps ≤ 32 floats |
| Cross-lane mechanism | shared-memory load/store | register-to-register shuffle |
One __syncthreads() instead of O(log N), and shared memory only used for O(warp count) values, not O(block size).
3.2 Masks, __activemask, and Why It Is Not Enough
Every warp-level intrinsic takes a mask argument. The hardware checks that every lane named in the mask executes the instruction with the same mask before proceeding. This can easily lead to deadlocks in branching code.
In this section, we will look at how to handle this properly.
__activemask() — tempting but wrong
CUDA provides an intrinsic that, at first glance, seems to be exactly what would be needed to fix deadlocks
of this kind: __activemask() returns a bitmask where exactly those bits are set that correspond to threads executing the instruction. This would suggest a pattern like:
// WRONG — do not do this
unsigned mask = __activemask();
val = __shfl_down_sync(mask, val, delta);
__activemask() returns the bitmask of lanes converged with the caller at the moment of the call. Since Volta, lanes in a warp are allowed to have independent program counters, and the only places they are guaranteed to be brought back together are the _sync intrinsics. The start of an if, the merge after it, and ordinary straight-line code in between are all places where lanes may drift apart — whether they actually do is a runtime decision.
This makes __activemask() unreliable in any divergent context; while it might work well and even produce the correct result in 99 percent of cases,
it leads to subtle bugs that are hard to reproduce.
There seems to be a sort of conundrum: Any warp-synchronous instruction requires a mask, but we're trying to figure out which mask to use.
__ballot_sync — the correct tool
The key is to lift the mask generation to happen before any divergent code. If there is no divergence, we can safely
use the mask 0xffffffff for the __ballot_sync intrinsic:
__ballot_sync(mask, predicate) atomically: (1) evaluates predicate on each lane named in mask, (2) builds a 32-bit result where bit i is set if lane i's predicate is true, and (3) returns the same result to all lanes in mask. Because all lanes agree on the result, it is safe to use as the mask for a subsequent shuffle:
// Correct pattern: determine which lanes have valid data, then reduce only those
unsigned activeMask = __ballot_sync(0xffffffff, myIndex < N);
if (myIndex < N) {
// Now use activeMask — it is consistent across all lanes that will participate
val = __shfl_down_sync(activeMask, val, 16);
// ...
}
Subwarp tiling
So far, we have used warp-level intrinsics to handle operations that require the full warp to cooperate.
However, many functions provide one additional width parameter, which allows us to specify the
granularity of communication at any power of two between 1 and 32.
What would be a scenario in which this would be useful? Consider as input a N x 64 matrix of floats, and the task is to compute the maximum element in each row. In the full-warp scenario, each thread loads only two values, followed by 5 rounds of shuffles to reduce across the warp; the cost is dominated by communication overheads. Ideally, each thread would load 64 values and reduce them locally, requiring zero communication, but this would lead to a terrible memory access pattern. With a warp loading 32 consecutive floats, each load fetches 128 contiguous bytes -- a full cache line. This is what makes coalesced memory access efficient. However, as we have seen in the previous section, with vectorized loads we can increase the amount of contiguous bytes loaded by a single thread, up to 16. In that case, 8 threads are sufficient to fetch the entire cache line.
This suggests the following pattern of subwarp tiling:
Eight consecutive threads are jointly responsible for handling the 64 values of a row of the matrix.
This gives us 8 operations per thread for 3 communication rounds, a much better ratio than the full-warp scenario.
Further, with 32 cooperative threads, we would not be able to fully vectorize the memory access; each thread would load
only two values. But with eight threads, we can load inputs as float4 for additional efficiency.
The reduction would then look as follows:
// 8-lane max-reduction over 16 float4s (64 floats) of one row.
__device__ float reduce_subwarp(unsigned mask, const float4* row) {
int lane = threadIdx.x % 8;
float4 a = row[lane];
float4 b = row[lane + 8];
float am = fmaxf(fmaxf(a.x, a.y), fmaxf(a.z, a.w));
float bm = fmaxf(fmaxf(b.x, b.y), fmaxf(b.z, b.w));
float mx = fmaxf(am, bm);
mx = fmaxf(__shfl_down_sync(mask, mx, 1, 8), mx);
mx = fmaxf(__shfl_down_sync(mask, mx, 2, 8), mx);
mx = fmaxf(__shfl_down_sync(mask, mx, 4, 8), mx);
return mx;
}
Because we used width=8, the first shuffle does in fact not change the values in lanes 4-7 at all. As per the documentation:
If width is less than warpSize, then each subsection of the warp behaves as a separate entity with a starting logical lane ID of 0.
The deadlock trap with sub-warp tiling
Consider the example above, extended to a full kernel:
__global__ void reduce_max_group_naive(const float4* __restrict__ inputs,
int cols_f4, int num_groups,
float* __restrict__ out)
{
for (int g = threadIdx.x / 8; g < num_groups; g += blockDim.x / 8) {
float mx = reduce_subwarp(0xffffffff, inputs + g * cols_f4);
if ((threadIdx.x % 8) == 0) out[g] = mx;
}
}
What happens when the number of groups num_groups is not a multiple of the number of subwarps?
✦ Solution
If the number of groups is still a multiple of 4, then the loop body still proceeds as a converged warp. But if the number of groups is not a multiple of 4, then for some warps there will be some threads that do not enter the loop. This means that the synchronization mask 0xffffffff cannot be fulfilled. Should we expect the code above to hang forever, then?
Actually no. Synchronization functions (__syncthreads, __syncwarp as well as the warp shuffles) wait for all non-exited threads
to reach the synchronization point. The threads for which the loop condition is false will just end their program and EXIT.
This is the main benefit of independent thread scheduling: With a program counter available to every thread, the scheduler can see that part of the warp is blocked waiting for the sync, and decide to schedule the other threads of the warp.
Once you understand why the example above does not hang, you should also be able to explain why this modification is problematic:
Why does adding __syncthreads() after the loop cause a deadlock here?
__global__ void reduce_max_group_naive(const float4* __restrict__ inputs,
int cols_f4, int num_groups,
float* __restrict__ out)
{
for (int g = threadIdx.x / 8; g < num_groups; g += blockDim.x / 8) {
float mx = reduce_subwarp(0xffffffff, inputs + g * cols_f4);
if ((threadIdx.x % 8) == 0) out[g] = mx;
}
__syncthreads();
// some additional work
}
✦ Solution
Contrary to the example above, the threads that do not enter the loop cannot exit the kernel, but instead have to wait to synchronize with the threads waiting at the synchronization point inside the loop. Neither can make progress, a classical deadlock.
Fortunately, we already know the tool that can be used to fix this: __ballot_sync.
✦ Solution
__global__ void reduce_max_group_correct(const float4* __restrict__ inputs,
int cols_f4, int num_groups,
float* __restrict__ out)
{
int g = threadIdx.x / 8;
unsigned mask = __ballot_sync(0xffffffff, g < num_groups);
while (mask) {
if (g < num_groups) {
float mx = reduce_subwarp(mask, inputs + g * cols_f4);
if ((threadIdx.x % 8) == 0) out[g] = mx;
}
g += blockDim.x / 8;
mask = __ballot_sync(0xffffffff, g < num_groups);
}
}
3.3 Cooperative Groups
The CUDA headers provide a set of abstractions for managing cooperating threads at different levels of the hierarchy, from subwarps to multiple blocks. These Cooperative Groups are provided in cooperative_groups.h under the cooperative_groups namespace.
Tiled partitions
Handling regular (power-of-two) subwarps (and more generally, subsets of threads within one block) can be achieved using the tiled_partition utility.
It takes as input one cooperative group, and returns a subgroup with the specified number of threads. To start, we can create a group representing the full current block this_thread_block().
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
// Full warp — equivalent to mask 0xffffffff (all 32 lanes)
auto warp = cg::tiled_partition<32>(cg::this_thread_block());
// Half-warp tile — 16 lanes, mask computed from calling context
auto half = cg::tiled_partition<16>(warp);
// Quarter-warp tile — 8 lanes
auto tile8 = cg::tiled_partition<8>(warp);
Tile size must be a compile-time power of 2.
This subdivision gives rise to a hierarchy of coordinates:
group.thread_rank() returns the index of a thread within the group, and
group.size() returns the number of threads in the group.
If the group is a partition of a larger group, then
group.meta_group_rank() provides the index of this subgroup in the parent group,
and group.meta_group_size() gives the total number of subgroups that comprise the parent.
As the reference is to a direct parent, these two constructions yield different results:
auto warp = cg::tiled_partition<32>(cg::this_thread_block());
auto tile8 = cg::tiled_partition<8>(warp);
auto tile4 = cg::tiled_partition<4>(warp);
static_assert(tile8.meta_group_size() == 4);
static_assert(tile4.meta_group_size() == 8);
but
auto warp = cg::tiled_partition<32>(cg::this_thread_block());
auto tile8 = cg::tiled_partition<8>(warp);
auto tile4 = cg::tiled_partition<4>(tile8);
static_assert(tile8.meta_group_size() == 4);
static_assert(tile4.meta_group_size() == 2);
In both versions, tile4 corresponds to 4 threads, but in the first it is interpreted as the whole warp partitioned into 8 subgroups, whereas in the second example it is a subgroup of 8 partitioned into 2 groups of 4 threads.
※ Overhead of cooperative groups
In general, the abstractions provided in cg are quite lightweight, and their performance impact is small.
But there is overhead, because cooperative groups need to work in general cases and thus may not assume anything about the
structure of your kernel.
For example, when constructing the cg::this_thread_block() group, the number of threads is calculated by multiplying blockDim.x * blockDim.y * blockDim.z, even though many kernels do not use dimensions y and z.
Group-level operations
auto tile = cg::tiled_partition<8>(cg::this_thread_block());
float val = /* ... */;
// Shuffle — source is tile-local (0..7), not warp-local (0..31)
float neighbor = tile.shfl_down(val, 4);
// Reduction — single call, result broadcast to all members
float sum = cg::reduce(tile, val, cg::plus<float>());
float mx = cg::reduce(tile, val, cg::greater<float>()); // greater → max
// Vote
int anyActive = tile.any(val > 0.f);
unsigned ballot = tile.ballot(val > 0.f); // bit i set if member i's predicate true
cg::reduce uses shuffle instructions under the hood, but is implemented in such a way that it can be used in divergent
code safely, provided all threads within a subgroup are converged. This uses a different strategy compared to the one presented above, based
on using a thread-id dependent mask:
For synchronizing groups of 16 threads, for example, the _sync intrinsics are called with mask 0x0000ffff from the lower 16 threads, and mask 0xffff0000 from the upper 16 threads.
Rewriting the sub-warp tiling deadlock example with CG
constexpr int TILE = 8;
// 8-lane max-reduction over 16 float4s (64 floats) of one row.
__device__ float reduce_subwarp_cg(cg::thread_block_tile<TILE> tile,
const float4* row)
{
int lane = tile.thread_rank();
float4 a = row[lane];
float4 b = row[lane + TILE];
float am = fmaxf(fmaxf(a.x, a.y), fmaxf(a.z, a.w));
float bm = fmaxf(fmaxf(b.x, b.y), fmaxf(b.z, b.w));
float mx = fmaxf(am, bm);
return cg::reduce(tile, mx, cg::greater<float>());
}
__global__ void reduce_rows_cg(const float4* __restrict__ inputs,
int cols_f4, int num_rows,
float* __restrict__ out)
{
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<32>(block);
auto tile = cg::tiled_partition<TILE>(warp);
int row = (blockIdx.x * (int)block.size() + (int)block.thread_rank()) / TILE;
if (row >= num_rows) return; // whole tile exits together -- no partial tiles
float mx = reduce_subwarp_cg(tile, inputs + row * cols_f4);
if (tile.thread_rank() == 0)
out[row] = mx;
}
The if (row >= num_rows) return guard is safe here because entire tiles either return or continue — the tile is formed before any group operation, so there are no partial participants.
CG and __syncthreads
CG also lifts block-level synchronization:
auto block = cg::this_thread_block();
block.sync(); // equivalent to __syncthreads()
auto warp = cg::tiled_partition<32>(cg::this_thread_block());
warp.sync(); // equivalent to __syncwarp(0xffffffff)
The advantage is composability: a __device__ function that accepts a cg::thread_group can synchronize its argument without knowing whether it is a block, a warp, or a tile.
Raw intrinsics vs. cooperative groups
For reference, the correspondence between the raw warp/block intrinsics from §3.1–§3.2 and their CG equivalents:
| Operation | Raw intrinsic | Cooperative groups |
|---|---|---|
| Block barrier | __syncthreads() |
block.sync() |
| Warp barrier | __syncwarp(0xffffffff) |
warp.sync() |
| Shuffle (down) | __shfl_down_sync(mask, v, d, width) |
tile.shfl_down(v, d) |
| Shuffle (arbitrary) | __shfl_sync(mask, v, srcLane, width) |
tile.shfl(v, srcLane) |
| Ballot | __ballot_sync(mask, pred) |
tile.ballot(pred) |
| Any / all | __any_sync(mask, pred) / __all_sync(...) |
tile.any(pred) / tile.all(pred) |
| Reduction (max) | __shfl_xor_sync loop with fmaxf |
cg::reduce(tile, v, cg::greater<T>()) |
The CG column never needs a mask argument: tile membership is determined at the tile's construction and carried with the object.
3.4 Exercise — Row-wise Softmax in BF16
Implement 128-element row-wise softmax in BF16 using cooperative groups.
In this exercise, implement a row-wise softmax kernel over a 2D BF16 matrix.
Background: row-wise softmax
Softmax normalizes each row of a matrix into a probability distribution:
output[row][col] = exp(input[row][col]) / Σ exp(input[row][col])
Interface
The function to be implemented is:
void softmax128(__nv_bfloat16* out, const __nv_bfloat16* in, int rows);
Here, rows is the number of rows, 128 elements each; out is the destination
array and in is the input array.
The number of elements can be an arbitrary positive 32-bit integer; the
starting addresses of out and in are guaranteed to align to 256 bytes.
The values of in are limited to -32000 < in[i] < 32000.
Hints
※ Numerically-stable softmax
Calculating the softmax formula directly easily leads to numerical problems, as exp(x) can overflow, despite the fact
that the final softmax result is bounded between 0 and 1. To prevent this problem, the following equivalent formula
can be used instead:
shift = max_j input[j]
output[i] = exp(input[i] - shift) / Σ exp(input[i] - shift)
This naturally leads to a three-pass algorithm, first calculating the maximum, then accumulating the denominator, and finally computing the softmax.
You can test your solution locally:
python run.py test