diff --git a/setup.py b/setup.py index 1c708741a4..726a3882b8 100644 --- a/setup.py +++ b/setup.py @@ -269,7 +269,7 @@ def get_extensions(): extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"], + "nvcc": ["-O3" if not debug_mode else "-O0", "-std=c++17", "-mllvm", "-amdgpu-early-inline-all=true", "-Wignore"], } if not IS_WINDOWS: @@ -363,8 +363,8 @@ def get_extensions(): # TOOD: Remove this and use what CUDA has once we fix all the builds. if IS_ROCM and use_cuda: # Add ROCm GPU architecture check - gpu_arch = torch.cuda.get_device_properties(0).name - if gpu_arch != "gfx942": + gpu_arch = torch.cuda.get_device_properties(0).gcnArchName + if "gfx942" not in gpu_arch: print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") print( "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index d59c5f552e..dfc707256c 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -50,3 +50,7 @@ OTHER BENCHMARKS 20240910010056, tok/s= 47.85, mem/s= 213.24 GB/s, peak_mem=11.85 GB, model_size= 4.46 GB quant: uintx-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240910010647, tok/s= 34.83, mem/s= 261.42 GB/s, peak_mem=14.99 GB, model_size= 7.51 GB quant: uintx-2-8, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-2-8 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240910110958, tok/s=223.95, mem/s= 682.88 GB/s, peak_mem= 5.59 GB, model_size= 3.05 GB quant: sparse-marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20250305151148, tok/s=156.37, tok/s_decode=159.58, ttft=0.0252, mem/s=2347.12 GB/s, peak_mem=16.36 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250305151355, tok/s= 98.66, tok/s_decode=100.88, ttft=0.0329, mem/s= 741.92 GB/s, peak_mem=10.56 GB, model_size= 7.52 GB quant: int8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20250305151549, tok/s= 47.66, tok/s_decode= 47.98, ttft=0.0272, mem/s= 201.20 GB/s, peak_mem= 6.57 GB, model_size= 4.22 GB quant: int4wo-64, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 \ No newline at end of file diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 2f50e47dcd..e28e84a64d 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -1,113 +1,12 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder -# README BENCHMARKS -export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt - -export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt # Runs on H100, float8 is not supported on CUDA arch < 8.9 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8wo --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-tensor --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-wo --write_result benchmark_results.txt - -# OTHER BENCHMARKS - -# kv cache quantization -export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization --linear_causal_mask -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 --kv_cache_quantization -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 --kv_cache_quantization --linear_causal_mask -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 --kv_cache_quantization -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 --kv_cache_quantization --linear_causal_mask -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 --kv_cache_quantization -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 --kv_cache_quantization --linear_causal_mask -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask - -export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt - -export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt - -# Different Batch Size Benchmarks -export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 1 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 128 - -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 1 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 128 - -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128 - -# TTFT benchmarks -export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured - -# gemlite benchmarks -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt - -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --batch_size 32 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --batch_size 32 - -# 2:4 sparse model -export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8wo --write_result benchmark_results.txt +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-tensor --write_result benchmark_results.txt +#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt diff --git a/torchao/csrc/cuda/sparse_marlin/base.h b/torchao/csrc/cuda/sparse_marlin/base.h index 513c53df3a..a96f23d274 100644 --- a/torchao/csrc/cuda/sparse_marlin/base.h +++ b/torchao/csrc/cuda/sparse_marlin/base.h @@ -26,14 +26,12 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template -struct Vec { +template struct Vec { T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } + __device__ T &operator[](int i) { return elems[i]; } }; -template -struct ShapeBase { +template struct ShapeBase { static constexpr int M = M_, N = N_, K = K_; }; @@ -42,10 +40,11 @@ using I4 = Vec; // Matrix fragments for tensor core instructions; their precise layout is // documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; using FragB = Vec; using FragM = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales -} // namespace torchao +} // namespace torchao diff --git a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu index 4f6980f29a..b40a770c02 100644 --- a/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu +++ b/torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu @@ -34,10 +34,11 @@ #include "mem.h" #include "mma.h" -template -inline std::string str(T x) { - return std::to_string(x); -} +#ifdef USE_ROCM +#include +#endif + +template inline std::string str(T x) { return std::to_string(x); } namespace torchao { @@ -54,69 +55,67 @@ static constexpr int max_par = 64; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && !defined(USE_ROCM) -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4 *__restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization ) {} -torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); +torch::Tensor marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_meta, torch::Tensor &b_scales, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); } #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4* __restrict__ meta, // 2bit metadata information about 2:4 - // format on B - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4 *__restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -176,22 +175,27 @@ __global__ void Marlin_24( auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; + if (col_off > 0) + slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; + if (col_off > 0) + slice_idx--; } } if (slice_col == n_tiles) { @@ -204,7 +208,7 @@ __global__ void Marlin_24( init_slice(); // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 32 * thread_k_blocks / 8; @@ -236,9 +240,9 @@ __global__ void Marlin_24( constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); constexpr int m_sh_wr_delta = threads / 2; @@ -302,7 +306,7 @@ __global__ void Marlin_24( // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; } @@ -322,13 +326,13 @@ __global__ void Marlin_24( // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < thread_m_blocks; j++) { a_sh_rd_trans[0][i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -341,23 +345,23 @@ __global__ void Marlin_24( // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4* meta_ptr[m_sh_iters]; - #pragma unroll + const int4 *meta_ptr[m_sh_iters]; +#pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - int4* sh_m = sh_s + (stages * s_sh_stage); + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_s = sh_b + (stages * b_sh_stage); + int4 *sh_m = sh_s + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks][2]; I4 frag_b_quant[2][b_thread_vecs]; @@ -367,34 +371,34 @@ __global__ void Marlin_24( // Zero accumulators. auto zero_accums = [&]() { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; } - int4* sh_meta_stage = sh_m + m_sh_stage * pipe; - #pragma unroll + int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; +#pragma unroll for (int i = 0; i < m_sh_iters; i++) { if (m_sh_wr_pred) cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); @@ -433,13 +437,13 @@ __global__ void Marlin_24( // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if constexpr (group_blocks != -1) { - int4* sh_s_stage = + int4 *sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { ldsm4(frag_a[k % 2][i][0], &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); @@ -447,24 +451,24 @@ __global__ void Marlin_24( &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } // Load meta with ldsm4 - int4* sh_m_stage = sh_m + m_sh_stage * pipe; + int4 *sh_m_stage = sh_m + m_sh_stage * pipe; ldsm4_m(frag_m[k % 2][0], &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -477,7 +481,7 @@ __global__ void Marlin_24( frag_b1 = dequant_4bit(b_quant_shift); } else { - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -494,7 +498,7 @@ __global__ void Marlin_24( scale(frag_b1, frag_s[k % 2][j], 1); } - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], frag_m[k % 2][j / 2], j % 2); @@ -515,41 +519,41 @@ __global__ void Marlin_24( int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - #pragma unroll +// Parallel logarithmic shared memory reduction. We make sure to avoid any +// unnecessary read or write iterations, e.g., for two warps we write only +// once by warp 1 and read only once by warp 0. +#pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll +#pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -571,7 +575,7 @@ __global__ void Marlin_24( int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; @@ -581,10 +585,10 @@ __global__ void Marlin_24( int col = 2 * ((threadIdx.x % 32) % 4); if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + @@ -596,32 +600,32 @@ __global__ void Marlin_24( cp_async_wait<0>(); } - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + col + (i % 2) < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll +#pragma unroll for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll +#pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2] += - __half2float( - reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); + ::__half2float( + reinterpret_cast<::__half *>(&c_red)[(j2 * 4 + j1)]); } } } if (!last) { int4 c; - #pragma unroll +#pragma unroll for (int j2 = 0; j2 < 2; j2++) { - #pragma unroll +#pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( + reinterpret_cast<::__half *>(&c)[(j2 * 4 + j1)] = + ::__float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2]); } @@ -640,9 +644,9 @@ __global__ void Marlin_24( auto write_result = [&]() { int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -651,22 +655,22 @@ __global__ void Marlin_24( c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + (threadIdx.x % (2 * 2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, - float c4, float c5, float c6, float c7, FragS& s1) { + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, + float c4, float c5, float c6, float c7, FragS &s1) { uint2 res[2]; res[0] = to_half4(c0, c1, c2, c3); res[1] = to_half4(c4, c5, c6, c7); - half2* tmp = (half2*)&res; + ::__half2 *tmp = (::__half2 *)&res; // for per-column quantization we finally apply the scale here if constexpr (group_blocks == -1 && num_bits == 4) { tmp[0] = __hmul2(tmp[0], s0[0]); @@ -674,12 +678,12 @@ __global__ void Marlin_24( tmp[2] = __hmul2(tmp[2], s1[0]); tmp[3] = __hmul2(tmp[3], s1[1]); } - ((int4*)sh)[idx] = *((int4*)&res[0]); + ((int4 *)sh)[idx] = *((int4 *)&res[0]); }; // RLC: only warp 0 and 1 baseline example if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { int wr = c_sh_wr; write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], @@ -704,7 +708,7 @@ __global__ void Marlin_24( } __syncthreads(); - #pragma unroll +#pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -718,8 +722,9 @@ __global__ void Marlin_24( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); +#pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -729,10 +734,10 @@ __global__ void Marlin_24( // Main loop. while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll +// We unroll over both the global fetch and the register load pipeline to +// ensure all shared memory accesses are static. Note that both pipelines have +// even length meaning that the next iteration will always start at index 0. +#pragma unroll for (int pipe = 0; pipe < stages;) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -743,7 +748,8 @@ __global__ void Marlin_24( pipe++; slice_iters--; - if (slice_iters == 0) break; + if (slice_iters == 0) + break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -757,11 +763,13 @@ __global__ void Marlin_24( // write-out if constexpr (group_blocks == -1) { if constexpr (num_bits == 8) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } else { if (last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } } @@ -773,14 +781,14 @@ __global__ void Marlin_24( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); } } else { if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); } } } @@ -791,7 +799,7 @@ __global__ void Marlin_24( // overflow in fp16) if constexpr (group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], @@ -820,13 +828,13 @@ __global__ void Marlin_24( } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; @@ -836,17 +844,19 @@ __global__ void Marlin_24( if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - #pragma unroll +#pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - #pragma unroll - for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] -= m_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -855,26 +865,28 @@ __global__ void Marlin_24( } } -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + const void *kernel_ptr = reinterpret_cast( \ + &Marlin_24); \ + cudaFuncSetAttribute(kernel_ptr, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ } -void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, - void* s, int prob_m, int prob_n, int prob_k, - void* workspace, int num_bits, int groupsize = -1, +void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, + void *s, int prob_m, int prob_n, int prob_k, + void *workspace, int num_bits, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_m = -1, int sms = -1, int max_par = 16) { int tot_n = prob_n; @@ -893,8 +905,8 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, if (thread_k == -1 || thread_m == -1) { if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important - // than better compute utilization + // For small batchizes, better partitioningif is slightly more + // important than better compute utilization thread_k = 128; thread_m = 128; } else if (prob_n <= 256) { @@ -906,7 +918,7 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, } } - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction int thread_m_blocks = thread_m / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; int blocks = sms; @@ -923,52 +935,53 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - const int4* meta_ptr = (const int4*)meta; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + const int4 *meta_ptr = (const int4 *)meta; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; constexpr int max_m_blocks = 4; - int* locks = (int*)workspace; + int *locks = (int *)workspace; for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; int par = 1; if (thread_n_blocks > max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding + // Note that parallel > 1 currently only works for inputs without + // any padding par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); - if (par > max_par) par = max_par; + if (par > max_par) + par = max_par; prob_n = (max_m_blocks * 16) * par; i += max_m_blocks * (par - 1); thread_n_blocks = max_m_blocks; } - // For compilation speed, we only define the kernel configurations that have - // seemed useful (in terms of performance) in our testing, however many more - // are, in principle, possible. + // For compilation speed, we only define the kernel configurations + // that have seemed useful (in terms of performance) in our testing, + // however many more are, in principle, possible. // the false is start of the CALL_IF macros if (false) { - } // BMxBNxBK, group + } // BMxBNxBK, group // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(4, 16, 2, 2, 4) CALL_IF_2_4(4, 16, 3, 2, -1) CALL_IF_2_4(4, 16, 3, 2, 4) CALL_IF_2_4(4, 16, 4, 2, -1) CALL_IF_2_4(4, 16, 4, 2, 4) - CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 CALL_IF_2_4(4, 32, 2, 1, 4) CALL_IF_2_4(4, 32, 3, 1, -1) CALL_IF_2_4(4, 32, 3, 1, 4) @@ -976,21 +989,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, CALL_IF_2_4(4, 32, 4, 1, 4) // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(8, 16, 2, 2, 4) CALL_IF_2_4(8, 16, 3, 2, -1) CALL_IF_2_4(8, 16, 3, 2, 4) CALL_IF_2_4(8, 16, 4, 2, -1) CALL_IF_2_4(8, 16, 4, 2, 4) - CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 CALL_IF_2_4(8, 32, 2, 1, 4) CALL_IF_2_4(8, 32, 3, 1, -1) CALL_IF_2_4(8, 32, 3, 1, 4) @@ -1010,12 +1023,10 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, } } -torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, - int64_t size_m, int64_t size_n, - int64_t size_k) { +torch::Tensor marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_meta, torch::Tensor &b_scales, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); @@ -1048,9 +1059,9 @@ torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, " is not divisible by tile_size = " + str(torchao::tile_size)); int actual_size_n = (b_q_weight.size(1) / torchao::tile_size) * pack_factor; - TORCH_CHECK( - size_n == actual_size_n, - "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK(size_n == actual_size_n, + "size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); // Verify meta TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, @@ -1092,7 +1103,7 @@ torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", is not divisible by b_scales.size(0) = " + str(b_scales.size(0))); groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 + groupsize /= 2; // Because of 24 } // Verify groupsize @@ -1100,12 +1111,11 @@ torch::Tensor marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, "Unexpected groupsize = " + str(groupsize)); // Verify workspace size - TORCH_CHECK(size_n % torchao::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(torchao::min_thread_n)); - int min_workspace_size = - (size_n / torchao::min_thread_n) * torchao::max_par; + TORCH_CHECK( + size_n % torchao::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + str(torchao::min_thread_n)); + int min_workspace_size = (size_n / torchao::min_thread_n) * torchao::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = " + str(workspace.numel()) + " is below min_workspace_size = " + str(min_workspace_size)); @@ -1126,4 +1136,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::marlin_24_gemm", &marlin_24_gemm); } -} // namespace torchao +} // namespace torchao diff --git a/torchao/csrc/cuda/sparse_marlin/mem.h b/torchao/csrc/cuda/sparse_marlin/mem.h index 1569e3cdda..c062286cab 100644 --- a/torchao/csrc/cuda/sparse_marlin/mem.h +++ b/torchao/csrc/cuda/sparse_marlin/mem.h @@ -21,78 +21,96 @@ namespace torchao { #ifdef USE_ROCM +#include +#include #include - // Convert generic pointer to shared memory address for ROCm -template -__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) { - // First get the address as a size_t to handle all pointer sizes - size_t addr = reinterpret_cast(ptr); - - // Extract the lower 32 bits which represent the shared memory offset - // This is safe because shared memory addresses are always within 32-bit range - return static_cast(addr & 0xFFFFFFFF); +template +__device__ __forceinline__ uint32_t cvta_to_shared(const T *ptr) { + // First get the address as a size_t to handle all pointer sizes + size_t addr = reinterpret_cast(ptr); + + // Extract the lower 32 bits which represent the shared memory offset + // This is safe because shared memory addresses are always within 32-bit range + return static_cast(addr & 0xFFFFFFFF); } #else // For CUDA, use the native intrinsic -template -__device__ __forceinline__ uint32_t cvta_to_shared(const T* ptr) { - return static_cast(__cvta_generic_to_shared(ptr)); +template +__device__ __forceinline__ uint32_t cvta_to_shared(const T *ptr) { + return static_cast(__cvta_generic_to_shared(ptr)); } #endif // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void* smem_ptr, - const void* glob_ptr, +__device__ inline void cp_async4_pred_zfill(void *smem_ptr, + const void *glob_ptr, bool pred = true, const bool zfill = false) { const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); uint32_t smem = cvta_to_shared(smem_ptr); - #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); - #else - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); - #endif +#ifdef USE_ROCM + // Simple approach using standard C++ operations + if (pred) { + // Load from global memory + uint4 data; + if (!zfill) { + data = *reinterpret_cast(glob_ptr); + } else { + data = make_uint4(0, 0, 0, 0); + } + + // Store to shared memory + *reinterpret_cast(smem_ptr) = data; + + // Ensure visibility + __threadfence_block(); + } + // ::int_amdgcn_global_load_lds(static_cast(glob_ptr), + // &smem, + // BYTES, 0, 0); +#else + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); +#endif } -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); - #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); - #else - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); - #endif +#ifdef USE_ROCM + // __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), + // &smem, BYTES, 0, 0); +#else + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +#endif } // Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = cvta_to_shared(smem_ptr); - #ifdef USE_ROCM - __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), &smem, BYTES, 0, 0); - #else - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); - #endif +#ifdef USE_ROCM + // __builtin_amdgcn_global_load_lds(static_cast(glob_ptr), + // &smem, BYTES, 0, 0); +#else + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +#endif } // Async copy fence. @@ -105,8 +123,7 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { +template __device__ inline void cp_async_wait() { #ifdef USE_ROCM // For AMD GPUs, we use s_waitcnt // This waits for all outstanding memory operations to complete @@ -119,80 +136,77 @@ __device__ inline void cp_async_wait() { // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); - #ifdef USE_ROCM - asm volatile( - "ds_read_b128 %0, %1 offset:0\n" - "ds_read_b128 %2, %1 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) - : "v"(smem)); - #else +#ifdef USE_ROCM + asm volatile("ds_read_b128 %0, %4 offset:0\n" + "ds_read_b128 %2, %4 offset:16\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + : "v"(smem)); +#else asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); - #endif +#endif } -__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_m); +__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_m); uint32_t smem = cvta_to_shared(smem_ptr); - #ifdef USE_ROCM - asm volatile( - "ds_read_b64 %0, %2 offset:0\n" - : "=v"(a[0]), "=v"(a[1]) - : "v"(smem)); - #else +#ifdef USE_ROCM + asm volatile("ds_read_b64 %0, %2 offset:0\n" + : "=v"(a[0]), "=v"(a[1]) + : "v"(smem)); +#else asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); - #endif +#endif } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = cvta_to_shared(smem_ptr); - #ifdef USE_ROCM - asm volatile( - "ds_read_b128 %0, %1 offset:0\n" - "ds_read_b128 %2, %1 offset:16\n" - : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) - : "v"(smem)); - #else +#ifdef USE_ROCM + asm volatile("ds_read_b128 %0, %1 offset:0\n" + "ds_read_b128 %2, %1 offset:16\n" + : "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3]) + : "v"(smem)); +#else asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); - #endif +#endif } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { +__device__ inline void barrier_acquire(int *lock, int count) { if (threadIdx.x == 0) { int state = -1; do { - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - #ifdef USE_ROCM +// Guarantee that subsequent writes by this threadblock will be visible +// globally. +#ifdef USE_ROCM asm volatile("flat_load_dword %0, %1 glc\n\t" "s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t" : "=v"(state) : "v"(lock)); - #else +#else asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - #endif +#endif } while (state != count); } __syncthreads(); } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { +__device__ inline void barrier_release(int *lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -200,21 +214,21 @@ __device__ inline void barrier_release(int* lock, bool reset = false) { return; } int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - #ifdef USE_ROCM +// Make sure that all writes since acquiring this barrier are visible +// globally, while releasing the barrier. +#ifdef USE_ROCM asm volatile("s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t" "s_memrealtime\n\t" "s_waitcnt vmcnt(0) & lgkmcnt(0)\n\t" "flat_atomic_add_i32 %0, %1\n\t" : "+v"(*lock) : "v"(val)); - #else +#else asm volatile("fence.acq_rel.gpu;\n"); asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); - #endif +#endif } } -} // namespace torchao +} // namespace torchao diff --git a/torchao/csrc/cuda/sparse_marlin/mma.h b/torchao/csrc/cuda/sparse_marlin/mma.h index 9e9a9be519..a8334974fa 100644 --- a/torchao/csrc/cuda/sparse_marlin/mma.h +++ b/torchao/csrc/cuda/sparse_marlin/mma.h @@ -32,26 +32,57 @@ namespace torchao { // | reduced performance on some future architectures #if defined(USE_ROCM) - // HIP ISA doesn't have an equivalent for ordered_metadata, so we'll use the standard mma instruction - #define MMA_SP_INST "v_mfma_f32_16x16x16f16 " +// HIP ISA doesn't have an equivalent for ordered_metadata, so we'll use the +// standard mma instruction +#define MMA_SP_INST "v_mfma_f32_16x16x16f16 " #elif defined(CUDA_VERSION) && CUDA_VERSION >= 12050 - #define MMA_SP_INST \ - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " +#define MMA_SP_INST \ + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " #else - #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " +#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " #endif // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, - const FragA& frag_b, FragC& frag_c, FragM& frag_m, +__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1, + const FragA &frag_b, FragC &frag_c, FragM &frag_m, const int psel) { - const uint32_t* a0 = reinterpret_cast(&a_frag0); - const uint32_t* a1 = reinterpret_cast(&a_frag1); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* e = reinterpret_cast(&frag_m); + const uint32_t *a0 = reinterpret_cast(&a_frag0); + const uint32_t *a1 = reinterpret_cast(&a_frag1); + const uint32_t *b = reinterpret_cast(&frag_b); + const uint32_t *e = reinterpret_cast(&frag_m); - float* c = reinterpret_cast(&frag_c); + float *c = reinterpret_cast(&frag_c); +#if defined(USE_ROCM) + // AMD implementation + if (psel == 0) { + asm volatile(MMA_SP_INST "%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, " + "%11, %12, %13, %14, %15, %16, 0\n" + : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), "v"(b[0]), + "v"(b[2]), "v"(b[4]), "v"(b[6]), "v"(c[0]), "v"(c[1]), + "v"(c[2]), "v"(c[3]), "v"(e[0])); + asm volatile(MMA_SP_INST "%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, " + "%11, %12, %13, %14, %15, %16, 0\n" + : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), "v"(b[1]), + "v"(b[3]), "v"(b[5]), "v"(b[7]), "v"(c[4]), "v"(c[5]), + "v"(c[6]), "v"(c[7]), "v"(e[0])); + } else { + asm volatile(MMA_SP_INST "%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, " + "%11, %12, %13, %14, %15, %16, 1\n" + : "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), "v"(b[0]), + "v"(b[2]), "v"(b[4]), "v"(b[6]), "v"(c[0]), "v"(c[1]), + "v"(c[2]), "v"(c[3]), "v"(e[0])); + asm volatile(MMA_SP_INST "%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, " + "%11, %12, %13, %14, %15, %16, 1\n" + : "=v"(c[4]), "=v"(c[5]), "=v"(c[6]), "=v"(c[7]) + : "v"(a0[0]), "v"(a1[0]), "v"(a0[1]), "v"(a1[1]), "v"(b[1]), + "v"(b[3]), "v"(b[5]), "v"(b[7]), "v"(c[4]), "v"(c[5]), + "v"(c[6]), "v"(c[7]), "v"(e[0])); + } +#else if (psel == 0) { asm volatile(MMA_SP_INST "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " @@ -83,35 +114,43 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), "r"(e[0])); } +#endif } // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template -__device__ inline int lop3(int a, int b, int c) { +template __device__ inline int lop3(int a, int b, int c) { int res; - #ifdef USE_ROCM - // AMD GPUs don't have a direct equivalent to lop3, so we implement it using bitwise operations +#ifdef USE_ROCM + // AMD GPUs don't have a direct equivalent to lop3, so we implement it using + // bitwise operations res = (a & b & c) | (a & b & ~c) | (a & ~b & c) | (~a & b & c); // Apply the LUT res = (res & lut) | (~res & ~lut); - #else +#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); - #endif +#endif return res; } __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, float c3) { uint2 r; - #ifdef USE_ROCM - // AMD implementation - r.x = __builtin_amdgcn_cvt_pkrtz(c0, c1); - r.y = __builtin_amdgcn_cvt_pkrtz(c2, c3); - #else +#ifdef USE_ROCM + // AMD implementation - properly handle the vector type conversion + typedef __attribute__((ext_vector_type(2))) __fp16 half2_t; + + // Convert float pairs to half2 vectors + half2_t packed_x = __builtin_amdgcn_cvt_pkrtz(c0, c1); + half2_t packed_y = __builtin_amdgcn_cvt_pkrtz(c2, c3); + + // Properly convert the vector type to uint32_t via bit-level reinterpretation + r.x = *reinterpret_cast(&packed_x); + r.y = *reinterpret_cast(&packed_y); +#else // NVIDIA implementation asm("{\n\t" ".reg .f16 a, b, c, d; \n\t" @@ -124,7 +163,7 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, "}" : "=r"(r.x), "=r"(r.y) : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); - #endif +#endif return r; } @@ -133,16 +172,17 @@ __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; - #ifdef USE_ROCM +#ifdef USE_ROCM // AMD implementation - res = ((a & 0xFF) << 24) | ((a & 0xFF00) << 8) | ((a & 0xFF0000) >> 8) | ((a & 0xFF000000) >> 24); + res = ((a & 0xFF) << 24) | ((a & 0xFF00) << 8) | ((a & 0xFF0000) >> 8) | + ((a & 0xFF000000) >> 24); res = (res >> (start_byte * 8)) & mask; - #else +#else // NVIDIA implementation asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); - #endif +#endif return res; } @@ -164,24 +204,24 @@ __device__ inline FragB dequant_4bit(int q) { const int ADD = 0xd480d480; FragB frag_b; - #ifdef USE_ROCM +#ifdef USE_ROCM // AMD implementation - __half2* lo_ptr = reinterpret_cast<__half2*>(&lo); - __half2* hi_ptr = reinterpret_cast<__half2*>(&hi); - const __half2* SUB_ptr = reinterpret_cast(&SUB); - const __half2* MUL_ptr = reinterpret_cast(&MUL); - const __half2* ADD_ptr = reinterpret_cast(&ADD); + ::__half2 *lo_ptr = reinterpret_cast<::__half2 *>(&lo); + ::__half2 *hi_ptr = reinterpret_cast<::__half2 *>(&hi); + const ::__half2 *SUB_ptr = reinterpret_cast(&SUB); + const ::__half2 *MUL_ptr = reinterpret_cast(&MUL); + const ::__half2 *ADD_ptr = reinterpret_cast(&ADD); - frag_b[0] = __hsub(*lo_ptr, *SUB_ptr); - frag_b[1] = __hfma(*hi_ptr, *MUL_ptr, *ADD_ptr); - #else + frag_b[0] = __hsub2(*lo_ptr, *SUB_ptr); + frag_b[1] = __hfma2(*hi_ptr, *MUL_ptr, *ADD_ptr); +#else // NVIDIA implementation - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - #endif + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +#endif return frag_b; } @@ -200,55 +240,57 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; FragB frag_b; - #ifdef USE_ROCM +#ifdef USE_ROCM // AMD implementation - __half2* lo_ptr = reinterpret_cast<__half2*>(&lo); - __half2* hi_ptr = reinterpret_cast<__half2*>(&hi); - const __half2* magic_num_ptr = reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM); + ::__half2 *lo_ptr = reinterpret_cast<::__half2 *>(&lo); + ::__half2 *hi_ptr = reinterpret_cast<::__half2 *>(&hi); + const ::__half2 *magic_num_ptr = + reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM); - frag_b[0] = __hsub(*lo_ptr, *magic_num_ptr); - frag_b[1] = __hsub(*hi_ptr, *magic_num_ptr); - #else + frag_b[0] = __hsub2(*lo_ptr, *magic_num_ptr); + frag_b[1] = __hsub2(*hi_ptr, *magic_num_ptr); +#else // NVIDIA implementation - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - #endif + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +#endif return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - #ifdef USE_ROCM +__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { +#ifdef USE_ROCM // AMD implementation - __half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul(frag_b[0], s); - frag_b[1] = __hmul(frag_b[1], s); - #else + // ::__half2 s = __half2half2(reinterpret_cast<::__half *>(&frag_s)[i]); + // frag_b[0] = ::__hmul2(frag_b[0], s); + // frag_b[1] = ::__hmul2(frag_b[1], s); +#else // NVIDIA implementation - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); - #endif + half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); + frag_b[0] = ::__hmul2(frag_b[0], s); + frag_b[1] = ::__hmul2(frag_b[1], s); +#endif } -__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, - FragS& s0, float* c4, float* c5, float* c6, - float* c7, FragS& s1) { - #ifdef USE_ROCM +__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, + FragS &s0, float *c4, float *c5, float *c6, + float *c7, FragS &s1) { +#ifdef USE_ROCM // AMD implementation - *c0 = __builtin_amdgcn_fmul_legacy(*c0, __half2float(s0[0].x)); - *c1 = __builtin_amdgcn_fmul_legacy(*c1, __half2float(s0[0].y)); - *c2 = __builtin_amdgcn_fmul_legacy(*c2, __half2float(s0[1].x)); - *c3 = __builtin_amdgcn_fmul_legacy(*c3, __half2float(s0[1].y)); + // *c0 = __fmul_rn(*c0, device_half2float(s0[0].x)); + // *c1 = __fmul_rn(*c1, device_half2float(s0[0].y)); + // *c2 = __fmul_rn(*c2, device_half2float(s0[1].x)); + // *c3 = __fmul_rn(*c3, device_half2float(s0[1].y)); - *c4 = __builtin_amdgcn_fmul_legacy(*c4, __half2float(s1[0].x)); - *c5 = __builtin_amdgcn_fmul_legacy(*c5, __half2float(s1[0].y)); - *c6 = __builtin_amdgcn_fmul_legacy(*c6, __half2float(s1[1].x)); - *c7 = __builtin_amdgcn_fmul_legacy(*c7, __half2float(s1[1].y)); - #else + // *c4 = __fmul_rn(*c4, device_half2float(s1[0].x)); + // *c5 = __fmul_rn(*c5, device_half2float(s1[0].y)); + // *c6 = __fmul_rn(*c6, device_half2float(s1[1].x)); + // *c7 = __fmul_rn(*c7, device_half2float(s1[1].y)); + +#else // NVIDIA implementation *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); @@ -259,7 +301,7 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); - #endif +#endif } -} // namespace torchao +} // namespace torchao