diff --git a/test/quantization/test_dynamic_float8_linear_cpu.py b/test/quantization/test_dynamic_float8_linear_cpu.py new file mode 100644 index 0000000000..2ccfca1dc7 --- /dev/null +++ b/test/quantization/test_dynamic_float8_linear_cpu.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.dtypes import ( + Float8DynamicActFloat8WeightCPULayout, + PlainLayout, +) +from torchao.quantization import PerRow +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, +) + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, K=64, N=32, bias=False): + super().__init__() + self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) + + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestDynamicFloat8Linear(TestCase): + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + m2 = copy.deepcopy(m) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + quantize_( + m, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + layout=Float8DynamicActFloat8WeightCPULayout(), + ), + ) + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + quantize_( + m2, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + layout=PlainLayout(), + ), + ) + torch._dynamo.reset() # may segfault without this + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + atol, rtol = 1e-6, 1e-6 + if dtype == torch.bfloat16: + atol, rtol = 1.6e-2, 3e-3 + elif dtype == torch.half: + atol, rtol = 6e-3, 2e-3 + assert torch.allclose(y, y2, atol=atol, rtol=rtol) + # Test get_plain by dequantize() + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() + assert torch.allclose(dqw1, dqw1_ref) + assert torch.allclose(dqw2, dqw2_ref) + + +common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/csrc/cpu/float8_linear.cpp b/torchao/csrc/cpu/float8_linear.cpp new file mode 100644 index 0000000000..c211168069 --- /dev/null +++ b/torchao/csrc/cpu/float8_linear.cpp @@ -0,0 +1,523 @@ +#include +#include +#include +#include + +namespace torchao { + +namespace { + +#define BLOCK_N 32 + +static bool cpublas_checked = false; +static bool cpublas_can_pack = false; + +bool cpublas_could_pack() { + // the could_pack check requires AMX support implicitly + if (cpublas_checked) { + return cpublas_can_pack; + } + cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16); + cpublas_checked = true; + return cpublas_can_pack; +} + +/* +return: packed_weight, packed_scales +*/ +std::tuple +float8_linear_prepack_impl( + const at::Tensor& weight, + const at::Tensor& scales) { + // weight shape = [N, K] + // scales shape = [N, G] + TORCH_CHECK(weight.dim() == 2, + "Float8 linear CPU: Weight should be a 2D tensor for packing"); + TORCH_CHECK(weight.size(1) % 2 == 0, + "Float8 linear CPU: Weight should have even number of columns for packing"); + + auto new_scales = scales; + if (new_scales.dim() == 1) { + new_scales.unsqueeze_(1); + } + new_scales = new_scales.to(at::kFloat); + int N = weight.size(0); + int K = weight.size(1); + int G = scales.size(1); + int group_size = K / G; + int block_k = group_size > 128 ? 128 : group_size; + while (K % block_k != 0) { + block_k /= 2; + } + TORCH_CHECK(block_k > 0 && block_k <= group_size, + "Float8 linear CPU: Invalid block_k size, should be in (0, group_size]"); + constexpr int block_n = BLOCK_N; + int Nc = N / block_n; + int Kc = K / block_k; + + // Reorder weight to [N/block_n, K/block_k, block_k, block_n] + // Reorder scales to [N/block_n, G, block_n] + auto weight_view = weight.view({Nc, block_n, Kc, block_k}); + at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); + at::Tensor blocked_weight; + at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + +#if defined(CPU_CAPABILITY_AVX512) + if (cpublas_could_pack()) { + constexpr int vnni_size = 2; // for float16 + blocked_weight = at::empty({Nc, Kc, block_k, block_n}, weight.options()); + auto weight_ptr = reinterpret_cast(weight_reordered.data_ptr()); + auto blocked_weight_ptr = reinterpret_cast(blocked_weight.data_ptr()); + int64_t num_blocks = Nc * Kc; + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + auto in_ptr = weight_ptr + i * block_k * block_n; + auto out_ptr = blocked_weight_ptr + i * block_k * block_n; + + // Reorder weight block to VNNI + // plain shape = [block_k, block_n] + // packed shape = [block_k / VNNI_SIZE, block_n, VNNI_SIZE] viewed as [block_k, block_n] + constexpr int n_group_size = 8; + constexpr int n_group = block_n / n_group_size; // 4 + for (int nb = 0; nb < n_group; ++nb) { + for (int k = 0; k < block_k; k += vnni_size) { + for (int ni = 0; ni < n_group_size; ++ni) { + for (int ki = 0; ki < vnni_size; ++ki) { + int src_idx = nb * n_group_size + ni + (k + ki) * block_n; + int dst_idx = (nb * n_group_size + ni) * vnni_size + k * block_n + ki; + *(out_ptr + dst_idx) = *(in_ptr + src_idx); + } + } + } + } + } + }); + } else +#endif + { + blocked_weight = weight_reordered; + } + + return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales)); +} + +#if defined(CPU_CAPABILITY_AVX512) +alignas(64) static uint16_t e4m3_to_16bit[256]; + +template +static void initialize_e4m3_to_16bit_tables() { + // run only once + static bool initialized_16bit = false; + if (!initialized_16bit) { + for (uint8_t u8 = 0; u8 < 256; ++u8) { + auto value = static_cast(c10::bit_cast(u8)); + uint16_t value_bits = c10::bit_cast(value); + e4m3_to_16bit[u8] = value_bits; + if (u8 == 255) { + break; + } + } + initialized_16bit = true; + } +} + +template +static void cvt_e4m3_16bit_intrinsic_lut( + const at::Float8_e4m3fn* __restrict__ in, + T* out, + int64_t len) { + for (size_t i = 0; i < len; i += 64) { + __m512i fp8_vec = _mm512_loadu_si512((__m512i*)&in[i]); + __m128i group0 = _mm512_castsi512_si128(fp8_vec); + __m128i group1 = _mm512_extracti32x4_epi32(fp8_vec, 1); + __m128i group2 = _mm512_extracti32x4_epi32(fp8_vec, 2); + __m128i group3 = _mm512_extracti32x4_epi32(fp8_vec, 3); + + __m512i indices0 = _mm512_cvtepu8_epi32(group0); + __m512i indices1 = _mm512_cvtepu8_epi32(group1); + __m512i indices2 = _mm512_cvtepu8_epi32(group2); + __m512i indices3 = _mm512_cvtepu8_epi32(group3); + + // Gather BF16 conversion results from the lookup table. + __m512i bf16_i32_vec0 = _mm512_i32gather_epi32(indices0, e4m3_to_16bit, 2); + __m512i bf16_i32_vec1 = _mm512_i32gather_epi32(indices1, e4m3_to_16bit, 2); + __m512i bf16_i32_vec2 = _mm512_i32gather_epi32(indices2, e4m3_to_16bit, 2); + __m512i bf16_i32_vec3 = _mm512_i32gather_epi32(indices3, e4m3_to_16bit, 2); + + // Helper lambda: Convert 16 32-bit ints (in a __m512i) to 16 16-bit ints. + auto convert_32_to_16 = [](__m512i vec) -> __m256i { + return _mm512_cvtepi32_epi16(vec); + }; + + __m256i bf16_i16_vec0 = convert_32_to_16(bf16_i32_vec0); + __m256i bf16_i16_vec1 = convert_32_to_16(bf16_i32_vec1); + __m256i bf16_i16_vec2 = convert_32_to_16(bf16_i32_vec2); + __m256i bf16_i16_vec3 = convert_32_to_16(bf16_i32_vec3); + + _mm256_storeu_si256((__m256i*)(out + i + 0), bf16_i16_vec0); + _mm256_storeu_si256((__m256i*)(out + i + 16), bf16_i16_vec1); + _mm256_storeu_si256((__m256i*)(out + i + 32), bf16_i16_vec2); + _mm256_storeu_si256((__m256i*)(out + i + 48), bf16_i16_vec3); + } +} + +static void _convert_B_to_bf16( + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* dqB, + int64_t len) { + initialize_e4m3_to_16bit_tables(); + int tail = len % 64; + cvt_e4m3_16bit_intrinsic_lut(B, dqB, len - tail); + for (int i = len - tail; i < len; ++i) { + dqB[i] = (at::BFloat16)B[i]; + } +} + +static void _convert_A_to_bf16( + const at::Float8_e4m3fn* __restrict__ A, + at::BFloat16* dqA, + int64_t M, + int64_t K, + int64_t lda) { + initialize_e4m3_to_16bit_tables(); + for (int m = 0; m < M; ++m) { + int tail = K % 64; + int body = K - tail; + cvt_e4m3_16bit_intrinsic_lut(A + m * lda, dqA + m * K, body); + for (int k = body; k < K; ++k) { + dqA[m * K + k] = (at::BFloat16)A[m * lda + k]; + } + } +} + +template +static void _dequant_and_store( + float* __restrict__ output, + const float* __restrict__ input, + const float* __restrict__ scale_a, + const float* __restrict__ scale_b, + int M, + int ldi, + int ldo, + int ldsa = 1) { + for (int m = 0; m < M; ++m) { + float a_scale = *(scale_a + m * ldsa); + __m512 va_scale = _mm512_set1_ps(a_scale); + int n = 0; +#pragma GCC unroll 2 + for (; n < N; n += 16) { + __m512 vc_f = _mm512_loadu_ps(input + m * ldi + n); + __m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale); + __m512 vb_s = _mm512_loadu_ps(scale_b + n); + vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s); + if constexpr (accum) { + __m512 vo = _mm512_loadu_ps(output + m * ldo + n); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul)); + } else { + _mm512_storeu_ps(output + m * ldo + n, vc_f_mul); + } + } + for (; n < N; ++n) { + float dq_val = input[m * ldi + n] * a_scale * scale_b[n]; + if constexpr (accum) { + output[m * ldo + n] += dq_val; + } else { + output[m * ldo + n] = dq_val; + } + } + } +} + +#else +static void _convert_B_to_bf16( + const at::Float8_e4m3fn* B, + at::BFloat16* dqB, + int64_t len) { + for (int i = 0; i < len; ++i) { + dqB[i] = (at::BFloat16)B[i]; + } +} + +static void _convert_A_to_bf16( + const at::Float8_e4m3fn* __restrict__ A, + at::BFloat16* dqA, + int64_t M, + int64_t K, + int64_t lda) { + for (int m = 0; m < M; ++m) { + for (int k = 0; k < K; ++k) { + dqA[m * K + k] = (at::BFloat16)A[m * lda + k]; + } + } +} +#endif + +template +void _dequant_gemm_accum( + float* C, + const at::Float8_e4m3fn* A, + const float* scales_a, + const at::Float8_e4m3fn* B, + const float* scales_b, + int64_t M, + int64_t K, + int64_t lda, + int64_t ldc) { + // Compute GEMM fp8 * fp8 -> fp32 + // Then apply scales and store results + at::BFloat16 dqB[K * N]; + _convert_B_to_bf16(B, dqB, K * N); + at::BFloat16 dqA[M * K]; + _convert_A_to_bf16(A, dqA, M, K, lda); +#if defined(CPU_CAPABILITY_AVX512) + if constexpr (cpublas_can_pack) { + float C_f32[M * N]; + at::native::cpublas::brgemm( + M, + N, + K, + K /*lda*/, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + dqA, + dqB, + C_f32, + true /* is_vnni */); + _mm_prefetch(B + N * K, _MM_HINT_T0); + _mm_prefetch(A + K, _MM_HINT_T0); + _dequant_and_store( + C, + C_f32, + scales_a, + scales_b, + M, + N /*ldi*/, + ldc, + 1 /*ldsa*/); + } else +#endif + { + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0; + for (int64_t k = 0; k < K; ++k) { + sum += ((float)dqA[i * K + k] * dqB[k * N + j]); + } + C[i * ldc + j] += sum * scales_a[i] * scales_b[j]; + } + } + } +} + +template +inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) { + if (bias_ptr) { + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 bias_vec = _mm512_loadu_ps(bias_ptr + j); + _mm512_storeu_ps(y_buf + i * N + j, bias_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = bias_ptr[j]; + } + } + } else { // initialize to zero + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 zero_vec = _mm512_setzero_ps(); + _mm512_storeu_ps(y_buf + i * N + j, zero_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = 0; + } + } + } +} + +template +inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, /* int64_t n, */ int64_t lda) { + for (int i = 0; i < m; ++i) { + int j = 0; + if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = y_buf[i * N + j]; + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]); + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]); + } + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + } +} + +template +void _float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::Tensor& output) { + // input shape = [..., K] + // input is per token quantized + int64_t K = input.size(-1); + auto input_view = input.view({-1, K}); + int64_t M = input_view.size(0); + TORCH_CHECK(input_scales.numel() == M, "Float8 linear: unexpected input scales shape"); + + // weight shape = [Nc, Kc, block_k, block_n] + // scales shape = [Nc, G, block_n] + int64_t Nc = weight.size(0); + int64_t Kc = weight.size(1); + int64_t block_k = weight.size(2); + constexpr int64_t block_n = BLOCK_N; + TORCH_CHECK(weight.size(3) == block_n, "Float8 linear: unexpected weight shape"); + int64_t N = Nc * block_n; + TORCH_CHECK(K == Kc * block_k, "Float8 linear: weight and input shapes mismatch"); + int64_t block_m = [&]() -> long { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 64; + } else { + return 128; + } + }(); + int64_t Mc = (M + block_m - 1) / block_m; + bool parallel_on_M = M > 128; + int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc; + + // scales shape = [Nc, G, block_n] + int64_t num_groups = weight_scales.size(1); + TORCH_CHECK(K % num_groups == 0, "K should be divisible by num_groups"); + int64_t group_size = K / num_groups; + TORCH_CHECK(group_size % block_k == 0, + "Float8 linear: group_size should be divisible by block_k"); + int64_t block_per_group = group_size / block_k; + + const at::Float8_e4m3fn* a_ptr = input_view.data_ptr(); + const float* a_scales_ptr = input_scales.data_ptr(); + const at::Float8_e4m3fn* b_ptr = weight.data_ptr(); + const float* b_scales_ptr = weight_scales.data_ptr(); + out_dtype* c_ptr = output.data_ptr(); + const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + int64_t mc = parallel_on_M ? i / Nc : 0; + int64_t nc = parallel_on_M ? i % Nc : i; + int64_t mc_end = parallel_on_M ? mc + 1 : Mc; + + for (int mci = mc; mci < mc_end; ++mci) { + int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; + alignas(64) float y_buf[m_size][block_n]; + // copy bias to y_buf if bias is not None + auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; + copy_bias(bias_data, y_buf[0], m_size); + for (int kci = 0; kci < Kc; ++kci) { + _dequant_gemm_accum( + y_buf[0] /*C*/, + a_ptr + mci * block_m * K + kci * block_k /*A*/, + a_scales_ptr + mci * block_m /*scales_a*/, + b_ptr + (nc * Kc + kci) * block_n * block_k /*B*/, + b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*scales_b*/, + m_size /*M*/, + block_k /*K*/, + K /*lda*/, + block_n /*ldc*/); + } + // store y_buf to output with dtype conversion + store_out( + y_buf[0], + c_ptr + mci * block_m * N + nc * block_n, + m_size, + N /*lda*/); + } + } + if constexpr (cpublas_can_pack) { + at::native::cpublas::brgemm_release(); + } + }); +} + +at::Tensor float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::ScalarType output_dtype) { + static bool cpublas_can_pack = cpublas_could_pack(); + auto out_sizes = input.sizes().vec(); + int64_t N = weight.size(0) * weight.size(-1); + out_sizes.back() = N; + auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); + +#define call__float8_linear_impl(cpublas_can_pack) \ + AT_DISPATCH_FLOATING_TYPES_AND2( \ + at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "float8_linear_cpu", [&] { \ + _float8_linear_impl( \ + input, \ + input_scales, \ + weight, \ + weight_scales, \ + bias, \ + output); \ + }); + + if (cpublas_can_pack) { + call__float8_linear_impl(true); + } else { + call__float8_linear_impl(false); + } + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::float8_linear_prepack_cpu", &float8_linear_prepack_impl); + m.impl("torchao::float8_linear_cpu", &float8_linear_impl); +} + +} // namespace torchao diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d6b1b9c440..476df2aace 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -12,6 +12,7 @@ from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, + Float8DynamicActFloat8WeightCPULayout, Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 @@ -70,4 +71,5 @@ "FbgemmFp8Tensor", "Int8DynamicActInt4WeightCPULayout", "Int4GroupwisePreshuffleTensor", + "Float8DynamicActFloat8WeightCPULayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 8b028352e4..a28e764cb8 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -15,6 +15,10 @@ _linear_fp8_act_fp8_weight_sparse_cutlass_check, _linear_fp8_act_fp8_weight_sparse_cutlass_impl, ) +from torchao.dtypes.floatx.dyn_float8_act_float8_wei_cpu_layout import ( + _float8_linear_cpu_check, + _float8_linear_cpu_impl, +) from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl, @@ -255,6 +259,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, ), + ( + _float8_linear_cpu_check, + _float8_linear_cpu_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 7e634a5211..05744e6b50 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,6 +1,9 @@ from .cutlass_semi_sparse_layout import ( CutlassSemiSparseLayout, ) +from .dyn_float8_act_float8_wei_cpu_layout import ( + Float8DynamicActFloat8WeightCPULayout, +) from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( FloatxTensorCoreLayout, @@ -14,4 +17,5 @@ "from_scaled_tc_floatx", "Float8Layout", "CutlassSemiSparseLayout", + "Float8DynamicActFloat8WeightCPULayout", ] diff --git a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py new file mode 100644 index 0000000000..7b581bf6a6 --- /dev/null +++ b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py @@ -0,0 +1,284 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout, is_device +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, + fill_defaults, +) + +from ..uintx.int4_cpu_layout import ( + _is_float, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Float8DynamicActFloat8WeightCPULayout(Layout): + """Layout class for float8 da8w8 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Float8DynamicActFloat8WeightCPULayout) +class Float8DynActFloat8WeiCpuAQTTensorImpl(AQTTensorImpl): + """TensorImpl for float8 da8w8 CPU layout for affine quantized tensor""" + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales"], [ + self.transposed, + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + ) + (_layout, transposed) = tensor_attributes + return cls(packed_weight, scales, transposed, _layout) + + @classmethod + def from_plain( + cls, + data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Float8DynamicActFloat8WeightCPULayout) + assert data.dtype == torch.float8_e4m3fn, ( + "Float8 DA8W8 CPU: expects float8_e4m3fn weight" + ) + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + + K = data.size(-1) + if K % 32 == 0: + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack inner blocks [block_k, block_n] to VNNI layout if AMX is available. + # Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n]. + weight_packed, scales = torch.ops.torchao.float8_linear_prepack_cpu( + data, scale + ) + else: + weight_packed = data + scales = scale + _layout = PlainLayout() + return cls(weight_packed, scales, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim in [0, 1]: + assert step == 1, "Only step == 1 is supported in slicing right now" + data, scale = self.get_plain() + data_len = data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + data = aten.slice.Tensor(data, dim, start, end, step) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + # this is to handle padding + data, scale = self._layout.post_process(data, scale, self.block_size) + sliced = self.from_plain(data, scale, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + else: + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self._layout == PlainLayout: + # If the layout is PlainLayout, return the packed weight and scales directly + return ( + self.packed_weight, + self.scales, + torch.zeros_like(self.scales), + ) + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.float8_e4m3fn) + x_scale = torch.ones(K).float() + w_scale = torch.ones_like(self.scales).float() + plain_weight = torch.ops.torchao.float8_linear_cpu.default( + x, + x_scale, + self.packed_weight, + w_scale, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.float8_e4m3fn) + + if self.scales.dim() == 2: + plain_scales = self.scales + else: + assert self.scales.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, torch.zeros_like(plain_scales) + + +def _aqt_is_float8e4m3(aqt): + """Check if an AffineQuantizedTensor is float8_e4m3fn quantized Tensor""" + return aqt.tensor_impl.dtype == torch.float8_e4m3fn + + +def _float8_linear_cpu_check(input_tensor, weight_tensor, bias): + return ( + TORCH_VERSION_AT_LEAST_2_6 + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_float8e4m3(input_tensor) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_float8e4m3(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Float8DynamicActFloat8WeightCPULayout) + ) + + +def _float8_linear_cpu_impl(input_tensor, weight_tensor, bias): + assert TORCH_VERSION_AT_LEAST_2_6, ( + f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.float8_linear_cpu.default( + act.contiguous(), + act_scales, + packed_weight, + wei_scales, + bias.float() if bias is not None else bias, # requires bias to be float + torch.float, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) diff --git a/torchao/ops.py b/torchao/ops.py index babe5506c0..178e98f589 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -70,6 +70,12 @@ lib.define( "da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" ) +lib.define( + "float8_linear_prepack_cpu(Tensor weight, Tensor scales) -> (Tensor, Tensor)" +) +lib.define( + "float8_linear_cpu(Tensor input, Tensor input_scales, Tensor weight, Tensor weight_scales, Tensor? bias, ScalarType output_dtype) -> Tensor" +) def register_custom_op(name): @@ -1106,3 +1112,67 @@ def _( assert weight.dim() == 4 N = weight.size(0) * weight.size(3) * 2 return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) + + +def float8_linear_prepack_cpu( + weight: Tensor, + scales: Tensor, +) -> Tensor: + """ + Prepack weights for float8 linear operator on CPU. + Args: + weight: weight tensor. + scales: scales for weight tensor. + Returns: + packed weight, packed scales + """ + return torch.ops.torchao.float8_linear_prepack_cpu.default(weight, scales) + + +@register_custom_op("torchao::float8_linear_prepack_cpu") +def _(weight: Tensor, scales: Tensor) -> Tensor: + return weight, scales + + +def float8_linear_cpu( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +): + """ + float8 linear operator on CPU. + Args: + input: input tensor. + input_scales: scales for input tensor. + weight: weight tensor. + weight_scales: scales for weight tensor. + bias: optional bias tensor. + out_dtype: output data type. + Returns: + output tensor in out_dtype. + """ + return torch.ops.torchao.float8_linear_cpu.default( + input, + input_scales, + weight, + weight_scales, + bias, + out_dtype, + ) + + +@register_custom_op("torchao::float8_linear_cpu") +def _( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +) -> Tensor: + assert weight.dim() == 4 + N = weight.size(0) * weight.size(3) + return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4662e20fc9..35773c9b0e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -82,6 +82,7 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, _is_fbgemm_genai_gpu_available, + check_cpu_version, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -1561,6 +1562,22 @@ def _input_activation_quant_func_fp8( return activation +def _input_activation_quant_cpu_fp8( + x: torch.Tensor, + activation_granularity: FP8Granularity, + activation_dtype: torch.dtype, +): + """Dynamic quantize activation to fp8 for CPU.""" + block_size = get_block_size(x.shape, activation_granularity) + return to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=PlainLayout(), + ) + + def _fp8_mm_compat(weight: torch.Tensor) -> bool: """ Check if a weight tensor meets float8 quantization requirements. @@ -1611,10 +1628,13 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True + layout: Optional[Layout] = None def __post_init__(self): - if self.mm_config is None: - self.mm_config = Float8MMConfig(use_fast_accum=True) + if self.layout is None: + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + self.layout = Float8Layout(self.mm_config) activation_granularity, weight_granularity = _normalize_granularity( self.granularity @@ -1630,17 +1650,23 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype granularity = config.granularity - mm_config = config.mm_config # Ensure works on device - _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity + is_cpu = weight.device.type == "cpu" + if is_cpu: + assert not ( + isinstance(activation_granularity, PerTensor) + or isinstance(weight_granularity, PerTensor) + ), "PerTensor quantization is not supported for CPU float8 quantization" + else: + _check_hardware_support(granularity) - if not _fp8_mm_compat(weight): + if not is_cpu and not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked return weight - if isinstance(weight_granularity, PerRow): + if not is_cpu and isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" ) @@ -1653,10 +1679,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): block_size=block_size, target_dtype=weight_dtype, scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), + _layout=config.layout, ) - input_quant_func = _input_activation_quant_func_fp8 + input_quant_func = ( + _input_activation_quant_func_fp8 + if isinstance(config.layout, Float8Layout) + else _input_activation_quant_cpu_fp8 + ) input_quant_kwargs = { "activation_granularity": activation_granularity, "activation_dtype": activation_dtype, @@ -1672,16 +1702,21 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig ): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) + + assert ( + check_cpu_version(module.weight.device, "2.6.0") + or is_sm_at_least_89() + or is_MI300() + ), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+ or on CPU with PyTorch >= 2.6.0" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( module.weight, config )