CUDA-GEMM优化简介
通用矩阵乘法(GEMM)是深度学习和科学计算中的核心操作之一。在NVIDIA GPU上优化GEMM性能对于提升整体计算效率至关重要。本文将详细介绍CUDA-GEMM的优化技术,从基础实现开始,逐步深入探讨各种高级优化策略。
优化基础
在开始深入优化之前,我们需要了解CUDA编程的一些基本概念:
- 线程层次结构:CUDA使用网格(Grid)、线程块(Block)和线程(Thread)的层次结构。
- 内存层次结构:包括全局内存、共享内存、寄存器等。
- 内存访问模式:合并访问(Coalesced Access)对性能影响很大。
这些基础知识是我们进行GEMM优化的理论基础。
GEMM优化策略
1. 内存访问优化
最基本的GEMM实现可能存在非合并的全局内存访问问题。优化的第一步是确保合并访问:
__global__ void gemm_kernel_v01(const float* A, const float* B, float* C, int M, int N, int K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; ++k) {
sum += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
}
这个版本实现了合并访问,性能相比基础版本有显著提升。
2. 二维块优化
接下来,我们引入二维块优化:
template <int BLOCK_SIZE>
__global__ void gemm_kernel_v02(const float* A, const float* B, float* C, int M, int N, int K) {
__shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int row = by * BLOCK_SIZE + ty;
int col = bx * BLOCK_SIZE + tx;
float sum = 0.0f;
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; ++i) {
if (row < M && i * BLOCK_SIZE + tx < K)
As[ty][tx] = A[row * K + i * BLOCK_SIZE + tx];
else
As[ty][tx] = 0.0f;
if (col < N && i * BLOCK_SIZE + ty < K)
Bs[ty][tx] = B[(i * BLOCK_SIZE + ty) * N + col];
else
Bs[ty][tx] = 0.0f;
__syncthreads();
for (int k = 0; k < BLOCK_SIZE; ++k)
sum += As[ty][k] * Bs[k][tx];
__syncthreads();
}
if (row < M && col < N)
C[row * N + col] = sum;
}
这个版本通过使用共享内存来减少全局内存访问,提高了计算效率。
3. 线程优化
在块优化的基础上,我们可以进一步引入线程优化:
template <int BLOCK_SIZE, int THREAD_SIZE_X, int THREAD_SIZE_Y>
__global__ void gemm_kernel_v04(const float* A, const float* B, float* C, int M, int N, int K) {
__shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int row = by * BLOCK_SIZE + ty;
int col = bx * BLOCK_SIZE + tx;
float sum[THREAD_SIZE_Y][THREAD_SIZE_X] = {0.0f};
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; ++i) {
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (row + m * blockDim.y < M && i * BLOCK_SIZE + tx + n * blockDim.x < K)
As[ty + m * blockDim.y][tx + n * blockDim.x] = A[(row + m * blockDim.y) * K + i * BLOCK_SIZE + tx + n * blockDim.x];
else
As[ty + m * blockDim.y][tx + n * blockDim.x] = 0.0f;
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (col + n * blockDim.x < N && i * BLOCK_SIZE + ty + m * blockDim.y < K)
Bs[ty + m * blockDim.y][tx + n * blockDim.x] = B[(i * BLOCK_SIZE + ty + m * blockDim.y) * N + col + n * blockDim.x];
else
Bs[ty + m * blockDim.y][tx + n * blockDim.x] = 0.0f;
__syncthreads();
for (int k = 0; k < BLOCK_SIZE; ++k)
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
sum[m][n] += As[ty + m * blockDim.y][k] * Bs[k][tx + n * blockDim.x];
__syncthreads();
}
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (row + m * blockDim.y < M && col + n * blockDim.x < N)
C[(row + m * blockDim.y) * N + col + n * blockDim.x] = sum[m][n];
}
这个版本通过让每个线程计算多个输出元素,进一步提高了计算密度。
4. 矩阵转置优化
为了进一步优化内存访问模式,我们可以考虑对输入矩阵进行转置:
template <int BLOCK_SIZE, int THREAD_SIZE_X, int THREAD_SIZE_Y>
__global__ void gemm_kernel_v05(const float* A, const float* B, float* C, int M, int N, int K) {
__shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int row = by * BLOCK_SIZE + ty;
int col = bx * BLOCK_SIZE + tx;
float sum[THREAD_SIZE_Y][THREAD_SIZE_X] = {0.0f};
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; ++i) {
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (row + m * blockDim.y < M && i * BLOCK_SIZE + tx + n * blockDim.x < K)
As[ty + m * blockDim.y][tx + n * blockDim.x] = A[(row + m * blockDim.y) * K + i * BLOCK_SIZE + tx + n * blockDim.x];
else
As[ty + m * blockDim.y][tx + n * blockDim.x] = 0.0f;
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (col + n * blockDim.x < N && i * BLOCK_SIZE + ty + m * blockDim.y < K)
Bs[tx + n * blockDim.x][ty + m * blockDim.y] = B[(col + n * blockDim.x) * K + i * BLOCK_SIZE + ty + m * blockDim.y];
else
Bs[tx + n * blockDim.x][ty + m * blockDim.y] = 0.0f;
__syncthreads();
for (int k = 0; k < BLOCK_SIZE; ++k)
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
sum[m][n] += As[ty + m * blockDim.y][k] * Bs[tx + n * blockDim.x][k];
__syncthreads();
}
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (row + m * blockDim.y < M && col + n * blockDim.x < N)
C[(row + m * blockDim.y) * N + col + n * blockDim.x] = sum[m][n];
}
这个版本通过转置B矩阵,优化了内存访问模式,提高了缓存命中率。
5. Warp优化
最后,我们可以引入Warp级别的优化:
template <int BLOCK_SIZE, int WARP_SIZE, int THREAD_SIZE_X, int THREAD_SIZE_Y>
__global__ void gemm_kernel_v06(const float* A, const float* B, float* C, int M, int N, int K) {
__shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int warpId = (ty * blockDim.x + tx) / WARP_SIZE;
int laneId = (ty * blockDim.x + tx) % WARP_SIZE;
int warpRow = warpId / (BLOCK_SIZE / WARP_SIZE);
int warpCol = warpId % (BLOCK_SIZE / WARP_SIZE);
int row = by * BLOCK_SIZE + warpRow * WARP_SIZE + laneId / (WARP_SIZE / THREAD_SIZE_Y);
int col = bx * BLOCK_SIZE + warpCol * WARP_SIZE + laneId % (WARP_SIZE / THREAD_SIZE_X);
float sum[THREAD_SIZE_Y][THREAD_SIZE_X] = {0.0f};
for (int i = 0; i < (K + BLOCK_SIZE - 1) / BLOCK_SIZE; ++i) {
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (row + m * (WARP_SIZE / THREAD_SIZE_Y) < M && i * BLOCK_SIZE + tx + n * blockDim.x < K)
As[warpRow * WARP_SIZE + laneId / (WARP_SIZE / THREAD_SIZE_Y) + m * (WARP_SIZE / THREAD_SIZE_Y)][tx + n * blockDim.x] =
A[(row + m * (WARP_SIZE / THREAD_SIZE_Y)) * K + i * BLOCK_SIZE + tx + n * blockDim.x];
else
As[warpRow * WARP_SIZE + laneId / (WARP_SIZE / THREAD_SIZE_Y) + m * (WARP_SIZE / THREAD_SIZE_Y)][tx + n * blockDim.x] = 0.0f;
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (col + n * (WARP_SIZE / THREAD_SIZE_X) < N && i * BLOCK_SIZE + ty + m * blockDim.y < K)
Bs[tx + n * blockDim.x][warpCol * WARP_SIZE + laneId % (WARP_SIZE / THREAD_SIZE_X) + m * (WARP_SIZE / THREAD_SIZE_X)] =
B[(col + n * (WARP_SIZE / THREAD_SIZE_X)) * K + i * BLOCK_SIZE + ty + m * blockDim.y];
else
Bs[tx + n * blockDim.x][warpCol * WARP_SIZE + laneId % (WARP_SIZE / THREAD_SIZE_X) + m * (WARP_SIZE / THREAD_SIZE_X)] = 0.0f;
__syncthreads();
for (int k = 0; k < BLOCK_SIZE; ++k)
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
sum[m][n] += As[warpRow * WARP_SIZE + laneId / (WARP_SIZE / THREAD_SIZE_Y) + m * (WARP_SIZE / THREAD_SIZE_Y)][k] *
Bs[k][warpCol * WARP_SIZE + laneId % (WARP_SIZE / THREAD_SIZE_X) + m * (WARP_SIZE / THREAD_SIZE_X)];
__syncthreads();
}
for (int m = 0; m < THREAD_SIZE_Y; ++m)
for (int n = 0; n < THREAD_SIZE_X; ++n)
if (row + m * (WARP_SIZE / THREAD_SIZE_Y) < M && col + n * (WARP_SIZE / THREAD_SIZE_X) < N)
C[(row + m * (WARP_SIZE / THREAD_SIZE_Y)) * N + col + n * (WARP_SIZE / THREAD_SIZE_X)] = sum[m][n];
}
这个版本通过引入Warp级别的优化,以提高计算的并行效率。