Somik's Blog

WorkLog: Optimizing Convolution

Convolution is a fundamental mathematical operation used extensively in signal processing, image processing and deep learning(particularly in Convolutional Neural Networks).

Mathematical Definition

Continuous-time Convolution

For two functions f(t) and g(t):

(f*g)(t)=f(τ)g(tτ)dτ

where,

Discrete-time Convolution

y[n]=k=x[k]h[nk]

where,

In this post, I benchmark CUDA implementations of convolution and analyze their performance across 1D, 2D, and 3D convolution workloads.

1D Convolution

Constraints

Benchmarking

cuDNN

  1. Setup/ Description Phase
/* cuDNN setup */
cudnnHandle_t cudnn; // Declare cuDNN execution context required for all cuDNN operations
CUDNN_CHECK(cudnnCreate(&cudnn));
cudnnTensorDescriptor_t xDesc, yDesc; // Declares tensor descriptors in cuDNN. Do not store data - only describe how it is laid out in memory
cudnnFilterDescriptor_t filterDesc;
cudnnConvolutionDescriptor_t convDesc;
CUDNN_CHECK(cudnnCreateTensorDescriptor(&xDesc));
CUDNN_CHECK(cudnnCreateTensorDescriptor(&yDesc));
CUDNN_CHECK(cudnnCreateFilterDescriptor(&filterDesc));
CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&convDesc));
  1. Configuration Phase
// Set tensor and filter descriptors
CUDNN_CHECK(cudnnSetTensor4dDescriptor(xDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, 1, N, 1));
CUDNN_CHECK(cudnnSetTensor4dDescriptor(yDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, 1, outSize, 1));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 1, 1, K, 1));
// Set convolution descriptor
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(convDesc,
                                            0, 0, // pad height, width
                                            1, 1, // vertical, horizontal stride
                                            1, 1, // dilation height, width
                                            CUDNN_CROSS_CORRELATION,
                                            CUDNN_DATA_FLOAT));

Mode: CUDNN_CROSS_CORRELATION

Compute Type: CUDNN_DATA_FLOAT

  1. Execution Phase
cudnnConvolutionFwdAlgoPerf_t perf;
int returnedAlgoCount = 0;

CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
            cudnn,
            xDesc,
            filterDesc,
            convDesc,
            yDesc,
            1, // max number of algorithms to return
            &returnedAlgoCount,
            &perf
));
cudnnConvolutionFwdAlgo_t algo = perf.algo;

// Workspace 
size_t workspaceSize = perf.memory;
void* d_workspace = nullptr;
if(workspaceSize > 0){
    CUDA_CHECK(cudaMalloc(&d_workspace, workspaceSize));
}
Field Meaning
algo Which algorithm was chosen
time Estimated execution time (ms)
memory Workspace size required
status Whether this algo is valid
mathType Tensor Core / default math

Actual Execution point

CUDNN_CHECK(cudnnConvolutionForward(
                    cudnn,
                    &alpha,
                    xDesc,
                    d_input,
                    filterDesc,
                    d_filter,
                    convDesc,
                    algo,
                    d_workspace,
                    workspaceSize,
                    &beta,
                    yDesc,
                    d_output
                ));

Custom CUDA Kernels

Naive method

__global__ void convolution_1d_kernel(const float* input, const float* kernel, float* output,
                                      int input_size, int kernel_size) {

    int tid = threadIdx.x + blockDim.x * blockIdx.x;
    if(tid + kernel_size < input_size + 1){
        float temp = 0.0f;
        for(int i = 0;i < kernel_size;i++){
            temp += input[tid + i] * kernel[i];
        }
        output[tid] = temp;
    }    
}

Using shared memory

__global__ void convolution_1d_kernel_shared_mem(const float* input, const float* kernel, float* output,
                                      int input_size, int kernel_size) {
    
    extern __shared__ float sI[];   
    int tid = threadIdx.x;
    int base = blockDim.x * blockIdx.x;
    
    for(int i = tid; i < blockDim.x + kernel_size - 1; i+=blockDim.x){
        sI[i] = (base + i < input_size) ? input[base + i] : 0.0f;
        
    }
    __syncthreads();    

    int gid = base + tid;
    if(gid < input_size - kernel_size + 1){
        float temp = 0.0f;
        #pragma unroll
        for(int i = 0;i < kernel_size;i++){
            temp += kernel[i] * sI[tid + i];
        }
        output[gid] = temp;
    }

}

The same benchmarking methodology is extended to 2D and 3D convolutions, enabling a consistent comparison across dimensions.

Constraints for 2D Convolution

Constraints for 3D Convolution

Results

Figure 1: Latency Comparison of 1D convolution for different kernel sizes
Figure 2: Latency Comparison of 2D convolution for different kernel sizes
Figure 3: Latency Comparison of 3D convolution for different kernel sizes

Comments on the above results

Future Exploration

Future work will involve profiling with NVIDIA Nsight Systems and NVIDIA Nsight Compute to break down execution time into kernel launches, memory stalls, and compute utilization, providing deeper insight into why cuDNN underperforms compared to custom kernels in low-channel convolution settings.

References

#1D #2D #3D #CUDA #convolution