Project Complete 19 min read

FlashAttention & LLM Inference on GPUs

Writing a FlashAttention CUDA kernel from scratch, tiling the attention matrix to avoid materializing N×N memory, building a KV cache for token generation, and running GPT-2 with custom kernels end-to-end.
FlashAttention & LLM Inference on GPUs

Introduction: The Attention Bottleneck

Large language models spend most of their compute time in matrix multiplies and attention kernels. Billions of parameters, thousands of layers, but the operation that dominates the compute budget, especially at long context lengths, is self-attention.

The naive implementation of self-attention has a problem that’s easy to miss until you actually profile it: it materializes a sequence-length × sequence-length matrix in GPU global memory. For a 4096-token context, that’s a 4096×4096 matrix of fp16 values, or 32 MB, just for intermediate attention scores, just for one layer, just for one batch element. At 64 layers and batch sizes of 64, you’re looking at tens of gigabytes of intermediate storage, and the GPU’s bandwidth has to move all of it every forward pass. The arithmetic units sit idle waiting for data. You’re deep in memory-bound territory on the Roofline model.

FlashAttention (Dao et al., 2022) solves this with an observation that sounds deceptively simple: you can compute exact attention without ever materializing the full N×N score matrix in global memory. You tile the computation and use an “online softmax” algorithm to accumulate the correct result incrementally. The result is O(N) additional memory in sequence length when head dimension is treated as constant (equivalently O(N·d) if d is explicit), for an operation that seemed to require O(N²) memory.

The capstone project of Georgia Tech’s CS 8803: GPU Hardware and Software asks you to implement this from scratch: a numerically-stable softmax kernel, tiled GEMM kernels, the FlashAttention forward pass in CUDA, a KV cache for autoregressive decoding, and finally an end-to-end inference pipeline running GPT-2 with your custom attention kernels.

It ties together everything from the course: CUDA kernel writing and profiling (Projects 1–2), understanding of latency and throughput tradeoffs (Projects 3–4), and the GPU’s relationship to the ML software stack (Module 12). It’s the most ambitious project in the course.

Note

I keep this write-up focused on the implementation ideas behind the capstone project. For the broader course overview, see: GPU Hardware and Software: A Retrospective.

Self-Attention from Scratch

The Scaled Dot-Product Formula

The core attention operation takes three matrices, Query (Q), Key (K), and Value (V), and produces an output:

Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V

Conceptually: each query “attends to” all keys by computing dot products, which are then softmax-normalized into attention weights, which are used to compute a weighted sum of the values.

In a language model, Q, K, and V are all derived from the same input sequence by multiplying with learned weight matrices (WQW_Q, WKW_K, WVW_V). This is self-attention, where the sequence attends to itself and each position weights all other positions.

GEMM operation behind attention
Figure 1: The QK^T computation is a GEMM: A=[batch, seq, d_head] times B^T=[batch, d_head, seq] produces S=[batch, seq, seq]. For seq=4096, this is a 4096×4096 matrix per head per batch element, which is the memory bottleneck.

The scaling by 1/d1/\sqrt{d} prevents the dot products from growing too large (which would push softmax into its saturating regime, causing vanishing gradients during training). The softmax is applied row-wise: each query’s dot products with all keys are normalized into a probability distribution over positions.

Multi-Headed Attention

In practice, models use Multi-Head Attention (MHA): instead of one set of Q/K/V projections, you have h heads, each with its own projections into a smaller subspace (dhead=dmodel/hd_{\text{head}} = d_{\text{model}} / h).

# Conceptual multi-head attention, illustrates the pattern, not the project code
def multihead_attention(Q, K, V, num_heads):
    B, N, D = Q.shape
    head_dim = D // num_heads

    # Reshape: [B, N, D] → [B, N, num_heads, head_dim] → [B, num_heads, N, head_dim]
    Q = Q.view(B, N, num_heads, head_dim).transpose(1, 2)
    K = K.view(B, N, num_heads, head_dim).transpose(1, 2)
    V = V.view(B, N, num_heads, head_dim).transpose(1, 2)

    # Scaled dot-product attention: [B, num_heads, N, N]
    scores = (Q @ K.transpose(-2, -1)) / math.sqrt(head_dim)
    weights = F.softmax(scores, dim=-1)
    output  = weights @ V   # [B, num_heads, N, head_dim]

    # Reassemble heads
    return output.transpose(1, 2).contiguous().view(B, N, D)

In practice, MHA maps cleanly to GPU batching: all heads can be computed as a single batched matrix multiplication. The head dimension becomes an extra batch dimension, so no sequential loop over heads is required.

Causal Masking

For autoregressive text generation, a token at position i should only be able to attend to tokens at positions ≤ i, so it can’t look at future tokens that haven’t been generated yet. This is enforced by a causal mask: before the softmax, set the upper-triangular portion of the score matrix to -\infty. After softmax, those positions become 0 (zero weight). Tokens at future positions are effectively ignored.

# Apply causal mask
mask = torch.triu(torch.ones(N, N), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))

This is another motivation for FlashAttention: the causal mask means about half the score matrix elements are immediately discarded. Materializing the full N×N matrix just to fill half of it with -\infty is especially wasteful.

Custom CUDA Kernels: The Building Blocks

Before implementing FlashAttention, you need two lower-level building blocks: a numerically-stable softmax kernel and tiled GEMM kernels. These are the primitives that the attention kernel assembles.

Numerically Stable Softmax

The naive softmax formula is:

softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

The problem: for large input values, exie^{x_i} overflows float32 (max representable ~3.4×10³⁸). Attention scores can easily reach 100+ for long sequences with large embeddings. A score of 90 gives e901039e^{90} \approx 10^{39}, which overflows.

The standard fix is to subtract the row maximum before exponentiating:

softmax(xi)=exixmaxjexjxmax\text{softmax}(x_i) = \frac{e^{x_i - x_{\max}}}{\sum_j e^{x_j - x_{\max}}}

This is mathematically identical (the exmaxe^{-x_{\max}} terms cancel) but numerically stable, because the largest exponentiated value is e0=1e^0 = 1.

In CUDA, computing the row maximum requires a parallel reduction within a thread block. It runs in log₂(N) passes: in each pass, half the threads compare pairs of values and write the winner into shared memory, then sync. After log₂(N) passes, thread 0 holds the block maximum.

// Conceptual parallel reduction for row maximum, illustrates the pattern
__shared__ float smem[BLOCK_SIZE];
smem[threadIdx.x] = input[threadIdx.x];
__syncthreads();

for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
    if (threadIdx.x < stride)
        smem[threadIdx.x] = max(smem[threadIdx.x], smem[threadIdx.x + stride]);
    __syncthreads();
}
float row_max = smem[0];  // thread 0 has the max

The batched softmax kernel maps one CUDA block to one row. gridDim.x handles the batch dimension. This design maximizes shared memory reuse and avoids inter-block synchronization.

Tiled GEMM

The attention computation requires two GEMM operations: S=QKTS = QK^T (scores) and O=softmax(S)VO = \text{softmax}(S) \cdot V (output). These have different matrix layouts:

  • QKTQK^T: multiply Q (normal row-major) by KTK^T (transposed), so the GEMM_NT kernel
  • SVSV: multiply S (normal) by V (normal), so the GEMM_NN kernel

Both use the tiled GEMM pattern from Project 1 (see The Tiling Solution). The attention-specific twist is handling both GEMM_NN and GEMM_NT layouts while keeping shared-memory accesses efficient.

Tiled GEMM with shared memory
Figure 2: The tiled GEMM pattern: each thread block computes one output tile by iterating over the K-dimension in steps of TILE_SIZE. Each step loads one tile of A and one tile of B into shared memory, synchronizes, and computes partial products. Global memory accesses are reduced by a factor of TILE_SIZE.

For the QKTQK^T case, the transposed layout means tile loading is different, because you’re loading rows of KK as if they were columns. Getting this indexing right without introducing bank conflicts is the fiddly part of GEMM_NT.

Why Not Just Use cuBLAS?

A reasonable question: cuBLAS has extremely optimized GEMM routines. Why write your own?

Two reasons. First, understanding: implementing tiled GEMM forces you to internalize the shared memory tiling pattern at a level that reading about it doesn’t. Second, fusion: FlashAttention works because the GEMMs, mask, softmax, and output GEMM are fused into a single kernel to avoid materializing the N×N intermediate matrix. You can’t fuse cuBLAS calls, since they’re black boxes. You need control over the kernel internals.

FlashAttention: Tiling Attention Itself

The Memory Bottleneck

The naive attention implementation’s memory problem:

  1. Compute S=QKT/dS = QK^T/\sqrt{d}, which writes a [batch, heads, N, N] tensor to global memory
  2. Apply causal mask, which reads and writes the full [batch, heads, N, N] tensor
  3. Compute softmax, which reads and writes the full [batch, heads, N, N] tensor
  4. Compute O=softmax(S)VO = \text{softmax}(S) \cdot V, which reads the full [batch, heads, N, N] tensor

Steps 1-4 each touch O(N²) data in global memory. For N=4096, that’s four passes over 32 MB (per head, per batch element). The arithmetic intensity is terrible, because you’re doing minimal computation per byte loaded.

Can we do all four steps with a single pass over the data, keeping the intermediate results in fast on-chip shared memory instead of writing them to global memory?

Online Softmax and FlashAttention

A useful mathematical insight is that softmax can be computed incrementally as you see new values, without knowing the final maximum or sum in advance.

If you’ve already computed a partial softmax for the first j elements, and you see a new element xj+1x_{j+1}, you can update the running maximum and running sum to incorporate the new value, then correct the partial output accordingly. This is the “online softmax” algorithm.

FlashAttention applies this idea at tile granularity:

  1. Divide Q into blocks of BrB_r rows, and divide K and V into blocks of BcB_c rows
  2. For each block of Q: iterate over all blocks of K and V
  3. For each (Q-block, K-block) pair: load into shared memory, compute partial attention scores, apply online softmax update
  4. Accumulate into the output block

Practically, at no point is the full N×N score matrix written to global memory. In practice, tiled kernels still reload K/V blocks across Q tiles, so this is not literally “read once each.” The paper-level win is that HBM traffic scales with tiled streaming plus final output, instead of repeated full-score-matrix read/write passes. Memory complexity drops from O(N²) intermediates to O(N) auxiliary state.

# Conceptual FlashAttention-2 forward, illustrates the algorithm,
# not the project-specific implementation. Based on Algorithm 1 of Dao (2023).
def flash_attention_forward(Q, K, V, causal=False):
    B, H, N, d = Q.shape
    B_r, B_c = ..., ...   # tunable tile sizes

    O = torch.zeros_like(Q)
    for q_block_start in range(0, N, B_r):
        q = Q[:, :, q_block_start:q_block_start+B_r, :]   # [B, H, B_r, d]
        m = torch.full((B, H, B_r), -inf)   # running max, per query row
        l = torch.zeros((B, H, B_r))        # running sum of exp

        for k_block_start in range(0, N, B_c):
            k = K[:, :, k_block_start:k_block_start+B_c, :]
            v = V[:, :, k_block_start:k_block_start+B_c, :]

            s = (q @ k.transpose(-2, -1)) / math.sqrt(d)   # [B, H, B_r, B_c]
            # apply causal mask if needed ...

            m_new = torch.max(m, s.max(dim=-1).values)
            l_new = torch.exp(m - m_new) * l + torch.exp(s - m_new.unsqueeze(-1)).sum(-1)
            O[:, :, q_block_start:q_block_start+B_r, :] = (
                torch.diag_embed(torch.exp(m - m_new)) @ O[...] +
                torch.exp(s - m_new.unsqueeze(-1)) @ v
            )
            m, l = m_new, l_new

        O[:, :, q_block_start:q_block_start+B_r, :] /= l.unsqueeze(-1)
    return O
Note

There’s a known typo in Algorithm 1 of the original FlashAttention-2 paper. The GitHub issue has the correction. Worth reading before implementing.

Translating to CUDA

The PyTorch version above (Task 3 in the project) is useful for understanding the algorithm. But to get the performance gains, you need to translate it to a CUDA kernel (Task 4) where:

  • Q, K, V tiles are loaded into __shared__ arrays
  • The loop over K/V blocks is the inner loop of the kernel (within a single kernel launch)
  • The running max and sum are kept in registers, not global memory
  • The output accumulation happens in shared memory

Tile size (BrB_r × BcB_c) is a critical tuning parameter. Larger tiles mean more reuse per global memory load, but require more shared memory. The maximum head dimension is 128, which constrains how large the tiles can be before you exceed the SM’s shared memory budget.

On newer GPUs, this translation is also a data-movement problem: stage global-memory tiles into shared memory efficiently, then keep the math fed with minimal synchronization overhead. Techniques like asynchronous copy (cp.async) and carefully placed barriers exist to overlap tile movement and compute, but they raise register pressure and shared-memory footprint. That tension is exactly why tile size and block shape tuning remain empirical in practice.

Speedup Results

The performance story is compelling. Against torch.nn.MultiheadAttention (PyTorch’s built-in):

ConfigurationSpeedup (FlashAttention)
batch=4, seq=128, non-causal~2.4×
batch=64, seq=4096, non-causal~7.7×
batch=4, seq=128, causal~1.8×
batch=64, seq=4096, causal~2.0×

Measurement context for this table: course project harness on an NVIDIA H100, FP16 attention path with FP32 softmax accumulation, wall-clock timings from the provided benchmarking script, and no full multi-run confidence-interval study (so treat values as directional). This is enough to compare kernel variants in-project, but not a publication-grade benchmark suite.

The speedup grows with sequence length, consistent with reduced HBM traffic and better tiling efficiency at long contexts. At seq=4096, the naive implementation spends substantially more time moving intermediate score tensors, while FlashAttention avoids materializing that N×N buffer.

Inference Optimization: Prefill vs. Decode

The Two Phases of LLM Inference

Running a language model in production has two very different computational phases:

Prefill: Process the entire input prompt at once to produce the first output token. The full N-token sequence runs through all layers. This is a GEMM-heavy operation, so all of FlashAttention applies here.

Decode: Generate subsequent tokens one at a time. Each new token attends to all previous tokens, but it’s just one query attending to N keys and values. The attention score computation degenerates from a matrix-matrix multiply to a matrix-vector multiply (GEMV). Algorithmically this is O(N) per token for attention against the cache, and implementation details (cache layout, bandwidth, launch overhead) dominate throughput.

Precision, Bandwidth, and Arithmetic Intensity

Modern inference stacks run with mixed precision because the limiting resource is often bytes moved, not peak FLOPs. FP16/BF16 halves activation and weight footprint versus FP32, which doubles effective memory bandwidth for the same traffic pattern and usually improves arithmetic intensity. Many deployments push further (FP8/INT8) for weights and activations, but keep numerically sensitive reductions (softmax max/sum, accumulation terms) in higher precision to avoid instability.

That tradeoff shows up directly in FlashAttention. The tile math can run at reduced precision, but the running max and running sum for online softmax should stay in FP32. In other words: lower precision for throughput where error is tolerable, higher precision for normalization paths where small numeric drift compounds across long sequences.

The KV Cache

The critical optimization for decode is the KV Cache: pre-compute and store the K and V vectors for all previously processed tokens. On each decode step:

  1. Compute Q, K, V for only the new token
  2. Append the new K and V vectors to the KV cache
  3. Run attention with the new query against all K and V values in the cache (growing with each step)

Without the KV cache, each decode step would need to re-process the entire sequence from scratch: O(N) work per token and O(N²) total for generating a full response. With the KV cache, each step is O(N) work but you avoid recomputing K and V for all previous tokens.

The KV cache has a memory cost: for GPT-2 (12 layers, 12 heads, head_dim=64), generating 100 tokens on top of a 1024-token prompt requires storing layers × heads × K/V × seq × head_dim = 12 × 12 × 2 × 1124 × 64 = 20,717,568 values per batch element. At FP16, that’s about 41.4 MB per batch element (or ~82.9 MB at FP32), ignoring allocator/paging overhead. For large models at scale (GPT-4 class, long contexts, large batches), KV cache memory often becomes the binding constraint on batch size.

The state-of-the-art solution to KV cache memory fragmentation is PagedAttention, which treats the KV cache like virtual memory with pages rather than pre-allocating contiguous blocks. But that’s beyond the scope of this project.

The Decode Kernel

The FlashAttention CUDA kernel needs modification for the decode phase. When Q has sequence length 1 (one query), the outer loop over Q-blocks collapses to a single iteration, and you process the one query against all K/V blocks.

// Decode kernel: Q has seq_len = 1, so the outer B_r loop disappears.
// Just iterate over K/V blocks and accumulate the output.
// (Conceptual, not project-specific code)
for (int kv_block = 0; kv_block < N; kv_block += B_c) {
    // Load K and V tiles into shared memory
    // Compute Q @ K^T for this block (a dot product, not a GEMM)
    // Online softmax update
    // Accumulate into output vector
}

The decode kernel produces exactly the same result as running the full attention with seq_len=1 in the prefill kernel, but with the outer loop specialized away. This both reduces computation and potentially enables better register/shared memory utilization.

Performance vs. naive PyTorch:

  • TTFT (time to first token, prefill): ~3.5× speedup
  • TBT (time between tokens, decode): ~3.2× speedup

Putting It All Together: Running GPT-2

Task 7 connects the custom attention kernels to a real GPT-2 model loaded from Hugging Face. The model’s attention layers are replaced with the custom implementation. The inference loop runs:

  1. Prefill: Run the input prompt through all 12 GPT-2 layers using the FlashAttention kernel. Store K and V for all layers in the KV cache.
  2. Decode: For each new token: run one forward pass through all layers using the decode kernel, appending new K/V to the cache. Sample from the output distribution. Repeat.

If the attention kernels are numerically correct, GPT-2 should generate coherent text similar in quality to the reference PyTorch implementation. If there are bugs (wrong softmax normalization, off-by-one in tiling, etc.), the model generates garbage. The end-to-end test is a stringent correctness check: any systematic numerical error accumulates across 12 layers and 100 decode steps.

A generated sample on an attention-related prompt produced something like:

“What are the advantages of using flashattention kernels compared to the naive pytorch implementation? Think about the shared memory usage, the computation complexity, and the performance.

For a long time, we didn’t see any advantage in using flashattention kernels, but it’s hard not to find that we’ve become much better…”

It’s GPT-2 (2019), not GPT-4. The factual content is questionable. But the syntax is coherent and the model is clearly doing language modeling, which means the kernels are correct.

Design Questions This Raised

Working through the project, the questions that stayed open were architectural rather than algorithmic:

  • What’s the right tile size? The project required tuning BrB_r and BcB_c empirically. The optimal values depend on the GPU’s shared memory budget, the head dimension, and the warp occupancy. Is there a principled way to derive them, or is it always measured?

  • When does causal masking change the tradeoff? With a causal mask, roughly half the K/V blocks per Q-block are “above the diagonal” and produce all-zero attention weights. FlashAttention can skip these blocks entirely, so the causal version can be faster than the non-causal version for large sequences. How should the kernel detect and exploit this?

  • How does FlashAttention interact with quantization? Modern LLM inference uses FP8 or INT8 weights and activations. The online softmax in FlashAttention requires fp32 accumulators for the running max and sum. Does the tiling approach compose cleanly with quantized arithmetic, or does it require special handling at tile boundaries?

  • How do multi-query attention (MQA) and grouped-query attention (GQA) change things? GPT-2 uses full MHA where each head has its own K and V. Modern models (LLaMA, Mistral) use GQA where K and V are shared across groups of Q heads. The KV cache shrinks dramatically, but the attention kernel needs to handle the asymmetric head counts.

Conclusion: The Course in One Kernel

FlashAttention is where the course concepts finally converge in one kernel design.

Implementing FlashAttention requires understanding the memory hierarchy (shared memory for tiles, global memory for Q/K/V/O), the execution model (warps, blocks, register pressure), the computational arithmetic (numerically stable softmax, online updates, fused kernels), and the ML context (why attention is the bottleneck, prefill vs. decode, KV cache).

It also demonstrates something broader about GPU programming: the wins come from rethinking the algorithm to match the hardware constraints, not from micro-optimizing a fixed algorithm. The naive attention implementation does the right math. FlashAttention does the same math in a different order, one that respects the memory hierarchy. That’s the whole trick.

The 7.7× speedup at seq=4096 isn’t free. It’s the payoff for reducing expensive intermediate-memory traffic and matching the algorithm to GPU memory hierarchy constraints. The gains come from tile boundary handling, online softmax updates, and careful shared memory layout. Writing the kernel teaches you why that speedup is possible in a way that reading the FlashAttention paper doesn’t.

Additional Resources

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré — NeurIPS 2022

The original paper. Section 3 derives the online softmax algorithm and the tiling scheme from first principles. Reading this before implementing is essential, because it explains why the algorithm works, not just the mechanics.

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning Tri Dao — ICLR 2024

Algorithm 1 is the one you implement in the project. The improvements over FA1 are mainly about better parallelism across GPU SM warps and reduced non-GEMM FLOPs. Note the known typo in Algorithm 1 (see the GitHub issue in the README).

Programming Massively Parallel Processors: A Hands-on Approach

Programming Massively Parallel Processors: A Hands-on Approach

David B. Kirk and Wen-mei W. Hwu

Chapter 5 (tiled matrix multiplication) is the direct predecessor to the GEMM kernels in this project. Chapter 17 covers tensor cores and the ML hardware context for understanding why FlashAttention matters.

A Note on Code Availability

In accordance with Georgia Tech’s academic integrity policy and the license for course materials, the source code for this project is kept in a private repository. This post follows Dean Joyner’s advice on sharing projects with a focus not on any particular solution but on an abstract overview of the problem and the underlying concepts involved.

I would be happy to discuss implementation details, kernel design choices, or performance results in an interview. Please feel free to reach out to request private access to the repository.