-
Notifications
You must be signed in to change notification settings - Fork 19
Implement optimized BF16 support for ARM architecture - [MOD-9079] #623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
600a2e4
SVE implementation for bf16
GuyAv46 b0796d4
add required build flags and fix implementation
GuyAv46 0470abe
final fixes and implement benchmarks
GuyAv46 ffcc6dc
added tests
GuyAv46 ed833c5
implement neon bf16 distance functions
GuyAv46 558451a
implement build flow and benchmarks
GuyAv46 9e762f6
added test
GuyAv46 61af8ec
format
GuyAv46 bb46609
remove redundant check
GuyAv46 1932f17
typo fix
GuyAv46 cd6885a
fixes and cleanup
GuyAv46 c842869
fix build
GuyAv46 df2d2ca
fix svwhilelt_b16 calls
GuyAv46 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
*Copyright Redis Ltd. 2021 - present | ||
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or | ||
*the Server Side Public License v1 (SSPLv1). | ||
*/ | ||
|
||
#include <arm_neon.h> | ||
|
||
inline void InnerProduct_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) { | ||
// Load brain-half-precision vectors | ||
bfloat16x8_t v1 = vld1q_bf16(vec1); | ||
bfloat16x8_t v2 = vld1q_bf16(vec2); | ||
vec1 += 8; | ||
vec2 += 8; | ||
// Compute multiplications and add to the accumulator | ||
acc = vbfdotq_f32(acc, v1, v2); | ||
} | ||
|
||
template <unsigned char residual> // 0..31 | ||
float BF16_InnerProduct_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v); | ||
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v); | ||
const auto *const v1End = vec1 + dimension; | ||
float32x4_t acc1 = vdupq_n_f32(0.0f); | ||
float32x4_t acc2 = vdupq_n_f32(0.0f); | ||
float32x4_t acc3 = vdupq_n_f32(0.0f); | ||
float32x4_t acc4 = vdupq_n_f32(0.0f); | ||
|
||
// First, handle the partial chunk residual | ||
if constexpr (residual % 8) { | ||
auto constexpr chunk_residual = residual % 8; | ||
// TODO: spacial cases and benchmark if its better | ||
// if constexpr (chunk_residual == 1) { | ||
// float16x8_t v1 = vld1q_f16(Vec1); | ||
// } else if constexpr (chunk_residual == 2) { | ||
// } else if constexpr (chunk_residual == 3) { | ||
// } else { | ||
// } | ||
constexpr uint16x8_t mask = { | ||
0xFFFF, | ||
(chunk_residual >= 2) ? 0xFFFF : 0, | ||
(chunk_residual >= 3) ? 0xFFFF : 0, | ||
(chunk_residual >= 4) ? 0xFFFF : 0, | ||
(chunk_residual >= 5) ? 0xFFFF : 0, | ||
(chunk_residual >= 6) ? 0xFFFF : 0, | ||
(chunk_residual >= 7) ? 0xFFFF : 0, | ||
0, | ||
}; | ||
|
||
// Load partial vectors | ||
bfloat16x8_t v1 = vld1q_bf16(vec1); | ||
bfloat16x8_t v2 = vld1q_bf16(vec2); | ||
|
||
// Apply mask to both vectors | ||
bfloat16x8_t masked_v1 = | ||
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask)); | ||
bfloat16x8_t masked_v2 = | ||
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask)); | ||
|
||
acc1 = vbfdotq_f32(acc1, masked_v1, masked_v2); | ||
|
||
// Advance pointers | ||
vec1 += chunk_residual; | ||
vec2 += chunk_residual; | ||
} | ||
|
||
// Handle (residual - (residual % 8)) in chunks of 8 bfloat16 | ||
if constexpr (residual >= 8) | ||
InnerProduct_Step(vec1, vec2, acc2); | ||
if constexpr (residual >= 16) | ||
InnerProduct_Step(vec1, vec2, acc3); | ||
if constexpr (residual >= 24) | ||
InnerProduct_Step(vec1, vec2, acc4); | ||
|
||
// Process the rest of the vectors (the full chunks part) | ||
while (vec1 < v1End) { | ||
// TODO: use `vld1q_f16_x4` for quad-loading? | ||
InnerProduct_Step(vec1, vec2, acc1); | ||
InnerProduct_Step(vec1, vec2, acc2); | ||
InnerProduct_Step(vec1, vec2, acc3); | ||
InnerProduct_Step(vec1, vec2, acc4); | ||
} | ||
|
||
// Accumulate accumulators | ||
acc1 = vpaddq_f32(acc1, acc3); | ||
acc2 = vpaddq_f32(acc2, acc4); | ||
acc1 = vpaddq_f32(acc1, acc2); | ||
|
||
// Pairwise add to get horizontal sum | ||
float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)); | ||
folded = vpadd_f32(folded, folded); | ||
|
||
// Extract result | ||
return 1.0f - vget_lane_f32(folded, 0); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
*Copyright Redis Ltd. 2021 - present | ||
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or | ||
*the Server Side Public License v1 (SSPLv1). | ||
*/ | ||
|
||
#include <arm_sve.h> | ||
|
||
inline void InnerProduct_Step(const bfloat16_t *vec1, const bfloat16_t *vec2, svfloat32_t &acc, | ||
size_t &offset, const size_t chunk) { | ||
svbool_t all = svptrue_b16(); | ||
|
||
// Load brain-half-precision vectors. | ||
svbfloat16_t v1 = svld1_bf16(all, vec1 + offset); | ||
svbfloat16_t v2 = svld1_bf16(all, vec2 + offset); | ||
// Compute multiplications and add to the accumulator | ||
acc = svbfdot(acc, v1, v2); | ||
|
||
// Move to next chunk | ||
offset += chunk; | ||
} | ||
|
||
template <bool partial_chunk, unsigned char additional_steps> // [t/f, 0..3] | ||
float BF16_InnerProduct_SVE(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v); | ||
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v); | ||
const size_t chunk = svcnth(); // number of 16-bit elements in a register | ||
svfloat32_t acc1 = svdup_f32(0.0f); | ||
svfloat32_t acc2 = svdup_f32(0.0f); | ||
svfloat32_t acc3 = svdup_f32(0.0f); | ||
svfloat32_t acc4 = svdup_f32(0.0f); | ||
size_t offset = 0; | ||
|
||
// Process all full vectors | ||
const size_t full_iterations = dimension / chunk / 4; | ||
for (size_t iter = 0; iter < full_iterations; iter++) { | ||
InnerProduct_Step(vec1, vec2, acc1, offset, chunk); | ||
InnerProduct_Step(vec1, vec2, acc2, offset, chunk); | ||
InnerProduct_Step(vec1, vec2, acc3, offset, chunk); | ||
InnerProduct_Step(vec1, vec2, acc4, offset, chunk); | ||
} | ||
|
||
// Perform between 0 and 3 additional steps, according to `additional_steps` value | ||
if constexpr (additional_steps >= 1) | ||
InnerProduct_Step(vec1, vec2, acc1, offset, chunk); | ||
if constexpr (additional_steps >= 2) | ||
InnerProduct_Step(vec1, vec2, acc2, offset, chunk); | ||
if constexpr (additional_steps >= 3) | ||
InnerProduct_Step(vec1, vec2, acc3, offset, chunk); | ||
|
||
// Handle the tail with the residual predicate | ||
if constexpr (partial_chunk) { | ||
svbool_t pg = svwhilelt_b16(offset, dimension); | ||
|
||
// Load brain-half-precision vectors. | ||
// Inactive elements are zeros, according to the docs | ||
svbfloat16_t v1 = svld1_bf16(pg, vec1 + offset); | ||
svbfloat16_t v2 = svld1_bf16(pg, vec2 + offset); | ||
// Compute multiplications and add to the accumulator. | ||
acc4 = svbfdot(acc4, v1, v2); | ||
} | ||
|
||
// Accumulate accumulators | ||
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc3); | ||
acc2 = svadd_f32_x(svptrue_b32(), acc2, acc4); | ||
acc1 = svadd_f32_x(svptrue_b32(), acc1, acc2); | ||
|
||
// Reduce the accumulated sum. | ||
float result = svaddv_f32(svptrue_b32(), acc1); | ||
return 1.0f - result; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* | ||
*Copyright Redis Ltd. 2021 - present | ||
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or | ||
*the Server Side Public License v1 (SSPLv1). | ||
*/ | ||
|
||
#include <arm_neon.h> | ||
|
||
// Assumes little-endianess | ||
inline void L2Sqr_Op(float32x4_t &acc, bfloat16x8_t &v1, bfloat16x8_t &v2) { | ||
float32x4_t v1_lo = vcvtq_low_f32_bf16(v1); | ||
float32x4_t v2_lo = vcvtq_low_f32_bf16(v2); | ||
float32x4_t diff_lo = vsubq_f32(v1_lo, v2_lo); | ||
|
||
acc = vfmaq_f32(acc, diff_lo, diff_lo); | ||
|
||
float32x4_t v1_hi = vcvtq_high_f32_bf16(v1); | ||
float32x4_t v2_hi = vcvtq_high_f32_bf16(v2); | ||
float32x4_t diff_hi = vsubq_f32(v1_hi, v2_hi); | ||
|
||
acc = vfmaq_f32(acc, diff_hi, diff_hi); | ||
} | ||
|
||
inline void L2Sqr_Step(const bfloat16_t *&vec1, const bfloat16_t *&vec2, float32x4_t &acc) { | ||
// Load brain-half-precision vectors | ||
bfloat16x8_t v1 = vld1q_bf16(vec1); | ||
bfloat16x8_t v2 = vld1q_bf16(vec2); | ||
vec1 += 8; | ||
vec2 += 8; | ||
L2Sqr_Op(acc, v1, v2); | ||
} | ||
|
||
template <unsigned char residual> // 0..31 | ||
float BF16_L2Sqr_NEON(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
const auto *vec1 = static_cast<const bfloat16_t *>(pVect1v); | ||
const auto *vec2 = static_cast<const bfloat16_t *>(pVect2v); | ||
const auto *const v1End = vec1 + dimension; | ||
float32x4_t acc1 = vdupq_n_f32(0.0f); | ||
float32x4_t acc2 = vdupq_n_f32(0.0f); | ||
float32x4_t acc3 = vdupq_n_f32(0.0f); | ||
float32x4_t acc4 = vdupq_n_f32(0.0f); | ||
|
||
// First, handle the partial chunk residual | ||
if constexpr (residual % 8) { | ||
auto constexpr chunk_residual = residual % 8; | ||
// TODO: spacial cases and benchmark if its better | ||
// if constexpr (chunk_residual == 1) { | ||
// float16x8_t v1 = vld1q_f16(Vec1); | ||
// } else if constexpr (chunk_residual == 2) { | ||
// } else if constexpr (chunk_residual == 3) { | ||
// } else { | ||
// } | ||
constexpr uint16x8_t mask = { | ||
0xFFFF, | ||
(chunk_residual >= 2) ? 0xFFFF : 0, | ||
(chunk_residual >= 3) ? 0xFFFF : 0, | ||
(chunk_residual >= 4) ? 0xFFFF : 0, | ||
(chunk_residual >= 5) ? 0xFFFF : 0, | ||
(chunk_residual >= 6) ? 0xFFFF : 0, | ||
(chunk_residual >= 7) ? 0xFFFF : 0, | ||
0, | ||
}; | ||
|
||
// Load partial vectors | ||
bfloat16x8_t v1 = vld1q_bf16(vec1); | ||
bfloat16x8_t v2 = vld1q_bf16(vec2); | ||
|
||
// Apply mask to both vectors | ||
bfloat16x8_t masked_v1 = | ||
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v1), mask)); | ||
bfloat16x8_t masked_v2 = | ||
vreinterpretq_bf16_u16(vandq_u16(vreinterpretq_u16_bf16(v2), mask)); | ||
|
||
L2Sqr_Op(acc1, masked_v1, masked_v2); | ||
|
||
// Advance pointers | ||
vec1 += chunk_residual; | ||
vec2 += chunk_residual; | ||
} | ||
|
||
// Handle (residual - (residual % 8)) in chunks of 8 bfloat16 | ||
if constexpr (residual >= 8) | ||
L2Sqr_Step(vec1, vec2, acc2); | ||
if constexpr (residual >= 16) | ||
L2Sqr_Step(vec1, vec2, acc3); | ||
if constexpr (residual >= 24) | ||
L2Sqr_Step(vec1, vec2, acc4); | ||
|
||
// Process the rest of the vectors (the full chunks part) | ||
while (vec1 < v1End) { | ||
// TODO: use `vld1q_f16_x4` for quad-loading? | ||
L2Sqr_Step(vec1, vec2, acc1); | ||
L2Sqr_Step(vec1, vec2, acc2); | ||
L2Sqr_Step(vec1, vec2, acc3); | ||
L2Sqr_Step(vec1, vec2, acc4); | ||
} | ||
|
||
// Accumulate accumulators | ||
acc1 = vpaddq_f32(acc1, acc3); | ||
acc2 = vpaddq_f32(acc2, acc4); | ||
acc1 = vpaddq_f32(acc1, acc2); | ||
|
||
// Pairwise add to get horizontal sum | ||
float32x2_t folded = vadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)); | ||
folded = vpadd_f32(folded, folded); | ||
|
||
// Extract result | ||
return vget_lane_f32(folded, 0); | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.