Skip to content

Implement optimized BF16 support for ARM architecture - [MOD-9079] #623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmake/aarch64InstructionFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ CHECK_CXX_COMPILER_FLAG("-march=armv7-a+neon" CXX_ARMV7_NEON)
CHECK_CXX_COMPILER_FLAG("-march=armv8-a" CXX_ARMV8A)
CHECK_CXX_COMPILER_FLAG("-march=armv8-a+sve" CXX_SVE)
CHECK_CXX_COMPILER_FLAG("-march=armv9-a+sve2" CXX_SVE2)
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+bf16" CXX_NEON_BF16)
CHECK_CXX_COMPILER_FLAG("-march=armv8.2-a+sve+bf16" CXX_SVE_BF16)

# Only use ARMv9 if both compiler and CPU support it
if(CXX_SVE2)
Expand All @@ -17,6 +19,12 @@ endif()
if (CXX_ARMV8A OR CXX_ARMV7_NEON)
add_compile_definitions(OPT_NEON)
endif()
if (CXX_NEON_BF16)
add_compile_definitions(OPT_NEON_BF16)
endif()
if (CXX_SVE)
add_compile_definitions(OPT_SVE)
endif()
if (CXX_SVE_BF16)
add_compile_definitions(OPT_SVE_BF16)
endif()
14 changes: 14 additions & 0 deletions src/VecSim/spaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,27 @@ if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)|(ARM64)|(armv.*)")
list(APPEND OPTIMIZATIONS functions/NEON.cpp)
endif()

# NEON bfloat16 support
if (CXX_NEON_BF16)
message("Building with NEON + BF16")
set_source_files_properties(functions/NEON_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+bf16")
list(APPEND OPTIMIZATIONS functions/NEON_BF16.cpp)
endif()

# SVE support
if (CXX_SVE)
message("Building with SVE")
set_source_files_properties(functions/SVE.cpp PROPERTIES COMPILE_FLAGS "-march=armv8-a+sve")
list(APPEND OPTIMIZATIONS functions/SVE.cpp)
endif()

# SVE with BF16 support
if (CXX_SVE_BF16)
message("Building with SVE + BF16")
set_source_files_properties(functions/SVE_BF16.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+sve+bf16")
list(APPEND OPTIMIZATIONS functions/SVE_BF16.cpp)
endif()

# SVE2 support
if (CXX_SVE2)
message("Building with ARMV9A and SVE2")
Expand Down
89 changes: 89 additions & 0 deletions src/VecSim/spaces/IP/IP_NEON_BF16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
*Copyright Redis Ltd. 2021 - present
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
*the Server Side Public License v1 (SSPLv1).
*/

#include <arm_neon.h>

inline void InnerProduct_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) {
// Load brain-half-precision vectors
bfloat16x8_t v1 = vld1q_bf16(vec1);
bfloat16x8_t v2 = vld1q_bf16(vec2);
vec1 += 8;
vec2 += 8;
// Compute multiplications and add to the accumulator
acc = vbfdotq_f32(acc, v1, v2);
}

template <unsigned char residual> // 0..31
float BF16_InnerProduct_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) {
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
const auto *const v1End = vec1 + dimension;
float32x4_t acc1 = vdupq_n_f32(0.0f);
float32x4_t acc2 = vdupq_n_f32(0.0f);
float32x4_t acc3 = vdupq_n_f32(0.0f);
float32x4_t acc4 = vdupq_n_f32(0.0f);

// First, handle the partial chunk residual
if constexpr (residual % 8) {
auto constexpr chunk_residual = residual % 8;
// TODO: special cases for some residuals and benchmark if its better
constexpr uint16x8_t mask = {
0xFFFF,
(chunk_residual >= 2) ? 0xFFFF : 0,
(chunk_residual >= 3) ? 0xFFFF : 0,
(chunk_residual >= 4) ? 0xFFFF : 0,
(chunk_residual >= 5) ? 0xFFFF : 0,
(chunk_residual >= 6) ? 0xFFFF : 0,
(chunk_residual >= 7) ? 0xFFFF : 0,
0,
};

// Load partial vectors
bfloat16x8_t v1 = vld1q_bf16(vec1);
bfloat16x8_t v2 = vld1q_bf16(vec2);

// Apply mask to both vectors
bfloat16x8_t masked_v1 =
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask));
bfloat16x8_t masked_v2 =
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask));

acc1 = vbfdotq_f32(acc1, masked_v1, masked_v2);

// Advance pointers
vec1 += chunk_residual;
vec2 += chunk_residual;
}

// Handle (residual - (residual % 8)) in chunks of 8 bfloat16
if constexpr (residual >= 8)
InnerProduct_Step(vec1, vec2, acc2);
if constexpr (residual >= 16)
InnerProduct_Step(vec1, vec2, acc3);
if constexpr (residual >= 24)
InnerProduct_Step(vec1, vec2, acc4);

// Process the rest of the vectors (the full chunks part)
while (vec1 < v1End) {
// TODO: use `vld1q_f16_x4` for quad-loading?
InnerProduct_Step(vec1, vec2, acc1);
InnerProduct_Step(vec1, vec2, acc2);
InnerProduct_Step(vec1, vec2, acc3);
InnerProduct_Step(vec1, vec2, acc4);
}

// Accumulate accumulators
acc1 = vpaddq_f32(acc1, acc3);
acc2 = vpaddq_f32(acc2, acc4);
acc1 = vpaddq_f32(acc1, acc2);

// Pairwise add to get horizontal sum
float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
folded = vpadd_f32(folded, folded);

// Extract result
return 1.0f - vget_lane_f32(folded, 0);
}
71 changes: 71 additions & 0 deletions src/VecSim/spaces/IP/IP_SVE_BF16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
*Copyright Redis Ltd. 2021 - present
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
*the Server Side Public License v1 (SSPLv1).
*/

#include <arm_sve.h>

inline void InnerProduct_Step(const bfloat16_t *vec1, const bfloat16_t *vec2, svfloat32_t &acc,
size_t &offset, const size_t chunk) {
svbool_t all = svptrue_b16();

// Load brain-half-precision vectors.
svbfloat16_t v1 = svld1_bf16(all, vec1 + offset);
svbfloat16_t v2 = svld1_bf16(all, vec2 + offset);
// Compute multiplications and add to the accumulator
acc = svbfdot(acc, v1, v2);

// Move to next chunk
offset += chunk;
}

template <bool partial_chunk, unsigned char additional_steps> // [t/f, 0..3]
float BF16_InnerProduct_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) {
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
const size_t chunk = svcnth(); // number of 16-bit elements in a register
svfloat32_t acc1 = svdup_f32(0.0f);
svfloat32_t acc2 = svdup_f32(0.0f);
svfloat32_t acc3 = svdup_f32(0.0f);
svfloat32_t acc4 = svdup_f32(0.0f);
size_t offset = 0;

// Process all full vectors
const size_t full_iterations = dimension / chunk / 4;
for (size_t iter = 0; iter < full_iterations; iter++) {
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);
InnerProduct_Step(vec1, vec2, acc4, offset, chunk);
}

// Perform between 0 and 3 additional steps, according to `additional_steps` value
if constexpr (additional_steps >= 1)
InnerProduct_Step(vec1, vec2, acc1, offset, chunk);
if constexpr (additional_steps >= 2)
InnerProduct_Step(vec1, vec2, acc2, offset, chunk);
if constexpr (additional_steps >= 3)
InnerProduct_Step(vec1, vec2, acc3, offset, chunk);

// Handle the tail with the residual predicate
if constexpr (partial_chunk) {
svbool_t pg = svwhilelt_b16_u64(offset, dimension);

// Load brain-half-precision vectors.
// Inactive elements are zeros, according to the docs
svbfloat16_t v1 = svld1_bf16(pg, vec1 + offset);
svbfloat16_t v2 = svld1_bf16(pg, vec2 + offset);
// Compute multiplications and add to the accumulator.
acc4 = svbfdot(acc4, v1, v2);
}

// Accumulate accumulators
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc3);
acc2 = svadd_f32_x(svptrue_b32(), acc2, acc4);
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc2);

// Reduce the accumulated sum.
float result = svaddv_f32(svptrue_b32(), acc1);
return 1.0f - result;
}
20 changes: 18 additions & 2 deletions src/VecSim/spaces/IP_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "VecSim/spaces/functions/AVX2.h"
#include "VecSim/spaces/functions/SSE3.h"
#include "VecSim/spaces/functions/NEON.h"
#include "VecSim/spaces/functions/NEON_BF16.h"
#include "VecSim/spaces/functions/SVE.h"
#include "VecSim/spaces/functions/SVE_BF16.h"
#include "VecSim/spaces/functions/SVE2.h"

using bfloat16 = vecsim_types::bfloat16;
Expand Down Expand Up @@ -134,13 +136,27 @@ dist_func_t<float> IP_BF16_GetDistFunc(size_t dim, unsigned char *alignment, con
if (!is_little_endian()) {
return BF16_InnerProduct_BigEndian;
}
auto features = getCpuOptimizationFeatures(arch_opt);

#if defined(CPU_FEATURES_ARCH_AARCH64)
#ifdef OPT_SVE_BF16
if (features.svebf16) {
return Choose_BF16_IP_implementation_SVE_BF16(dim);
}
#endif
#ifdef OPT_NEON_BF16
if (features.bf16 && dim >= 8) { // Optimization assumes at least 8 BF16s (full chunk)
return Choose_BF16_IP_implementation_NEON_BF16(dim);
}
#endif
#endif // AARCH64

#if defined(CPU_FEATURES_ARCH_X86_64)
// Optimizations assume at least 32 bfloats. If we have less, we use the naive implementation.
if (dim < 32) {
return ret_dist_func;
}

#ifdef CPU_FEATURES_ARCH_X86_64
auto features = getCpuOptimizationFeatures(arch_opt);
#ifdef OPT_AVX512_BF16_VL
if (features.avx512_bf16 && features.avx512vl) {
if (dim % 32 == 0) // no point in aligning if we have an offsetting residual
Expand Down
103 changes: 103 additions & 0 deletions src/VecSim/spaces/L2/L2_NEON_BF16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
*Copyright Redis Ltd. 2021 - present
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
*the Server Side Public License v1 (SSPLv1).
*/

#include <arm_neon.h>

// Assumes little-endianess
inline void L2Sqr_Op(float32x4_t &acc, bfloat16x8_t &v1, bfloat16x8_t &v2) {
float32x4_t v1_lo = vcvtq_low_f32_bf16(v1);
float32x4_t v2_lo = vcvtq_low_f32_bf16(v2);
float32x4_t diff_lo = vsubq_f32(v1_lo, v2_lo);

acc = vfmaq_f32(acc, diff_lo, diff_lo);

float32x4_t v1_hi = vcvtq_high_f32_bf16(v1);
float32x4_t v2_hi = vcvtq_high_f32_bf16(v2);
float32x4_t diff_hi = vsubq_f32(v1_hi, v2_hi);

acc = vfmaq_f32(acc, diff_hi, diff_hi);
}

inline void L2Sqr_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) {
// Load brain-half-precision vectors
bfloat16x8_t v1 = vld1q_bf16(vec1);
bfloat16x8_t v2 = vld1q_bf16(vec2);
vec1 += 8;
vec2 += 8;
L2Sqr_Op(acc, v1, v2);
}

template <unsigned char residual> // 0..31
float BF16_L2Sqr_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) {
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v);
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v);
const auto *const v1End = vec1 + dimension;
float32x4_t acc1 = vdupq_n_f32(0.0f);
float32x4_t acc2 = vdupq_n_f32(0.0f);
float32x4_t acc3 = vdupq_n_f32(0.0f);
float32x4_t acc4 = vdupq_n_f32(0.0f);

// First, handle the partial chunk residual
if constexpr (residual % 8) {
auto constexpr chunk_residual = residual % 8;
// TODO: special cases for some residuals and benchmark if its better
constexpr uint16x8_t mask = {
0xFFFF,
(chunk_residual >= 2) ? 0xFFFF : 0,
(chunk_residual >= 3) ? 0xFFFF : 0,
(chunk_residual >= 4) ? 0xFFFF : 0,
(chunk_residual >= 5) ? 0xFFFF : 0,
(chunk_residual >= 6) ? 0xFFFF : 0,
(chunk_residual >= 7) ? 0xFFFF : 0,
0,
};

// Load partial vectors
bfloat16x8_t v1 = vld1q_bf16(vec1);
bfloat16x8_t v2 = vld1q_bf16(vec2);

// Apply mask to both vectors
bfloat16x8_t masked_v1 =
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask));
bfloat16x8_t masked_v2 =
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask));

L2Sqr_Op(acc1, masked_v1, masked_v2);

// Advance pointers
vec1 += chunk_residual;
vec2 += chunk_residual;
}

// Handle (residual - (residual % 8)) in chunks of 8 bfloat16
if constexpr (residual >= 8)
L2Sqr_Step(vec1, vec2, acc2);
if constexpr (residual >= 16)
L2Sqr_Step(vec1, vec2, acc3);
if constexpr (residual >= 24)
L2Sqr_Step(vec1, vec2, acc4);

// Process the rest of the vectors (the full chunks part)
while (vec1 < v1End) {
// TODO: use `vld1q_f16_x4` for quad-loading?
L2Sqr_Step(vec1, vec2, acc1);
L2Sqr_Step(vec1, vec2, acc2);
L2Sqr_Step(vec1, vec2, acc3);
L2Sqr_Step(vec1, vec2, acc4);
}

// Accumulate accumulators
acc1 = vpaddq_f32(acc1, acc3);
acc2 = vpaddq_f32(acc2, acc4);
acc1 = vpaddq_f32(acc1, acc2);

// Pairwise add to get horizontal sum
float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
folded = vpadd_f32(folded, folded);

// Extract result
return vget_lane_f32(folded, 0);
}
Loading
Loading