Skip to content

Commit a89bddd

Browse files
authored
Matmul_nbits kernel for mlas sqnbits to support Fp16 inputs (microsoft#21807)
1 parent 7e2c722 commit a89bddd

File tree

11 files changed

+341
-116
lines changed

11 files changed

+341
-116
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,10 @@ message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")
580580

581581
if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11")
582582
message(STATUS "Using -mavx2 -mfma -mavxvnni flags")
583-
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")
583+
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni")
584584
else()
585585
message(STATUS "Using -mavx2 -mfma flags")
586-
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
586+
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c")
587587
endif()
588588
set(mlas_platform_srcs_avx512f
589589
${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S

docs/OperatorKernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ Do not modify directly.*
488488
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
489489
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
490490
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
491-
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
491+
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(float16), tensor(uint8)<br/> **T4** = tensor(int32)|
492492
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
493493
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
494494
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Lines changed: 181 additions & 65 deletions
Large diffs are not rendered by default.

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ void Dequantize4BitsKernelReOrder(
5454
T scale = *(scale_data + n_idx * scales_shape_x + rid);
5555
float zp_f = 8;
5656
if (zero_points) {
57-
if constexpr (std::is_same_v<zeroT, T>) {
58-
zp_f = *(zero_points + n_idx * scales_shape_x + rid);
59-
} else {
57+
if constexpr (std::is_same_v<zeroT, uint8_t>) {
6058
uint8_t zp = 8;
6159
zp = zero_points[n_idx * zero_point_shape_x + rid / 2];
6260
zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f);
61+
} else {
62+
zp_f = *(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(scales_shape_x) + static_cast<uint64_t>(rid));
6363
}
6464
}
6565

@@ -112,5 +112,10 @@ template void DequantizeBlockwise<float, float>(
112112
const float* zero_points, const int32_t* reorder_idx, int32_t block_size,
113113
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
114114

115+
template void DequantizeBlockwise<float, MLFloat16>(
116+
float* output, const uint8_t* quant_data, const float* scales_data,
117+
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
118+
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);
119+
115120
} // namespace contrib
116121
} // namespace onnxruntime

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Module Name:
2020
#include <cstddef>
2121
#include <cstdlib>
2222
#include <cstdint>
23+
#include <stdexcept>
2324

2425
//
2526
// Define the calling convention for Windows targets.
@@ -1025,18 +1026,6 @@ MlasComputeTanh(
10251026
size_t N
10261027
);
10271028

1028-
//
1029-
// Half-precision floating-point routines.
1030-
//
1031-
1032-
void
1033-
MLASCALL
1034-
MlasConvertHalfToFloatBuffer(
1035-
const unsigned short* Source,
1036-
float* Destination,
1037-
size_t Count
1038-
);
1039-
10401029
//
10411030
// Transpose routines.
10421031
//
@@ -1426,7 +1415,27 @@ using MLAS_FP16 = onnxruntime::MLFloat16;
14261415

14271416
constexpr size_t FP16_SIZE = sizeof(uint16_t);
14281417

1429-
/**
1418+
//
1419+
// Half-precision floating-point routines.
1420+
//
1421+
1422+
void
1423+
MLASCALL
1424+
MlasConvertHalfToFloatBuffer(
1425+
const MLAS_FP16* Source,
1426+
float* Destination,
1427+
size_t Count
1428+
);
1429+
1430+
void
1431+
MLASCALL
1432+
MlasConvertFloatToHalfBuffer(
1433+
const float* Source,
1434+
MLAS_FP16* Destination,
1435+
size_t Count
1436+
);
1437+
1438+
/**
14301439
* @brief Whether current CPU supports FP16 acceleration.
14311440
*/
14321441
bool MLASCALL
@@ -1787,6 +1796,7 @@ MlasTranspose(
17871796
M, N);
17881797
}
17891798

1799+
17901800
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
17911801
/**
17921802
* @brief Max Pooling for fp16 NHWC

onnxruntime/core/mlas/lib/cast.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,35 @@ union fp32_bits {
2323
void
2424
MLASCALL
2525
MlasConvertHalfToFloatBuffer(
26-
const unsigned short* Source,
26+
const MLAS_FP16* Source,
2727
float* Destination,
2828
size_t Count
2929
)
3030
{
31-
3231
if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) {
33-
// If there is no kernel use the reference implementation, adapted from mlas_float16.h.
34-
constexpr fp32_bits magic = {113 << 23};
35-
constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
32+
for (size_t i = 0; i < Count; ++i) {
33+
Destination[i] = Source[i].ToFloat();
34+
}
35+
} else {
36+
// If the kernel is available, use it to perform the conversion.
37+
GetMlasPlatform().CastF16ToF32Kernel(reinterpret_cast<const unsigned short*>(Source), Destination, Count);
38+
}
39+
}
3640

41+
void
42+
MLASCALL
43+
MlasConvertFloatToHalfBuffer(
44+
const float* Source,
45+
MLAS_FP16* Destination,
46+
size_t Count
47+
)
48+
{
49+
if (GetMlasPlatform().CastF32ToF16Kernel == nullptr) {
3750
for (size_t i = 0; i < Count; ++i) {
38-
fp32_bits o;
39-
o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits
40-
uint32_t exp = shifted_exp & o.u; // just the exponent
41-
o.u += (127 - 15) << 23; // exponent adjust
42-
43-
// handle exponent special cases
44-
if (exp == shifted_exp) { // Inf/NaN?
45-
o.u += (128 - 16) << 23; // extra exp adjust
46-
} else if (exp == 0) { // Zero/Denormal?
47-
o.u += 1 << 23; // extra exp adjust
48-
o.f -= magic.f; // renormalize
49-
}
50-
51-
o.u |= (Source[i] & 0x8000) << 16; // sign bit
52-
Destination[i] = o.f;
51+
Destination[i] = MLAS_FP16(Source[i]);
5352
}
54-
5553
} else {
5654
// If the kernel is available, use it to perform the conversion.
57-
GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count);
55+
GetMlasPlatform().CastF32ToF16Kernel(Source, reinterpret_cast<unsigned short*>(Destination), Count);
5856
}
5957
}

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,13 +610,19 @@ void
610610
size_t N
611611
);
612612

613-
typedef
613+
typedef
614614
void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)(
615615
const unsigned short* Source,
616616
float* Destination,
617617
size_t Count
618618
);
619619

620+
typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)(
621+
const float* Source,
622+
unsigned short* Destination,
623+
size_t Count
624+
);
625+
620626
typedef
621627
void
622628
(MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)(
@@ -880,6 +886,8 @@ extern "C" {
880886
#if defined(MLAS_TARGET_AMD64)
881887
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse;
882888
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx;
889+
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx2;
890+
MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2;
883891
#endif
884892

885893
}
@@ -1165,6 +1173,7 @@ struct MLAS_PLATFORM {
11651173
const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr};
11661174

11671175
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
1176+
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
11681177
};
11691178

11701179
inline

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ Return Value:
245245
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel<int8_t, int8_t>;
246246
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel<int8_t, uint8_t>;
247247
this->CastF16ToF32Kernel = nullptr;
248+
this->CastF32ToF16Kernel = nullptr;
248249

249250
#if defined(MLAS_TARGET_AMD64_IX86)
250251

@@ -387,6 +388,9 @@ Return Value:
387388
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
388389
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
389390
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
391+
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
392+
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;
393+
390394

391395
//
392396
// Check if the processor supports Hybrid core architecture.

onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,51 @@ Module Name:
2929
#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h"
3030
#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h"
3131

32+
void
33+
MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size)
34+
{
35+
size_t i = 0;
36+
37+
// Process 16 elements at a time using AVX2
38+
for (; i + 15 < size; i += 16) {
39+
// Load 16 FP16 values into an AVX2 register
40+
__m256i fp16_values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_fp16 + i));
41+
42+
// Convert FP16 values to FP32
43+
__m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values));
44+
__m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1));
45+
46+
// Store the converted FP32 values into the output vector
47+
_mm256_storeu_ps(dst_fp32 + i, fp32_values1);
48+
_mm256_storeu_ps(dst_fp32 + i + 8, fp32_values2);
49+
}
50+
51+
// Process any remaining elements
52+
const MLAS_FP16* fp16 = reinterpret_cast<const MLAS_FP16*>(src_fp16);
53+
for (; i < size; ++i) {
54+
dst_fp32[i] = fp16[i].ToFloat();
55+
}
56+
}
57+
58+
void
59+
MlasCastF32ToF16KernelAvx2(const float* src_fp32, unsigned short* dst_fp16, size_t size)
60+
{
61+
size_t i = 0;
62+
63+
// Process 8 elements at a time using AVX2
64+
for (; i + 8 <= size; i += 8) {
65+
__m256 fp32_chunk = _mm256_loadu_ps(&src_fp32[i]);
66+
__m128i fp16_chunk = _mm256_cvtps_ph(fp32_chunk, _MM_FROUND_TO_NEAREST_INT);
67+
_mm_storeu_si128(reinterpret_cast<__m128i*>(&dst_fp16[i]), fp16_chunk);
68+
}
69+
70+
// Process any remaining elements
71+
for (; i < size; ++i) {
72+
MLAS_FP16 fp16(src_fp32[i]);
73+
dst_fp16[i] = fp16.val;
74+
}
75+
}
76+
3277
MLAS_FORCEINLINE
3378
__m256
3479
load_float_n_avx2(const float* data, int n)

onnxruntime/core/providers/cpu/tensor/cast_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ struct TensorCaster<MLFloat16, float> {
258258
auto out_data = out.MutableData<float>();
259259
auto in_data = in.Data<MLFloat16>();
260260
const size_t shape_size = narrow<size_t>(shape.Size());
261-
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
261+
MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size);
262262
}
263263
};
264264

0 commit comments

Comments
 (0)