High-Performance GPU Kernels for Deep Learning
Operations: Implementation and Optimization
1. Introduction
The remarkable advancements in deep learning (DL) are intrinsically linked to the
availability of massively parallel computing hardware, predominantly Graphics
Processing Units (GPUs). Achieving optimal performance—high throughput and low
latency—in DL training and inference hinges critically on the efficiency of the
underlying GPU kernels. These kernels are responsible for executing the core
mathematical operations that constitute neural networks. Developing
high-performance GPU kernels is a complex endeavor, fraught with challenges that
demand a deep understanding of both the algorithms and the hardware architecture.
Key among these challenges is the effective management of the GPU's memory
hierarchy. Data must be meticulously orchestrated between the high-capacity but
slower High Bandwidth Memory (HBM) and the smaller but significantly faster on-chip
Static Random-Access Memory (SRAM), typically utilized as shared memory or L1
cache, and ultimately, registers.1 Minimizing data movement, especially to and from
HBM, and maximizing data reuse within faster memory tiers are paramount, as many
DL workloads are memory-bandwidth-bound rather than compute-bound.2 This
observation underscores a fundamental principle: the efficiency of data transfer often
dictates overall kernel performance more than the raw count of arithmetic operations.
Exploiting the massive parallelism inherent in GPUs is another critical aspect. Kernels
must be designed to effectively utilize thousands of threads, organized into warps and
thread blocks, and distributed across multiple Streaming Multiprocessors (SMs).3
However, merely launching many threads is insufficient; developers must avoid
common performance pitfalls. Uncoalesced memory access, where threads in a warp
access disparate memory locations, leads to multiple, inefficient HBM transactions.4
Shared memory bank conflicts, occurring when multiple threads within a warp attempt
to access different addresses within the same memory bank simultaneously, serialize
access and reduce effective bandwidth.4 Thread divergence, where threads within a
warp follow different execution paths due to conditional logic, diminishes SIMT (Single
Instruction, Multiple Thread) efficiency, as some threads become idle.4
Kernel fusion represents another powerful optimization strategy. By combining
multiple distinct operations into a single, larger kernel, intermediate data can often be
kept within on-chip memory (registers or shared memory), thereby avoiding costly
round trips to HBM.5 This not only reduces memory bandwidth consumption and
latency but also amortizes kernel launch overhead. Examples include fusing Layer
Normalization with preceding bias and residual additions or integrating rescaling
operations directly into quantized matrix multiplication kernels.7
This report aims to provide a detailed examination of the implementation and
optimization strategies for several common and critical deep learning operations,
focusing on the nuances relevant for high-performance GPU kernel development. The
operations covered include matrix multiplication (GEMM), numerically stable Softmax,
Layer Normalization, the forward pass of FlashAttention, common element-wise
operations, and convolution (Conv1D/Conv2D). For each, the discussion will delve into
algorithmic choices, memory access patterns, parallelism strategies, and techniques
to mitigate performance bottlenecks on GPU architectures.
2. Optimizing Matrix Multiplication (GEMM) Kernels
General Matrix Multiplication (GEMM) is a cornerstone operation in deep learning,
forming the computational core of fully connected layers and, often, convolutional
layers (via transformations like im2col). Optimizing GEMM kernels for GPUs involves
sophisticated tiling strategies to exploit data reuse across the memory hierarchy and
careful management of memory access patterns.
2.1. Fundamental Tiling Strategies for Data Reuse
Tiling is a fundamental technique to improve data locality and reduce the number of
accesses to slower memory levels. GEMM kernels typically employ a hierarchy of tiling
strategies.
2.1.1. Shared Memory (LDS) Tiling
The primary goal of shared memory tiling (often referred to as Local Data Store or LDS
tiling on AMD GPUs, or simply shared memory on NVIDIA GPUs) is to minimize
high-latency global memory (HBM) accesses.
● Concept: Input matrices A and B are too large to fit entirely in fast on-chip
memory. Therefore, they are broken down into smaller blocks, or "tiles".1
● Mechanism: Each thread block is assigned the computation of one tile of the
output matrix C. To compute this output tile, the thread block iteratively loads
tiles of A and B from global memory into shared memory. After each pair of tiles
(one from A, one from B) is loaded into shared memory, threads within the block
perform a partial matrix multiplication using only data from shared memory,
accumulating the results in their local registers. A __syncthreads() barrier is
necessary after loading into shared memory and before computation to ensure all
data is available to all threads in the block.1
● Performance: This strategy significantly reduces the number of times each
element of A and B is read from global memory. For instance, if a tile of A is
loaded into shared memory, all threads in the block that need elements from that
tile for their portion of the output C tile can access it from the much faster shared
memory.1 This is critical because global memory accesses can have latencies of
hundreds of cycles.1
2.1.2. Register Tiling
Register tiling further optimizes data reuse at the finest level of the memory
hierarchy—registers, which offer the fastest access.
● Concept: Within the computation of a shared memory tile, each thread is
responsible for a small sub-matrix (or a few elements) of the output C tile.
Register tiling aims to maximize the reuse of data loaded from shared memory
into these registers.1
● Mechanism: A thread loads an element (or several elements) from the shared
memory tiles of A and B into its private registers. It then performs multiple
multiply-accumulate (MAC) operations using these register-held values before
discarding them and loading new values. For example, a thread might load one
element from its row in tile A and multiple elements from its column in tile B (or
vice-versa), performing several MACs.
● Performance: This minimizes accesses to shared memory, which, while much
faster than global memory, is still slower than registers. It also increases
arithmetic intensity at the thread level.
The hierarchical application of these tiling strategies—global memory to shared
memory tiles, and shared memory to register tiles—is a direct reflection of the GPU's
memory architecture and is crucial for approaching peak performance.3
2.2. Memory Access Optimization
Efficiently moving data between memory levels is as important as reusing it. Two key
aspects are coalesced global memory access and avoiding shared memory bank
conflicts.
2.2.1. Coalesced Global Memory Access
● Concept: When threads within a warp (a group of typically 32 threads that
execute in lockstep) access global memory, their individual memory requests can
be "coalesced" by the hardware into a smaller number of wide memory
transactions if the accessed locations are contiguous.4 This is far more efficient
than each thread triggering a separate, narrow transaction.
● Mitigation/Implementation: To achieve coalesced access when loading tiles of
A and B into shared memory, threads must be organized such that thread i, i+1,...,
i+W-1 (where W is warp size) access consecutive memory locations. For a
row-major matrix, if a warp is loading a part of a row, each thread should load an
adjacent element. If loading part of a column, access patterns need to be
designed carefully, potentially involving a transpose in shared memory if
subsequent access patterns favor a different layout. Inefficient global memory
loading can severely undermine the benefits of shared memory tiling, as the initial
data fetch itself becomes a bottleneck.
2.2.2. Shared Memory Bank Conflicts
● Concept: Shared memory is physically divided into a number of equally-sized
memory modules called banks (typically 32 banks, each 32-bits or 64-bits wide).
Multiple simultaneous accesses to different addresses within the same bank by
threads in a warp result in a bank conflict, causing these accesses to be
serialized.4 Accesses to different banks or all threads accessing the exact same
address (broadcast) do not cause conflicts.
● Mitigation/Implementation:
○ Padding: One common technique is to pad the dimensions of shared memory
arrays. For example, if a tile is TILE_DIM wide and TILE_DIM is a multiple of the
number of banks (e.g., 32), accessing columns (shared_tile[idx][threadIdx.x])
can lead to bank conflicts if idx increments sequentially for threads. Adding an
extra column of padding, e.g., float shared_tile;, can change the mapping of
elements to banks and resolve such conflicts.4
○ Access Pattern Design: Carefully design how threads map to shared memory
locations. For example, if blockDim.x is 32, an access like
sh_mem[threadIdx.y][threadIdx.x] where each thread in a warp has a unique
threadIdx.x but the same threadIdx.y will be conflict-free for row-wise access.
2.3. Accumulation Strategies and Mixed-Precision GEMM
Modern DL often employs mixed-precision training to accelerate computation and
reduce memory footprint, using lower-precision formats like FP16 (half-precision) or
BF16 (BFloat16) for most computations, while accumulating in higher precision (FP32)
to maintain numerical stability.
2.3.1. FP16/BF16 Computation with FP32 Accumulation
● Mechanism: Input matrices A and B are often in FP16 or BF16. The multiplication
A_ik * B_kj is performed using FP16/BF16 arithmetic. However, the resulting partial
products are then cast to FP32 before being added to an FP32 accumulator
register.11 This is crucial because repeated additions in FP16 can quickly lead to
loss of precision or overflow/underflow, especially if the accumulated values
become much larger than the individual addends.12
● FP32 Master Weights: In the broader context of training, model weights are
often stored in FP32 ("master weights"). For each training iteration, an FP16/BF16
copy is made for the forward and backward passes. Gradients, also often
computed in FP16/BF16, are used to update the FP32 master weights.12 This
ensures that small gradient updates are not lost.
● Loss Scaling: When using FP16 for gradients, loss scaling (multiplying the loss by
a scaling factor before backpropagation and then unscaling the gradients before
weight update) is often necessary to prevent small gradient values from
underflowing to zero in FP16 representation.11 While not part of the GEMM kernel
itself, it's an indispensable system-level component for stable mixed-precision
training.
2.3.2. Role of Tensor Cores
NVIDIA GPUs, starting from the Volta architecture, include specialized hardware units
called Tensor Cores, designed to accelerate GEMM operations, particularly for mixed
precision.3
● Functionality: A Tensor Core can perform a small matrix multiply-accumulate
operation, for example, a 4×4 FP16 matrix multiplication added to a 4×4 FP16 or
FP32 accumulator matrix, in a single clock cycle.3 Newer generations support
various precisions (including BF16, TF32, INT8, FP8) and larger matrix fragment
sizes.
● Programming Model (e.g., CUTLASS): Directly programming Tensor Cores via
assembly (PTX mma.sync instructions) is complex. Libraries like CUTLASS provide
C++ template abstractions to define and execute GEMMs using Tensor Cores.3
CUTLASS organizes computation hierarchically:
○ Threadblock-level GEMM: Manages loading data tiles from global memory
to shared memory.
○ Warp-level GEMM: Multiple warps within a threadblock fetch data from
shared memory into registers and issue matrix-multiply-accumulate (MMA)
instructions to Tensor Cores. These are often mma.sync (for
warp-synchronous operations) or older wmma
(warp-matrix-multiply-accumulate) intrinsics.3
○ Instruction-level GEMM: The actual MMA instructions executed by the
Tensor Cores. The shift towards Tensor Cores means that GEMM kernel
optimization is now heavily focused on efficiently feeding these units with
data and structuring the computation to match their operational
characteristics, rather than hand-coding individual scalar floating-point
operations.
● Constraints: Tensor Cores often impose constraints on matrix dimensions (e.g.,
M, N, K must be multiples of 8 or 16 for certain FP16 operations) and require
specific data layouts in shared memory for optimal performance.11
2.4. INT8 GEMM Kernels
Quantizing models to use 8-bit integers (INT8) for weights and/or activations can
provide significant performance boosts and memory reduction, especially for
inference.
2.4.1. Integer Multiplication and Accumulation
● Mechanism: Input matrices A and B are quantized to INT8. The core operation
involves multiplying INT8 values. An int8 * int8 multiplication results in an INT16
value. These INT16 products are typically accumulated into INT32 registers to
prevent overflow during the summation over the K dimension.8 NVIDIA Tensor
Cores also support INT8 MMA operations, which internally handle the
accumulation in INT32.14
● Output: The final INT32 accumulated sum for each element of the output matrix
C must then be dequantized back to a floating-point format or requantized to
INT8 if the next layer also expects INT8 input.
2.4.2. Quantization/Dequantization and Scaling Factors
● Symmetric Quantization: A common scheme is symmetric quantization.
○ Quantization: xq=round(clip(x/s,−128,127)) for signed INT8, where x is the
original float value and s is the per-tensor or per-channel scaling factor.15
○ Dequantization: x=xq×s.15
● Role in GEMM: If Afp32×Bfp32=Cfp32, then in quantized form,
(sAAint8)×(sBBint8)=sCCint8. This means Aint8×Bint8=(sC/(sAsB))Cint8. The term
Aint8×Bint8results in INT32 accumulators. So, Accumint32=(sC/(sAsB))Cint8. To
get Cfp32, one would compute Cfp32=Accumint32×sA×sB. If the output C also
needs to be INT8, then Cint8=quantize(Accumint32×(sAsB/sC)). The term
(sAsB/sC) is a combined rescaling factor.
2.4.3. Handling INT32 Intermediate Results and Fused Rescaling
● Challenge: Storing the INT32 accumulated values back to global memory is
inefficient due to increased data size (4x compared to INT8).
● Solution: Fused Kernels: The dequantization (multiplying by sAsB) and any
subsequent requantization (dividing by sCand casting to INT8) or the
element-wise operations that often follow a GEMM (like bias addition and
activation functions) are fused into the GEMM kernel itself. The INT32
accumulators residing in registers are rescaled and then either stored as
FP32/FP16 or requantized to INT8 and stored, all within the same kernel, avoiding
HBM write/read of INT32 values.8 For instance, the rescaling operation sw (weight
scale) can be fused, as noted in dynamic quantization kernel discussions,
preventing materialization of the INT32 intermediate.8
The evolution of GEMM kernels, particularly with the advent of Tensor Cores and the
push for lower precision, highlights a strong interplay between algorithmic
reformulation (tiling), hardware specialization (Tensor Cores), and numerical
considerations (FP32 accumulation, scaling factors). Optimal performance requires a
holistic approach that considers all these factors.
Table 1: Comparison of MatMul Tiling Strategies
Strategy Target Memory Key Typical Complexity
Level Optimization Performance
Goal Impact
Naive (Global Global Baseline, direct Very low due to Low
Memory) computation high global
memory latency
and low data
reuse
Shared Memory Shared (SRAM) Reduce global Significant Medium
Tiling memory reads improvement by
by caching tiles exploiting data
in SRAM reuse within a
thread block 1
Register Tiling Register Reduce shared Further Medium-High
memory reads improvement by
by maximizing reusing data
register reuse already in the
fastest memory
tier 1
Tensor Core Register Leverage Highest High
Utilization (implicit) specialized throughput for
MMA hardware supported
for high precisions, but
throughput requires specific
data
layouts/sizes 3
3. Numerically Stable Softmax Kernels
The Softmax function is essential in classification tasks and attention mechanisms,
converting a vector of arbitrary real values into a probability distribution. Its naive
implementation, however, is prone to numerical instability.
3.1. Standard Softmax Definition and Numerical Challenges
The Softmax function for an element xiin a vector x is defined as:
softmax(xi)=∑jexjexi
Numerical issues arise primarily from the exponentiation term exi:
● Overflow: If any xiis a large positive number, exican exceed the maximum
representable value for the floating-point type (e.g., FP32 or FP16), resulting in
inf.16
● Underflow: If xiis a very negative number, exican underflow to zero. If all exjin
the denominator underflow to zero, this leads to a division by zero (0/0), resulting
in NaN.16
3.2. The Max-Subtraction Trick for Numerical Stability
To address these issues, a common technique is to subtract the maximum value of the vector
x from each element xjbefore exponentiation. Let c=maxk(xk). The numerically stable Softmax
is computed as:
softmax(xi)=∑jexj−cexi−c
This transformation does not change the final Softmax output because:
$$ \frac{e^{x_i - c}}{\sum_j e^{x_j - c}} = \frac{e^{x_i} e^{-c}}{\sum_j (e^{x_j} e^{-c})} =
\frac{e^{x_i} e^{-c}}{e^{-c} \sum_j e^{x_j}} = \frac{e^{x_i}}{\sum_j e^{x_j}} $$
The e−c term cancels out from the numerator and the denominator.16
The benefit of this trick is that the maximum value passed to the exponential function
becomes c−c=0, so e0=1. This prevents exj−c from overflowing. While it doesn't completely
eliminate underflow (if xj−c is still very negative), it often shifts the values into a more
numerically stable range.16
3.3. Kernel Implementation Strategy
A high-performance Softmax kernel, typically operating row-wise on an input matrix
(e.g., in attention scores), involves several stages which are ideally fused:
1. Find Row-wise Maximum (ci):
○ For each row i of the input matrix, find the maximum value ci. This is a parallel
reduction operation.
○ Threads within a thread block, cooperatively processing one or more rows,
can load row elements into shared memory.
○ A parallel reduction (e.g., tree-based reduction) is performed in shared
memory to find ci. Each thread block might compute the max for several rows
if rows are short, or multiple thread blocks might cooperate for very long
rows.
2. Subtract Max, Exponentiate, and Sum (exij−ciand Li=∑jexij−ci):
○ Once ciis known (and potentially broadcast or stored in shared memory
accessible to threads working on row i), each thread processing element xij
computes xij′=xij−ci.
○ Then, it computes pij′=exij′.
○ Another parallel reduction is performed over pij′for each row i to find the sum
Li=∑jpij′. This reduction also typically uses shared memory.
3. Element-wise Division:
○ Each thread then computes the final Softmax value: softmax(xij)=pij′/Li. The
sum Liis broadcast or read from shared memory by threads working on row i.
Optimization Considerations:
● Fusion: These three steps (find max, subtract-exp-sum, divide) should be fused
into a single kernel to avoid intermediate writes to and reads from global memory.
● Shared Memory Usage: Shared memory is crucial for both the reduction to find
ciand the reduction to find Li, as well as for broadcasting ciand Lito the threads
processing each row.
● Work Distribution: If rows are long, a single thread block might process one row.
If rows are short, a thread block might process multiple rows to ensure sufficient
parallelism and SM occupancy. Threads within the block would be partitioned to
handle different rows.
The implementation of a numerically stable Softmax kernel showcases how algebraic
manipulation can lead to more robust numerical behavior, and how multi-stage
computations involving reductions and element-wise operations can be efficiently
mapped to GPU architectures through careful use of shared memory and fusion. This
pattern of reduction followed by element-wise operations using the reduced value is
also seen in other normalization layers like LayerNorm.
4. Efficient LayerNorm Kernels
Layer Normalization (LayerNorm) is a widely used technique for normalizing
activations within a neural network layer. It helps stabilize training dynamics and
improve model convergence. Unlike Batch Normalization, LayerNorm computes
normalization statistics (mean and variance) across the feature dimension for each
individual training sample independently.
4.1. LayerNorm Definition and Purpose
For an input vector x (representing the features of a single instance), LayerNorm is defined as:
y=Var[x]+ϵx−E[x]⋅γ+β
where E[x] is the mean of x, Var[x] is the variance of x, ϵ is a small constant added for
numerical stability (to prevent division by zero), and γ (scale) and β (shift) are learnable affine
transformation parameters of the same dimension as x.17 In a typical Transformer model, if an
input tensor has shape [batch_size, sequence_length, hidden_size], LayerNorm is often
applied over the hidden_size dimension, meaning mean and variance are computed
independently for each [batch_idx, seq_idx] slice.
4.2. Implementation as Reduction and Broadcast
Implementing LayerNorm efficiently on a GPU involves two main phases: a reduction
phase to compute statistics and a broadcast/element-wise phase to apply the
normalization and affine transformation.
4.2.1. Reduction Phase (Mean and Variance Calculation)
For each input instance (e.g., each row in a [sequence_length, hidden_size] tensor,
normalizing along hidden_size):
1. Sum and Sum of Squares: The mean and variance require the sum of elements
(∑xj) and the sum of squares of elements (∑xj2) along the normalization axis.
○ A naive kernel might perform two separate passes over the data for each
instance: one to compute the sum for the mean, and another to compute the
sum of squared differences from the mean for the variance.18 This is inefficient
due to repeated global memory reads.
○ An optimized approach computes both ∑xjand ∑xj2in a single pass.18 Each
thread participating in the normalization of a given instance loads an element
xj, adds it to a partial sum accumulator, and adds xj2to a partial
sum-of-squares accumulator. These partial sums are then reduced across all
participating threads for that instance, typically using shared memory.
2. Mean Calculation: μ=(∑xj)/D, where D is the dimension size (e.g., hidden_size).
3. Variance Calculation: σ2=(∑xj2)/D−μ2.17
4.2.2. Broadcast and Element-wise Normalization Phase
1. Broadcast Statistics: The computed μiand σi2for each instance i (or more
efficiently, the inverse standard deviation si=1/σi2+ϵ) are effectively broadcast to
all threads that are processing elements of that instance i.
2. Normalization: Each thread j working on instance i applies the normalization:
xij′=(xij−μi)⋅si.
3. Affine Transformation: The learnable parameters γ and β (vectors of length D)
are applied: yij=xij′⋅γj+βj. The same γjand βjare used for the j-th feature across
all instances in the batch/sequence.17
4.3. Kernel Optimization Strategies
● Shared Memory for Reductions: This is crucial for efficiently calculating ∑xjand
∑xj2within a thread block. Threads load elements into registers, compute local
partial sums, and then write these partial sums to shared memory. A parallel
reduction algorithm (e.g., tree-based or iterative pairwise sum) is then executed
on the shared memory data to get the final sums for the instance.18
● Coalesced Memory Access: When threads read the input x for computing sums
and for the final normalization step, and when writing the output y, memory
accesses should be coalesced to maximize global memory bandwidth.18 Reads of
γ and β should also be efficient.
● Fused Operations:
○ Modern hardware and kernel libraries often provide fused operations for parts
of LayerNorm. For example, specialized instructions or micro-coded
sequences might compute mean and variance statistics more efficiently (e.g.,
nki.isa.bn_stats and nki.isa.bn_aggr in AWS Neuron SDK, which reduce cycles
compared to separate mean and square operations).17 Similarly, the
shift-and-scale part of normalization (x - mean) * rsqrt(var + eps) can be
fused (e.g., nki.isa.tensor_scalar 17).
○ Fused LayerNorm with Bias/Residual: A significant optimization is to fuse
LayerNorm with common preceding operations like bias addition and/or
residual connections. Instead of tmp = bias + residual_alpha * residual + x;
output = LayerNorm(tmp), a fused kernel computes LayerNorm(bias +
residual_alpha * residual + x) in one go.6 This avoids writing the intermediate
tmp tensor to global memory and then reading it back, saving bandwidth and
latency. The γ and β parameters are loaded once and reused across
rows/tiles.17
The structure of a LayerNorm kernel, involving per-sample reductions followed by
element-wise operations using these reduced statistics, makes it a good candidate
for shared memory optimization and instruction-level fusion. The broadcasting nature
of γ, β, and the per-sample statistics (μi,σi2) aligns well with SIMT execution, where
threads in a warp can process multiple features of the same sample in parallel using
these common values.
4.4. Brief Comparison: RMSNorm Kernels
Root Mean Square Normalization (RMSNorm) is a simplification of LayerNorm that has
gained popularity.
● Definition: y=D1∑jxj2+ϵx⋅γ. It omits the mean subtraction (re-centering) step of
LayerNorm, only performing re-scaling based on the root mean square of x.19
● Kernel Implementation: The reduction phase is simpler, as only the sum of
squares (∑xj2) is needed to compute the RMS value. The subsequent application
of γ (and optionally β, though often RMSNorm is used without β) is similar to
LayerNorm.
● Performance: Due to the reduced computation (no mean calculation and no
subtraction of mean from x), RMSNorm is generally faster than LayerNorm.
Reported speedups range from 7% to 64% in training and inference time, with
comparable model performance in many cases.19
Table 2: LayerNorm vs. RMSNorm Kernel Characteristics
Feature LayerNorm RMSNorm
Normalization Formula y=Var[x]+ϵx−E[x]γ+β y=RMS[x]+ϵxγ(+β optional)
Statistics Computed Mean (E[x]), Variance (Var[x]) Root Mean Square
(RMS[x]=D1∑xj2)
Re-centering (Mean Yes No
Subtraction)
Re-scaling (Variance/RMS Div) Yes Yes
Computational Steps Sum, Sum of Squares, Mean, Sum of Squares, RMS,
Variance, Normalize, Affine Normalize, Affine Transform
Transform
Typical Performance Baseline Faster (e.g., 7-64% speedup
reported 19)
5. FlashAttention Forward Pass: IO-Aware Kernel Design
The attention mechanism, particularly the scaled dot-product attention, is a
fundamental component of Transformer models. However, its standard
implementation has a computational and memory complexity quadratic in sequence
length (N), i.e., O(N2d+N2) for computation and O(N2) for memory to store the
attention matrix. This makes it a significant bottleneck for processing long sequences.
FlashAttention is an IO-aware exact attention algorithm designed to mitigate this
bottleneck by avoiding the materialization of the full N×N attention matrix in GPU
HBM.2
5.1. Motivation: The IO Bottleneck of Standard Attention
In a standard attention implementation, the N×N attention score matrix S=QKT and the
subsequent probability matrix P=softmax(S) are explicitly computed and written to
HBM. These large intermediate matrices are then read back from HBM for further
computations (e.g., P is read to compute O=PV). For large N, the size of S and P (e.g.,
N=64k, FP16, N2≈4×109 elements, ≈ 8GB) can exceed SRAM capacity and lead to
substantial HBM read/write traffic.2 Since GPU compute capabilities have outpaced
HBM bandwidth improvements, these memory operations, rather than arithmetic
operations, often limit performance.2
5.2. Core Principles of FlashAttention: Tiling, Recomputation, and Kernel Fusion
FlashAttention addresses the IO bottleneck by:
1. Tiling: The computation is broken down into blocks (tiles). Blocks of query (Q),
key (K), and value (V) matrices are loaded from HBM into fast on-chip SRAM.
2. Kernel Fusion: All steps of the attention computation for a given set of
tiles—Sij=QiKjT, softmax, and multiplication by Vjto update the output Oi—are
performed within a single fused kernel, operating on data in SRAM. Only the final
output block Oiis written back to HBM.5 This avoids writing the intermediate Sijor
Pijblocks to HBM.
3. Recomputation (for backward pass): Instead of storing the large attention
matrix S or P from the forward pass for use in the backward pass, FlashAttention
recomputes them block by block during the backward pass. This trades increased
computation for significantly reduced memory usage (O(N) instead of O(N2)).21
The forward pass saves only the output O and the normalization statistics
(log-sum-exponentials) needed for this recomputation.
5.3. Online Softmax (Running Softmax) within Tiled Computation
A key challenge in tiled attention is computing the softmax, which requires
normalization over an entire row of the S matrix. FlashAttention uses an "online
softmax" or "running softmax" method that allows correct computation of softmax in a
tiled, iterative manner without needing the full row at once.23
For a given query block Qi(representing a block of rows of Q), the algorithm iterates through
blocks of keys Kjand values Vj. Let Oibe the corresponding output block. The running
statistics for each row r within Qiare mr(current maximum score encountered for row r) and lr
(current sum of esrk−mrfor row r).
When processing a new block of scores Sij=QiKjT:
1. Compute current block scores: Sij=QiKjT/dk.
2. For each row r in Sij: a. Find the new row maximum:
mr,new=max(mr,old,maxk(Sij,rk)). b. Scale the previous sum of exponentials:
lr,rescaled=lr,old⋅emr,old−mr,new. c. Compute exponentials for current block
scores: Pij,rk=eSij,rk−mr,new. d. Update the sum of exponentials:
lr,new=lr,rescaled+∑kPij,rk. e. Scale the previous output block:
Oi,r,rescaled=Oi,r,old⋅(lr,rescaled/lr,new). f. Compute current block's contribution
to output: ΔOi,r=(Pij,r∗Vj)/lr,new. g. Update output: Oi,r,new=Oi,r,rescaled+ΔOi,r. h.
Update mr,old=mr,newand lr,old=lr,newfor the next iteration.
This iterative update ensures that Oicorrectly accumulates the attention output as if
softmax were computed over the full row of S.20
5.4. Causal Masking in Tiled FlashAttention
For autoregressive models like GPT, causal masking is necessary to prevent a query
qufrom attending to a key kvif v>u.
● Standard Causal Masking: In a non-tiled approach, this involves adding −∞ to
the elements Suvwhere v>u before applying softmax.27
● Tiled Causal Masking: In FlashAttention, masking must be incorporated into the
tiled computation. The outer loop iterates over blocks of queries Qi(say, rows Br⋅i
to Br⋅(i+1)−1), and the inner loop iterates over blocks of keys Kj(columns Bc⋅j to
Bc⋅(j+1)−1).
○ If the column block index j is greater than the row block index i (i.e., j>i), all key
tokens in Kjare "future" tokens relative to all query tokens in Qi. Thus, the
entire score block Sijis masked, and its computation can be skipped. These
are "empty blocks".28
○ If j<i, all key tokens in Kjare "past" tokens. No causal masking is needed for
the Sijblock. These are "full blocks".28
○ If j=i, the Sijblock is a diagonal block. An intra-block causal mask is applied:
for a score Suvwithin this block (where u,v are local indices within the block),
if v>u (and global positions also satisfy this), it's masked. These can be "right
partial blocks" if not all elements are masked, or "left partial blocks" if the
masking only affects the right side of the block due to overall sequence
position.28 The Kvax implementation, for example, explicitly defines these
block types and applies masking logic accordingly, potentially skipping
computations for empty blocks or applying only document masks for left
partial blocks, and both document and positional (causal) masks for right
partial blocks.28 The loops for processing key/value blocks effectively run only
for j≤i.
● Column-wise Mask Representations: For more general sparsity patterns
beyond simple causal masking, or for more efficient representation of causal
masks, techniques like column-wise mask representation (e.g., FlashMask) can be
used. These represent the mask more compactly (e.g., O(N) instead of O(N2)) and
allow kernels to efficiently identify and skip computations for masked-out blocks
or elements.29
5.5. FlashAttention-2/3 Enhancements
Subsequent versions, FlashAttention-2 and FlashAttention-3, introduced further
optimizations 21:
● Reduced Non-Matmul FLOPs: Algorithmic tweaks to the online softmax update
rules to minimize computationally expensive non-matmul operations (e.g.,
divisions, exponentials where possible).25
● Better Parallelism:
○ FlashAttention-2 parallelizes the computation over the sequence length
dimension (in addition to batch size and number of heads). For the forward
pass, different thread blocks can process different row blocks of Q
independently. For the backward pass, different thread blocks process
different column blocks of dS (gradient w.r.t scores), with atomic operations
for accumulating dQ.21 This improves GPU occupancy, especially for long
sequences with small batch sizes.
● Improved Work Partitioning Between Warps: Within a thread block, work is
partitioned more efficiently among warps to reduce shared memory
communication. For instance, in the forward pass, Q might be split across warps,
while K and V are shared, reducing inter-warp synchronization compared to
splitting K.25
● FlashAttention-3: Further leverages hardware asynchrony on newer GPUs (like
Hopper's Tensor Memory Accelerator - TMA and asynchronous Tensor Core
execution) to overlap data movement (HBM to SRAM using TMA) with
computation (matmul and softmax operations). It also introduces block
quantization for FP8 precision.31
5.6. Backward Pass Recomputation
A critical aspect of FlashAttention's memory efficiency is that it does not store the
N×N matrices S or P during the forward pass for use in the backward pass. Instead,
during the backward pass, the necessary blocks of S and P are recomputed
on-the-fly using the original input blocks Qi,Kj,Vj(which are small enough to be
loaded into SRAM) and the mi,listatistics saved from the forward pass.21 This
recomputation allows the backward pass to also operate in a tiled manner with O(N)
memory footprint with respect to sequence length.
The design of FlashAttention, by meticulously considering the GPU memory hierarchy
and fusing operations, transforms attention from a memory-bound operation to one
that can better leverage the GPU's computational power. This has enabled training
models with significantly longer contexts than previously feasible. The evolution to
FlashAttention-2 and -3 further refines this by improving parallelism and adapting to
new hardware capabilities, underscoring the principle that optimal algorithms are
often co-designed with the hardware. The complexity introduced by integrating
features like causal masking into such a highly fused and tiled kernel highlights the
intricate nature of high-performance GPU programming.
Table 3: FlashAttention Forward Pass - Tiled Computation with Online Softmax
and Causal Masking
Step Description Key Operations Memory Level Causal
Masking Logic
(for block
Qi,Kj)
1. Outer Loop Iterate over Loop control, Registers for N/A
Setup blocks of Query initialization. Oi,mi,li.
Qi. Initialize
Oi=0,mi=−∞,li=0
.
2. Inner Loop Iterate over Loop control. If j>i (block
Setup blocks of Key Kj index), skip Kj,Vj
and Value Vj. block (entire Sij
is masked).
3. Load Qi Load i-th block HBM Read. SRAM N/A
of Q from HBM
to SRAM.
4. Load Kj,Vj Load j-th block HBM Read. SRAM N/A (block-level
of K and V from check done in
HBM to SRAM. Step 2)
5. Compute Sij Sij=(QiKjT)/dk. MatMul. SRAM (inputs), If j=i (diagonal
Registers/SRAM block), apply
(output Sij) intra-block
causal mask to
Sij(elements
where key_pos >
query_pos set to
−∞).
6. Update Online Compute Max, Exp, Sum, Registers/SRAM Masked
Softmax Stats mi,new,li,new Scale. elements in Sij
using do not
mi,old,li,oldand contribute to
Sij. mi,new,li,new.
7. Update Rescale Oi,old Scale, MatMul, Registers/SRAM Pijis derived
Output Oi and add Add. from (masked)
contribution Sij.
from (PijVj).
8. Loop/Store Continue inner HBM Write. HBM N/A
loop. After inner
loop, write Oi
from
registers/SRAM
to HBM.
6. Optimizing Element-wise Operation Kernels
Element-wise operations, such as vector addition (C=A+B), ReLU activation
(y=max(0,x)), and GeLU activation, are ubiquitous in neural networks. While
computationally simple, their efficient implementation on GPUs requires attention to
memory access patterns.
6.1. Parallel Implementation
The parallelization strategy for element-wise operations is typically straightforward:
● Each output element is computed independently. Thus, one GPU thread can be
assigned to compute one element of the output tensor.32
● For vector addition Ci=Ai+Bi, thread i reads Aiand Bifrom global memory,
performs the addition, and writes Ciback to global memory.
● Similarly for ReLU, thread i reads xi, applies the max(0,⋅) function, and writes yi.
GeLU involves more complex arithmetic but follows the same per-element
independent computation pattern.
6.2. Memory-Bound Nature
Element-wise operations are characterized by very low arithmetic intensity, defined as
the ratio of arithmetic operations to bytes of data accessed from memory.
● For example, FP32 vector addition performs one addition (4 bytes of computation
if considering FLOPs) for every 8 bytes read (A and B) and 4 bytes written (C),
assuming FP32 data. This is a low ratio.
● Consequently, the execution time of element-wise kernels is almost always limited
by the bandwidth of global memory (HBM), not by the GPU's computational
capability.32 The GPU cores spend most of their time waiting for data to arrive
from or be written to HBM.
6.3. Optimization Focus
Given their memory-bound nature, optimizations for element-wise kernels primarily
target maximizing effective memory bandwidth:
● Memory Coalescing: This is the most critical optimization. Threads within a warp
must access contiguous memory locations for both reads (e.g., Ai,Ai+1,…) and
writes (Ci,Ci+1,…).4 Uncoalesced access can drastically reduce effective
bandwidth by causing multiple separate memory transactions for data that could
have been fetched in one.
● Sufficient Parallelism: Launching enough thread blocks to fully occupy all
Streaming Multiprocessors (SMs) on the GPU is necessary to hide memory
latency and maximize throughput. Each SM can typically handle multiple active
warps.
● Loop Unrolling (within a thread): If the number of elements to process is much
larger than the number of threads launched, each thread might process multiple
elements in a loop. Unrolling this loop (e.g., processing 2 or 4 elements per
iteration instead of 1) can reduce loop control overhead and potentially improve
instruction-level parallelism.32 This is akin to "loading two values at a time" as
mentioned in reduction optimization contexts.32
● Vectorized Loads/Stores: Modern GPUs and compilers can sometimes utilize
wider load/store instructions (e.g., float2, float4 loads/stores for 64-bit or 128-bit
transactions) if memory is properly aligned and accessed. This allows a single
instruction to move multiple data elements, further improving bandwidth
utilization.
● Kernel Fusion: While not an optimization of the element-wise kernel itself,
element-wise operations are prime candidates for fusion with their producer or
consumer kernels. For example, an element-wise activation function is often
fused with a preceding GEMM or convolution operation. This avoids writing the
intermediate pre-activation tensor to global memory and then reading it back,
which is a significant saving for memory-bound operations.
In summary, while element-wise operations appear simple, achieving high
performance requires adherence to fundamental GPU programming principles,
especially those related to efficient memory access. Their low arithmetic intensity
means that any inefficiency in memory operations will directly translate to
performance degradation. The most substantial gains often come from system-level
optimizations like kernel fusion, which minimize the number of times these
memory-bound operations need to access HBM.
7. High-Performance Convolution Kernels (Conv1D/Conv2D)
Convolutional layers are fundamental to many deep learning architectures, especially
in computer vision. Optimizing 1D and 2D convolutions on GPUs involves various
techniques, each with its own trade-offs in terms of performance, memory usage, and
applicability. Common approaches include transforming convolutions into GEMM
operations (im2col), using Winograd's algorithm for small filters, or employing
FFT-based methods for larger filters.
7.1. Implicit Im2col Transformation and GEMM-based Convolution
The im2col (image-to-column) technique transforms a multi-dimensional convolution
into a single large General Matrix Multiplication (GEMM), allowing it to leverage highly
optimized GEMM libraries and hardware accelerators like Tensor Cores.34
● Im2col Concept:
1. Input Patch Matrix (Lowered Feature Matrix): For each position of the
sliding filter, the corresponding input patch (receptive field) is "unrolled" or
"lowered" into a column vector. These column vectors, from all sliding
positions, are stacked together to form a large matrix. If the input has Cin
channels and the filter is R×S, each column vector has Cin⋅R⋅S elements. The
number of columns is equal to the number of output pixels Hout⋅Wout.
2. Filter Matrix: The Coutfilters, each of size Cin×R×S, are unrolled into rows,
forming a matrix of size Cout×(Cin⋅R⋅S). The convolution is then equivalent to
multiplying this filter matrix by the input patch matrix.34
● Memory Overhead of Explicit Im2col: A naive "explicit" im2col approach, where
the input patch matrix is physically materialized in memory, leads to significant
memory duplication. This is because overlapping receptive fields in the input
tensor are replicated multiple times in the patch matrix. This can increase memory
requirements by a factor of up to filter height × filter width (R×S) compared to the
original input feature map.35 For example, AlexNet's 1.39MB input feature map
could expand to 14.57MB after explicit im2col.35
● Implicit Im2col: To avoid this memory blowup, modern DL libraries (like cuDNN)
and accelerators often use "implicit" im2col. Instead of creating the full im2col
matrix in HBM, specialized convolution kernels perform the data rearrangement
on-the-fly. Tiles of the input feature map are loaded into shared memory, and
threads then gather the necessary elements within shared memory to form the
"virtual" im2col patches needed for the GEMM-like computation for an output tile.
The actual GEMM might still be performed by Tensor Cores or standard MAC
units. This approach maintains the computational structure of GEMM without the
prohibitive memory cost of explicit im2col.35 The kernel logic becomes more
complex due to the intricate addressing required to fetch data from the original
input tensor according to convolutional strides, padding, and dilations.
7.2. Winograd Convolution for Small Filters (e.g., 3x3)
Winograd's minimal filtering algorithms can significantly reduce the number of
arithmetic operations, especially multiplications, required for convolutions with small
filters.36
● Principle: The convolution is performed in a transformed domain where
element-wise multiplications suffice, similar in spirit to FFT-based convolution but
optimized for small tiles. The transformations themselves involve additions and
multiplications by pre-computed constants.
● Mechanism (e.g., F(2x2, 3x3) for a 3x3 filter producing a 2x2 output tile): The
2D Winograd algorithm F(m×m,r×r) computes an m×m output tile using an r×r
filter. The core idea is Y=ATA, where:
○ d is an (m+r−1)×(m+r−1) input tile.
○ g is the r×r filter.
○ G,BT,AT are transformation matrices of fixed values for a given F(m,r).
○ ⊙ denotes element-wise multiplication.
1. Filter Transform: U=GgGT. This can be pre-computed if filters are static.
2. Input Transform: V=BTdB. This is applied to each input tile.
3. Element-wise Multiplication: M=U⊙V. These are (m+r−1)×(m+r−1)
element-wise products.
4. Output Transform: Y=ATMA. This transforms the result back to the spatial
domain. For F(2×2,3×3), this method reduces the number of multiplications
from 2×2×3×3=36 (direct convolution) to (2+3−1)2=42=16 for the element-wise
multiplication stage.37 This is a 2.25× reduction in multiplications.
● Kernel Implementation: A Winograd convolution kernel involves stages for each
transformation and the element-wise multiplication. The element-wise products
can be batched if processing multiple channels or input tiles. Efficient
implementation requires careful data layout and management in shared memory
for the intermediate transformed tiles.
● Limitations:
○ Numerical Error: Winograd transforms can introduce larger numerical errors
compared to direct convolution, especially as the tile size m increases or if the
transform points are not chosen carefully.36
○ Complexity: The control flow and addressing are more complex than direct
convolution or im2col.
○ Applicability: Most beneficial for small filters (typically 3x3, sometimes 5x5)
and certain tile sizes. For very small channels or batch sizes, the overhead of
transformations might outweigh the arithmetic savings.
7.3. FFT-based Convolution
The convolution theorem states that convolution in the spatial (or time) domain is
equivalent to element-wise multiplication in the frequency domain.40
● Mechanism:
1. Forward FFT: Transform the input feature maps and the filters (kernels) into
the frequency domain using Fast Fourier Transform (FFT). Filters are typically
padded to match the input size (or a padded input size) to perform linear
convolution using circular convolution properties.
2. Element-wise Multiplication: Multiply the transformed inputs and filters
element-wise in the frequency domain.
3. Inverse FFT: Transform the product back to the spatial domain using an
inverse FFT (IFFT) to get the output feature map.
● Kernel Structure: Requires highly optimized FFT and IFFT kernels, along with a
kernel for complex element-wise multiplication. Data is typically handled as
complex numbers in the frequency domain.
● Benefits: Asymptotically, FFT-based convolution is faster than direct convolution
for larger filters and input sizes, as FFT has O(NlogN) complexity compared to
O(N2) for direct 1D convolution.
● Limitations:
○ Overhead: FFT/IFFT computations have significant overhead, making this
method less efficient for small filters or small input sizes where the transform
costs dominate the arithmetic savings from element-wise multiplication.40
○ Numerical Precision: FFTs can introduce numerical precision issues,
especially with lower-precision floating-point numbers.
○ Padding: Requires careful padding to achieve linear convolution and manage
boundary effects, which can increase the effective data size.
The choice among these convolution algorithms is not static; it depends on factors
like filter size, input/output channel counts, image dimensions, batch size, and the
specific capabilities of the target GPU (e.g., Tensor Core support for GEMM,
availability of optimized FFT libraries). Deep learning libraries like cuDNN often include
heuristics to select the most performant algorithm for a given convolution
configuration. This dynamic selection underscores the complexity of achieving
optimal convolution performance. Furthermore, the potential for numerical
discrepancies between algorithms means that speed must be balanced with the need
to maintain model accuracy, a recurring theme in high-performance deep learning.
Table 4: Comparison of Convolution Algorithms for Kernel Development
Algorithm Principle Arithmeti Memory Best Numerica Key
c Overhead Suited l Stability Kernel
Complexi (Transfor For Consider Implemen
ty mation) ations tation
(Qualitati Details
ve, e.g.,
for 3x3
filter)
Implicit Transform Relies on Low General Inherits Complex
Im2col+GE to matrix GEMM (implicit purpose, GEMM addressin
MM multiplicat complexit transform especially stability g for
ion y; efficient avoids when (e.g., FP32 on-the-fly
with materializi highly accumulat patch
Tensor ng large optimized ion in gathering;
Cores intermedia GEMM mixed leverages
te matrix) kernels precision). GEMM
35 are tiling and
available; memory
various optimizati
filter sizes. on
technique
s.
Winograd Minimal Reduces Moderate Small Can have Input/filter
F(2x2,3x3) multiplicat multiplicat (for filters higher /output
ion ions transform (e.g., 3x3, numerical transform
algorithm significant ed tiles in 5x5), error than stages;
in ly (e.g., shared specific direct element-w
transform 2.25x for memory) tile sizes.37 convolutio ise
ed domain F(2x2,3x3) n; product of
) 39 sensitive transform
to ed tiles.
transform
points.36
FFT-base Convolutio O(NlogN) High (for Large Can have Efficient
d n for FFTs + padded filters, precision FFT/IFFT
theorem: O(N) for inputs/filte large input issues; kernels;
multiplicat element-w rs and sizes requires element-w
ion in ise complex where FFT careful ise
frequency product represent overhead handling complex
domain (N is ations) is of multiplicat
padded amortized. complex ion;
size) 40 numbers. padding
managem
ent.
8. Conclusion
The pursuit of high-performance GPU kernels for deep learning operations is a
multifaceted discipline, demanding a synergistic understanding of algorithms, parallel
programming paradigms, and underlying hardware architectures. This report has
traversed several key DL operations, elucidating the intricate strategies employed to
maximize their efficiency on GPUs. A few overarching principles emerge from this
exploration.
First, effective management of the GPU memory hierarchy is paramount. The
performance of most deep learning kernels is predominantly dictated by memory
bandwidth rather than raw computational throughput. Techniques such as tiling for
data reuse in shared memory and registers (as seen in GEMM and FlashAttention),
coalesced global memory access (critical for element-wise operations and data
loading stages of all kernels), and minimizing HBM traffic through kernel fusion (e.g.,
fused LayerNorm, or FlashAttention's avoidance of materializing the N×N attention
matrix) are fundamental to mitigating the memory bottleneck.1
Second, exploiting the massive parallelism of GPUs at all levels is essential. This
involves not only launching a sufficient number of threads and thread blocks to
saturate the SMs but also carefully partitioning work among warps within a block to
optimize communication (e.g., shared memory usage in FlashAttention-2) and ensure
balanced workloads.3 Avoiding performance limiters like shared memory bank
conflicts and thread divergence through careful data layout and control flow design is
equally crucial.4
Third, the strategic use of reduced-precision arithmetic (mixed-precision
FP16/BF16 with FP32 accumulation, or INT8 quantization) offers substantial speedups
and memory savings. However, this introduces numerical challenges that necessitate
careful handling, such as master FP32 weights, loss scaling for FP16 training, and
precise scaling factor management for INT8 operations, to maintain model accuracy.11
Fourth, there is a clear trend towards algorithm-hardware co-design. Operations
like FlashAttention are explicitly designed with GPU memory characteristics (HBM vs.
SRAM, IO costs) in mind.2 Similarly, the proliferation of specialized hardware units like
NVIDIA's Tensor Cores has reshaped how operations like GEMM and, by extension,
im2col-based convolutions are implemented, with libraries like CUTLASS providing
abstractions to efficiently target these units.3 The evolution from FlashAttention to
FlashAttention-2 and FlashAttention-3, adapting to new hardware features like TMA,
further exemplifies this co-design principle.25
The development of high-performance kernels is an iterative process. As illustrated by
the progression of GEMM optimization techniques or the evolution of FlashAttention,
even highly optimized algorithms can often be further refined through deeper analysis
of hardware behavior, workload characteristics, and emerging architectural features.
There is rarely a one-size-fits-all solution; the optimal approach for a convolution, for
instance, depends heavily on parameters like filter size and batch size, leading to the
use of different algorithms (implicit GEMM, Winograd, FFT) in different scenarios.35
Looking ahead, the complexity of manual kernel optimization for an ever-expanding
array of DL models and rapidly evolving GPU architectures underscores the growing
importance of automated kernel generation and optimization through advanced
compilers and domain-specific languages. Tools like Triton, which allow expression
of parallel algorithms at a higher level while still enabling fine-grained control and
generating efficient code, represent a significant step in this direction.8 We can
anticipate continued research into more sophisticated compiler techniques, more
powerful hardware abstractions, and novel algorithms that are increasingly aware of
and adaptive to the underlying hardware, pushing the boundaries of deep learning
performance.
Works cited
1. Deep Dive into Matrix Optimization on AMD GPUs - seb-v, accessed June 9, 2025,
https://seb-v.github.io/optimization/update/2025/01/20/Fast-GPU-Matrix-multiplic
ation.html
2. Flash attention(Fast and Memory-Efficient Exact Attention with IO-Awareness): A
deep dive, accessed June 9, 2025,
https://towardsdatascience.com/flash-attention-fast-and-memory-efficient-exac
t-attention-with-io-awareness-a-deep-dive-724af489997b/
3. Efficient GEMM in CUDA — NVIDIA CUTLASS Documentation, accessed June 9,
2025, https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html
4. Optimizing GPU Kernels | Elijah's Notes, accessed June 9, 2025,
https://notes.elimelt.com/llm-serving-systems/optimizing-gpu-kernels.html
5. Kernel Case Study: Flash Attention | Towards Data Science, accessed June 9,
2025, https://towardsdatascience.com/kernel-case-study-flash-attention/
6. Layer Fusion - Aussie AI, accessed June 9, 2025,
https://www.aussieai.com/research/layer-fusion
7. fused_layer_norm-API Document-PaddlePaddle Deep Learning ..., accessed June
9, 2025,
https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/incubate/nn
/functional/fused_layer_norm_en.html
8. GPU MODE Lecture 7: Advanced Quantization – Christian Mills, accessed June 9,
2025, https://christianjmills.com/posts/cuda-mode-notes/lecture-007/
9. Optimizing Matrix Multiplication: Unveiling the Power of Tiles - IndiaAI, accessed
June 9, 2025,
https://indiaai.gov.in/article/optimizing-matrix-multiplication-unveiling-the-power
-of-tiles
10.GEMM Optimization: Achieving Coalesced and Bank Conflict-Free Shared
Memory Access, accessed June 9, 2025,
https://forums.developer.nvidia.com/t/gemm-optimization-achieving-coalesced-
and-bank-conflict-free-shared-memory-access/319329
11. Automatic Mixed Precision Using PyTorch | DigitalOcean, accessed June 9, 2025,
https://www.digitalocean.com/community/tutorials/automatic-mixed-precision-us
ing-pytorch
12.Mixed Precision Training - Paperspace Blog, accessed June 9, 2025,
https://blog.paperspace.com/mixed-precision-training-overview/
13.CUTLASS GEMM API - NVIDIA Docs, accessed June 9, 2025,
https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html
14.Cublas-LT Int8 matrix multiplication - YouTube, accessed June 9, 2025,
https://www.youtube.com/watch?v=P-Cyjrgt2eY
15.Working with Quantized Types — NVIDIA TensorRT Documentation, accessed
June 9, 2025,
https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-quant
ized-types.html
16.Numberically Stable Softmax - Brian Lester, accessed June 9, 2025,
https://blester125.com/blog/softmax.html
17.LayerNorm — AWS Neuron Documentation, accessed June 9, 2025,
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/l
ayernorm.html
18.Optimizing a Layer Normalization Kernel with CUDA: a Worklog, accessed June 9,
2025, https://aryagxr.com/blogs/cuda-optimizing-layernorm
19.Pre-RMSNorm and Pre-CRMSNorm Transformers: Equivalent and ..., accessed
June 9, 2025,
https://proceedings.neurips.cc/paper_files/paper/2023/file/8f1bacee31caf990a4f0
8d84f0ccb322-Paper-Conference.pdf
20.[2205.14135] FlashAttention: Fast and Memory-Efficient Exact ..., accessed June 9,
2025, https://ar5iv.labs.arxiv.org/html/2205.14135
21.FlashAttention: Fast Transformer training with long sequences - Adept AI,
accessed June 9, 2025, https://www.adept.ai/blog/flashier-attention
22.FlashAttention, accessed June 9, 2025,
https://llmsystem.github.io/llmsystem2024spring/assets/files/Group-FlashAttentio
n-0b70d553037a7729dd2a9af5e23d8b3e.pdf
23.Understanding Flash Attention: Writing the Algorithm from Scratch in Triton - Alex
Dremov, accessed June 9, 2025,
https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-fro
m-scratch-in-triton/
24.Understanding Flash Attention: Writing the Algorithm from Scratch in Triton,
accessed June 9, 2025,
https://towardsdatascience.com/understanding-flash-attention-writing-the-algor
ithm-from-scratch-in-triton-5609f0b143ea/
25.FlashAttention-2: Faster Attention with Better Parallelism and Work ..., accessed
June 9, 2025, https://arxiv.org/pdf/2307.08691
26.FLASH-D: FlashAttention with Hidden Softmax Division - arXiv, accessed June 9,
2025, https://www.arxiv.org/pdf/2505.14201
27.Causal Self-Attention - Kaggle, accessed June 9, 2025,
https://www.kaggle.com/code/aisuko/causal-self-attention
28.Kvax: Fast and easy-to-use Flash Attention implementation for JAX - Nebius,
accessed June 9, 2025,
https://nebius.com/blog/posts/kvax-open-source-flash-attention-for-jax
29.FlashMask: Efficient and Rich Mask Extension of FlashAttention - arXiv, accessed
June 9, 2025, https://arxiv.org/html/2410.01359v1
30.FlashMask: Efficient and Rich Mask Extension of FlashAttention - OpenReview,
accessed June 9, 2025, https://openreview.net/forum?id=wUtXB43Chi
31.FlashAttention-3: Fast and Accurate Attention with Asynchrony and
Low-precision - arXiv, accessed June 9, 2025, https://arxiv.org/pdf/2407.08608
32.Optimizing the GPU kernel — CUDA training materials documentation, accessed
June 9, 2025, https://enccs.github.io/cuda/3.01_ParallelReduction/
33.a-hamdi/GPU: 100 days of building GPU kernels! - GitHub, accessed June 9, 2025,
https://github.com/a-hamdi/GPU
34.Im2col GEMM converted from the convolution in Fig. 1. The red boxed data show
duplicated accesses. - ResearchGate, accessed June 9, 2025,
https://www.researchgate.net/figure/m2col-GEMM-converted-from-the-convolu
tion-in-Fig-1-The-red-boxed-data-show-duplicated_fig2_332186100
35.Characterizing and Demystifying the Implicit Convolution Algorithm ..., accessed
June 9, 2025,
https://cs.sjtu.edu.cn/~leng-jw/resources/Files/zhou21iiswc-im2col.pdf
36.[2201.10369] Winograd Convolution for Deep Neural Networks: Efficient Point
Selection - arXiv, accessed June 9, 2025, https://arxiv.org/abs/2201.10369
37.Pruning of Winograd and FFT Based Convolution Algorithm - CS231n, accessed
June 9, 2025, http://cs231n.stanford.edu/reports/2016/pdfs/117_Report.pdf
38.Fast Algorithms For Convolutional Neural Networks: Andrew Lavin Scott Gray
Nervana Systems | PDF - Scribd, accessed June 9, 2025,
https://www.scribd.com/document/369456919/NervanaFastConv1509-09308v2
39.arXiv:1509.09308v2 [cs.NE] 10 Nov 2015, accessed June 9, 2025,
https://arxiv.org/pdf/1509.09308
40.machine learning - Convolutional neural network fast fourier ..., accessed June 9,
2025,
https://datascience.stackexchange.com/questions/17287/convolutional-neural-net
work-fast-fourier-transform
41.A fast Fourier convolutional deep neural network for accurate and explainable
discrimination of wheat yellow rust and nitrogen deficiency from Sentinel-2 time
series data - PMC - PubMed Central, accessed June 9, 2025,
https://pmc.ncbi.nlm.nih.gov/articles/PMC10582577/