Skip to content

[Doc][Polish] gemm optimize by vectorize #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/07_optimize_matmul/matmul_shared.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@ int main()
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));

// Copy data to device
cudaMalloc((void **)&d_A, m * k * sizeof(float));
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));

// Copy matrices to device
cudaMemcpy(d_A, A, m * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, k * n * sizeof(float), cudaMemcpyHostToDevice);
Expand Down
5 changes: 0 additions & 5 deletions docs/07_optimize_matmul/matmul_tiled.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,6 @@ int main()
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));

// Copy data to device
cudaMalloc((void **)&d_A, m * k * sizeof(float));
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));

// Copy matrices to device
cudaMemcpy(d_A, A, m * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, k * n * sizeof(float), cudaMemcpyHostToDevice);
Expand Down
5 changes: 0 additions & 5 deletions docs/11_gemm_optimize/01_tiled2d/sgemm_tiled2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,6 @@ int main(int argc, char *argv[])
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));

// Copy data to device
cudaMalloc((void **)&d_A, m * k * sizeof(float));
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));

// Copy matrices to device
cudaMemcpy(d_A, A, m * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, k * n * sizeof(float), cudaMemcpyHostToDevice);
Expand Down
32 changes: 17 additions & 15 deletions docs/11_gemm_optimize/02_vectorize_smem_and_gmem_accesses/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# 向量化访存

向量化访存是指将多个内存访问操作合并为一个内存访问操作。这样可以减少内存访问的次数,提高内存访问的效率。在本节中,我们将介绍如何通过向量化访存来提高矩阵乘法的性能。
向量化访存是指将多个内存访问操作合并为一个内存访问操作。这样可以减少内存访问的次数,提高内存访问的效率。在高性能领域,这又可以叫做 SIMD(Single Instruction Multiple Data)。在 CPU 侧,有适用 intel 平台的 SSE 指令集,适用 Arm 端的 Neon 指令集;在 GPU 侧除 cuda 外还有 opencl 框架,这些工具都支持向量化读取和向量化计算。

而在本节中,我们将介绍如何通过向量化访存来提高矩阵乘法的性能。

## 1. 优化思路

Expand Down Expand Up @@ -29,7 +31,7 @@ LDS.128 指令可以一次性读取 4 个 float 类型的数据。

算法整体流程如下:

![picture 0](images/05eee538f6394ffc2ffffc2947edc8c888175af7152a150d697bfefb47db7a98.jpg)
![picture 0](images/05eee538f6394ffc2ffffc2947edc8c888175af7152a150d697bfefb47db7a98.jpg)

本 Kerne 和上一个 Kernel 的主要区别就在于如何加载数据到共享内存中。A 矩阵加载过程如下图所示:

Expand Down Expand Up @@ -144,12 +146,12 @@ for (uint dot_idx = 0; dot_idx < BK; ++dot_idx)
{
for (int m = 0; m < TM; m += 4)
{
FETCH_FLOAT4(reg_a[m]) =
FETCH_FLOAT4(reg_a[m]) =
FETCH_FLOAT4(smem_a[OFFSET(dot_idx, thread_row + m, BM)]);
}
for (int n = 0; n < TN; n += 4)
{
FETCH_FLOAT4(reg_b[n])
FETCH_FLOAT4(reg_b[n])
= FETCH_FLOAT4(smem_b[OFFSET(dot_idx, thread_col + n, BN)]);
}

Expand Down Expand Up @@ -178,29 +180,29 @@ for (int m = 0; m < TM; m++)
编译运行代码:

```bash
nvcc -o sgemm_vectorize sgemm_vectorize.cu
nvcc -o sgemm_vectorize sgemm_vectorize.cu
./sgemm_vectorize 256 256 256
```
```

## 3. 性能对比

我们将上该内核的性能和之前的内核进行比较,我们分别计算 256x256、512x512、1024x1024、2048x2048 (Matrix 1、Matrix 2、Matrix 3、Matrix 4、Matrix 5)的矩阵乘法的性能 (us)。在 1080Ti 上运行,结果如下:


| Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 |
| --------- | -------- | -------- | -------- | -------- |
| Naive | 95.5152 | 724.396 | 28424 | 228681 |
| 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 |
| 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 |
| 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 |
| 向量化访存 | 36.567 | 90.745 | 427.701 | 2901.475 |
| Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 |
| ---------------- | -------- | -------- | -------- | -------- |
| Naive | 95.5152 | 724.396 | 28424 | 228681 |
| 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 |
| 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 |
| 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 |
| 向量化访存 | 36.567 | 90.745 | 427.701 | 2901.475 |


## 4. 总结

本文我们介绍了一种优化矩阵乘法的方法:向量化访存。向量化访存是指将多个内存访问操作合并为一个内存访问操作。这样可以减少内存访问的次数,提高内存访问的效率。我们通过一个具体的例子介绍了如何使用向量化访存来提高矩阵乘法的性能。从实验结果可以看出,向量化访存的性能比二维 Thread Tile 的性能要好。因此向量化访存是一种提高矩阵乘法性能的有效方法。但是向量化访存对输入数据的要求比较高,需要输入数据是 4 的倍数。因此在实际应用中需要根据实际情况选择合适的优化方法。
本文我们介绍了一种优化矩阵乘法的方法:向量化访存。向量化访存是指将多个内存访问操作合并为一个内存访问操作。这样可以减少内存访问的次数,提高内存访问的效率。我们通过一个具体的例子介绍了如何使用向量化访存来提高矩阵乘法的性能。从实验结果可以看出,向量化访存的性能比二维 Thread Tile 的性能要好。因此向量化访存是一种提高矩阵乘法性能的有效方法。但是向量化访存对输入数据的要求比较高,需要输入数据是 4 的倍数。因此在实际应用中需要根据实际情况选择合适的优化方法。

## Reference
## Reference

1. https://siboehm.com/articles/22/CUDA-MMM
2. https://github.com/siboehm/SGEMM_CUDA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,26 @@
#include <cuda.h>
#include <cstdlib>

#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4 *>(&(pointer))[0])

void free_resource(float *ptr, int is_cuda = 1)
{
if (nullptr != ptr)
{
if (is_cuda)
{
cudaFree(ptr);
}
else
{
delete[] ptr;
}
}
ptr = nullptr;
}

void sgemm_naive_cpu(float *A, float *B, float *C, int M, int N, int K)
{
for (int x = 0; x < M; x++)
Expand All @@ -36,8 +52,8 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
const uint c_row = blockIdx.y;
const uint c_col = blockIdx.x;

const int block_row_thread = BN / TN;
const int block_col_thread = BM / TM;
const int block_row_thread = BM / TM;
const int block_col_thread = BN / TN;
// 一个线程负责计算 block 中 TM*TN 个元素
const int thread_num = block_row_thread * block_col_thread;

Expand Down Expand Up @@ -73,8 +89,8 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
C += c_row * BM * N + c_col * BN;

float thread_results[TM * TN] = {0.0};
// 每个线程搬运ldg_a_num轮,寄存器缓存ldg_a_num个float4元素,用于转置As矩阵
float ldg_reg_a[4 * ldg_a_num] = {0.};
// 转置时,只用大小为 4 的数组就可以
float ldg_reg_a[4] = {0.};
float reg_a[TM] = {0.0}; // 缓存 smem_a
float reg_b[TN] = {0.0}; // 缓存 smem_b

Expand All @@ -83,13 +99,12 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
{
for (int i = 0; i < BM; i += stride_a)
{
int ldg_index = i / stride_a * 4;
FETCH_FLOAT4(ldg_reg_a[ldg_index]) = FETCH_FLOAT4(A[OFFSET(i + inner_row_a, inner_col_a, K)]);
FETCH_FLOAT4(ldg_reg_a[0]) = FETCH_FLOAT4(A[OFFSET(i + inner_row_a, inner_col_a, K)]);
// smem_a 转置存,其中 ldg_reg_a 做中间缓存,目的是读取时可以按FLOAT4读取
smem_a[OFFSET(inner_col_a, i + inner_row_a, BM)] = ldg_reg_a[ldg_index];
smem_a[OFFSET(inner_col_a + 1, i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 1];
smem_a[OFFSET(inner_col_a + 2, i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 2];
smem_a[OFFSET(inner_col_a + 3, i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 3];
smem_a[OFFSET(inner_col_a, i + inner_row_a, BM)] = ldg_reg_a[0];
smem_a[OFFSET(inner_col_a + 1, i + inner_row_a, BM)] = ldg_reg_a[1];
smem_a[OFFSET(inner_col_a + 2, i + inner_row_a, BM)] = ldg_reg_a[2];
smem_a[OFFSET(inner_col_a + 3, i + inner_row_a, BM)] = ldg_reg_a[3];
}

for (int i = 0; i < BK; i += stride_b)
Expand Down Expand Up @@ -166,7 +181,7 @@ int main(int argc, char *argv[])

// Allocate memory for matrices
float *A, *B, *C, *C_ref;
float *d_A, *d_B, *d_C, *d_C_ref;
float *d_A, *d_B, *d_C;

A = new float[m * k];
B = new float[k * n];
Expand All @@ -183,17 +198,10 @@ int main(int argc, char *argv[])
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));

// Copy data to device
cudaMalloc((void **)&d_A, m * k * sizeof(float));
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));
cudaMalloc((void **)&d_C_ref, m * n * sizeof(float));

// Copy matrices to device
cudaMemcpy(d_A, A, m * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, k * n * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_C, C, m * n * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_C_ref, C_ref, m * n * sizeof(float), cudaMemcpyHostToDevice);

run_sgemm_vectorize(d_A, d_B, d_C, m, n, k);

Expand Down Expand Up @@ -230,5 +238,15 @@ int main(int argc, char *argv[])
cudaEventElapsedTime(&elapsed_time, start, stop);
float avg_run_time = elapsed_time * 1000 / 100;
printf("Average run time: %f us\n", avg_run_time);

free_resource(A, 0);
free_resource(B, 0);
free_resource(C, 0);
free_resource(C_ref, 0);

free_resource(d_A, 1);
free_resource(d_B, 1);
free_resource(d_C, 1);

return 0;
}
Loading