Somik's Blog

WorkLog: Optimizing GEMM

This is my first blog post in a series about going under the hood of GPU Programming and implementing custom CUDA kernels and ways to optimize them. As a case study, I will employ various methods to optimize GEMM(GEneral Matrix-Matrix Multiplication). The custom CUDA kernels will be benchmarkerd against cuBLAS(NVIDIA's official BLAS library).

Background

GEMM: GEneral Matrix-to-Matrix Multiplication

CαAB+βC

where,

For the purpose of this exercise, will be using

Benchmark Setup

Before diving into kernel implementations, let's have an insight into two key topics:

Compute-bound v/s Memory Bound

Compute-bound Memory-bound
GPU spends most of its time performing arithmetic operations GPU spends most of its time waiting for data from global memory / VRAM
Performance is limited by how fast GPU cores can compute FLOPs Performance is limited by memory bandwidth

Arithmetic Intensity(AI)

Measures how many floating-point operations are performed per byte of memory transferred. In other words,

AI=FLOPsBytes transferred

For SGEMM(Single-Precision GEMM)
a) Total FLOPs

FLOPs=2·M·N·K

b) Total bytes accessed

Total Memory traffic

Bytes=4·(M·K+K·N+2·M·N)

c) Theoretical AI for SGEMM(Single-Precision GEMM)

AI=2·M·N·K4·(M·K+K·N+2·M·N)
Note: Each benchmark is measured over 10 runs after a warm-up phase to minimize the impact of GPU frequency scaling, system scheduling, CUDA runtime overhead, and cache effects.

Individual run times may fluctuate due to these factors, so we report the minimum execution time rather than the average.

The minimum time represents the run with the least system interference, providing a clearer estimate of the kernel’s true performance potential.

Timing is performed using CUDA events to measure kernel execution time on the device.

Benchmark: cuBLAS implementation

cublasHandle_t handle;
CUBLAS_CHECK(cublasCreate(&handle));

CUBLAS_CHECK(cublasSgemm(
            handle,
            CUBLAS_OP_N, CUBLAS_OP_N,    
            N, M, K,                     
            &alpha,
            d_B, N,                       
            d_A, K,                       
            &beta,
            d_C, M                      
        ));

Note: cuBLAS also provides a transpose flag: CUBLAS_OP_T, which allows to treat inputs as transposed. In principle, this can be used to avoid explicitly transposing the matrix before call. However, for the purpose of our exercise, I have used explicit transposition to keep things straightforward and will revisit CUBLAS_OP_T in a future article.

Kernel 1: Using Naive Method

In the first approach,

In other words, the kernel does

Cij=k=0K1AikBkj
__global__ void GEMM(const float* A, const float* B, float* C, int M, int N, int K, float alpha, float beta){
    int row = threadIdx.y + blockDim.y * blockIdx.y;
    int col = threadIdx.x + blockDim.x * blockIdx.x;
    if(row < M && col < N){
        float temp = 0;
        for(int i = 0; i < K; i++){
            temp += A[row*K + i] * B[i*N + col];
        }
        C[row*N + col] = alpha * temp + beta* C[row*N + col];
    }
}
Naive Kernel

The above implementation is highly inefficient on modern GPUs because it is memory-bound rather than compute-bound.

Kernel 2: Shared Memory & Tiling

__global__ void GEMM(const float* A, const float* B, float* C, int M, int N, int K, float alpha, float beta){
    int row = threadIdx.y + blockDim.y * blockIdx.y;
    int col = threadIdx.x + blockDim.x * blockIdx.x;

    __shared__ float sA[TILE_SIZE][TILE_SIZE + 1];
    __shared__ float sB[TILE_SIZE][TILE_SIZE + 1];

    float temp = 0.0f;
    for(int i = 0;i < (K + TILE_SIZE - 1)/TILE_SIZE; i++)
    {
        // Load from tile A
        if(row < M && threadIdx.x + TILE_SIZE * i < K)
        {
            sA[threadIdx.y][threadIdx.x] = (A[row * K + threadIdx.x + TILE_SIZE*i]);
        }
        else{
            sA[threadIdx.y][threadIdx.x] = 0.0f;
        }

        if(col < N && threadIdx.y + TILE_SIZE * i < K)
        {
            sB[threadIdx.y][threadIdx.x] = (B[(threadIdx.y + TILE_SIZE * i)*N + col]);
        }
        else{
            sB[threadIdx.y][threadIdx.x] = 0.0f;
        }

        __syncthreads();

        #pragma unroll
        for(int j = 0;j < TILE_SIZE; j++)
        {
            temp += sA[threadIdx.y][j] * sB[j][threadIdx.x];
        }
        __syncthreads();

    } 
    if(row < M && col < N){
        C[row*N + col] = (alpha * temp + beta* (C[row*N + col]));
    }

}
Shared Memory Tiling Kernel

Kernel 3: 1D Thread Tiling

Kernel implementation as follows:

#define THREAD_TILE 4 // 1 thread → computes 4 elements of C
__global__ void GEMM(const float* A, const float* B, float*C, int M, int N, int K, float alpha, float beta){
    int row = threadIdx.y + blockIdx.y * TILE_SIZE;
    int col = threadIdx.x * THREAD_TILE + blockIdx.x * TILE_SIZE;

    __shared__ float sA[TILE_SIZE][TILE_SIZE + 1];
    __shared__ float sB[TILE_SIZE][TILE_SIZE + 1];

    float temp[THREAD_TILE];
    for (int i = 0; i < THREAD_TILE; i++)
        temp[i] = 0.0f;

    for(int i = 0; i < (K + TILE_SIZE - 1)/TILE_SIZE; i++){
        // LOAD MATRIX A
        for(int j = 0; j < THREAD_TILE; j++){
            int gCol = i * TILE_SIZE + threadIdx.x * THREAD_TILE + j;
            if(row < M && gCol < K){
                sA[threadIdx.y][threadIdx.x * THREAD_TILE + j] = A[row * K + gCol];
            }
            else{
                sA[threadIdx.y][threadIdx.x * THREAD_TILE + j] = (0.0f);
            }

            int gRow = i * TILE_SIZE + threadIdx.y; 
            if(gRow < K && col + j< N){
                sB[threadIdx.y][threadIdx.x * THREAD_TILE + j] = B[gRow * N + col + j];
            }
            else{
                sB[threadIdx.y][threadIdx.x * THREAD_TILE + j] = (0.0f);
            }

        } 

        __syncthreads();

        for(int j = 0; j < TILE_SIZE;j++){
            float a_val = (sA[threadIdx.y][j]);
            for(int k = 0; k < THREAD_TILE; k++){
                float b_val = (sB[j][threadIdx.x * THREAD_TILE + k]);
                temp[k] += a_val * b_val;
            }
        }
        __syncthreads();
    }

    for(int j = 0; j < THREAD_TILE;j++){
        if(row < M && col + j< N){
            C[row * N + col + j] = (alpha * temp[j] + beta * (C[row * N + col + j]));
        }
    }
}
Figure 1: 1D thread tiling for different TILE_SIZE, THREAD_TILE

Kernel 5: 2D Thread Tiling

1D Thread Tiling

Thread:
C[i][j] += A[i][k] * B[k][j]
C[i][j+1] += A[i][k] * B[k][j+1]
C[i][j+2] += A[i][k] * B[k][j+2]
C[i][j+3] += A[i][k] * B[k][j+3]

2D Thread Tiling

Thread:
C[i    ][j    ] += A[i    ][k] * B[k][j    ]
C[i    ][j + 1] += A[i    ][k] * B[k][j + 1]
C[i + 1][j    ] += A[i + 1][k] * B[k][j    ]
C[i + 1][j + 1] += A[i + 1][k] * B[k][j + 1]
Each loaded value of A and B is reused multiple times, increasing the amount of computation performed per memory load.
__global__ void GEMM_2D_thread_tiled(const float* A, const float* B, float*C, int M, int N, int K, float alpha, float beta){
    int row = threadIdx.y * THREAD_TILE_M + blockIdx.y * TILE_SIZE;
    int col = threadIdx.x * THREAD_TILE_N + blockIdx.x * TILE_SIZE;

    __shared__ float sA[TILE_SIZE][TILE_SIZE + 1];
    __shared__ float sB[TILE_SIZE][TILE_SIZE + 1];

    float temp[THREAD_TILE_M][THREAD_TILE_N];
    for (int i = 0; i < THREAD_TILE_M; i++)
    {
        for(int j = 0;j < THREAD_TILE_N;j++)
        {
            temp[i][j] = 0.0f;
        }
    }
    for(int i = 0; i < (K + TILE_SIZE - 1)/TILE_SIZE; i++){
        
        // LOAD MATRIX A
        for(int j = 0; j < THREAD_TILE_N; j++){
            for(int k = 0; k < THREAD_TILE_M; k++){
                int sRow = threadIdx.y * THREAD_TILE_M + k;
                int sCol = threadIdx.x * THREAD_TILE_N + j;
                int gRow = blockIdx.y * TILE_SIZE + sRow;
                int gCol = i  * TILE_SIZE + sCol;
                if(gRow < M && gCol < K){
                    sA[sRow][sCol] = A[gRow * K + gCol];
                }
                else{
                    sA[sRow][sCol] = (0.0f);
                }
            }
        } 

        // LOAD MATRIX B
        for(int j = 0; j < THREAD_TILE_N; j++)
        {
            for(int k = 0;k < THREAD_TILE_M; k++){
                int sRow = threadIdx.y * THREAD_TILE_M + k;
                int sCol = threadIdx.x * THREAD_TILE_N + j;
                int gRow = i * TILE_SIZE + sRow;
                int gCol = blockIdx.x * TILE_SIZE + sCol;
                if(gRow < K && gCol < N){
                    sB[sRow][sCol] = B[gRow * N + gCol];
                }
                else{
                    sB[sRow][sCol] =(0.0f);
                }
            }
        }

        __syncthreads();

        for(int m = 0;m < TILE_SIZE;m++){
            for(int k = 0; k < THREAD_TILE_M;k++){
                for(int l = 0;l < THREAD_TILE_N;l++){
                    float a_val = (sA[threadIdx.y * THREAD_TILE_M + k][m]);
                    float b_val = (sB[m][threadIdx.x* THREAD_TILE_N + l]);
                    temp[k][l] += a_val * b_val;
                }
            }
        }
            
        
        __syncthreads();
    }

    for(int i = 0; i < THREAD_TILE_M; i++){
        for(int j = 0; j < THREAD_TILE_N; j++){
            if(row + i < M && col + j < N)
                C[(row + i)*N + col + j] = alpha * temp[i][j] + beta * C[(row + i) * N + col + j];
        }
    }
}
2D Thread Tiling Diagram
Figure 2: 2D thread tiling for different TILE_SIZE, THREAD_TILE_M, THREAD_TILE_N

Benchmark Results

Figure 3: Comparison of Custom CUDA Kernels with cuBLAS

Matrix A: 8192 X 4096, Matrix B: 4096 X 8192

Kernel GFLOPS/s
Naive 692.4
Shared Memory 946.5
1D Thread Tiling(TILE=64, TT=4) 1583.3
2D Thread Tiling(TILE=64, TT_M=4, TT_N=4) 2783.6
cuBLAS 5538

Matrix A: 4096 X 8192, Matrix B: 8192 X 4096

Kernel GFLOPS/s
Naive 682.7
Shared Memory 944
1D Thread Tiling(TILE=64, TT=4) 1578.4
2D Thread Tiling(TILE=64, TT_M=4, TT_N=4) 2777.5
cuBlas 7910

Work in Progress

I plan to explore vectorized memory loading, warp-level optimizations, and double buffering to try to get close to cuBLAS performance.

Resources and References

Thank you for reading!!

#CUDA #GEMM