-
Notifications
You must be signed in to change notification settings - Fork 20
[MOD-8198] Introduce INT8 distance functions #560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
2f34c15
c641d23
1c5eb90
fa8e9ff
a7a556f
43064e8
fb9f1cc
602f8e9
cde5e2d
cdb4d7f
5f01890
2dce6f0
3d3b375
6f211b3
6ac65a3
0d07c5d
3586a76
adbc4d7
cb2c887
880dd33
b79777f
ab159bc
397ac3f
f9b7b87
c4439f3
e526d02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
*Copyright Redis Ltd. 2021 - present | ||
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or | ||
*the Server Side Public License v1 (SSPLv1). | ||
*/ | ||
|
||
#include "VecSim/spaces/space_includes.h" | ||
|
||
static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) { | ||
__m256i temp_a = _mm256_loadu_epi8(pVect1); | ||
__m512i va = _mm512_cvtepi8_epi16(temp_a); | ||
pVect1 += 32; | ||
|
||
__m256i temp_b = _mm256_loadu_epi8(pVect2); | ||
__m512i vb = _mm512_cvtepi8_epi16(temp_b); | ||
pVect2 += 32; | ||
|
||
// _mm512_dpwssd_epi32(src, a, b) | ||
// Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding | ||
// 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results | ||
// with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst. | ||
sum = _mm512_dpwssd_epi32(sum, va, vb); | ||
} | ||
|
||
template <unsigned char residual> // 0..64 | ||
static inline int INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
int8_t *pVect1 = (int8_t *)pVect1v; | ||
int8_t *pVect2 = (int8_t *)pVect2v; | ||
|
||
const int8_t *pEnd1 = pVect1 + dimension; | ||
|
||
__m512i sum = _mm512_setzero_epi32(); | ||
|
||
// Deal with remainder first. `dim` is more than 32, so we have at least one 32-int_8 block, | ||
// so mask loading is guaranteed to be safe | ||
if constexpr (residual % 32) { | ||
__mmask32 mask = (1LU << (residual % 32)) - 1; | ||
__m256i temp_a = _mm256_maskz_loadu_epi8(mask, pVect1); | ||
__m512i va = _mm512_cvtepi8_epi16(temp_a); | ||
pVect1 += residual % 32; | ||
|
||
__m256i temp_b = _mm256_maskz_loadu_epi8(mask, pVect2); | ||
__m512i vb = _mm512_cvtepi8_epi16(temp_b); | ||
pVect2 += residual % 32; | ||
|
||
sum = _mm512_dpwssd_epi32(sum, va, vb); | ||
} | ||
|
||
if constexpr (residual >= 32) { | ||
InnerProductStep(pVect1, pVect2, sum); | ||
} | ||
|
||
// We dealt with the residual part. We are left with some multiple of 64-int_8. | ||
while (pVect1 < pEnd1) { | ||
InnerProductStep(pVect1, pVect2, sum); | ||
InnerProductStep(pVect1, pVect2, sum); | ||
} | ||
|
||
return _mm512_reduce_add_epi32(sum); | ||
} | ||
|
||
template <unsigned char residual> // 0..64 | ||
float INT8_InnerProductSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, | ||
size_t dimension) { | ||
|
||
return 1 - INT8_InnerProductImp<residual>(pVect1v, pVect2v, dimension); | ||
} | ||
template <unsigned char residual> // 0..64 | ||
float INT8_CosineSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, | ||
size_t dimension) { | ||
float ip = INT8_InnerProductImp<residual>(pVect1v, pVect2v, dimension); | ||
float norm_v1 = | ||
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect1v) + dimension); | ||
float norm_v2 = | ||
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect2v) + dimension); | ||
return 1.0f - ip / (norm_v1 * norm_v2); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
#include "VecSim/spaces/functions/AVX512BW_VBMI2.h" | ||
#include "VecSim/spaces/functions/AVX512FP16_VL.h" | ||
#include "VecSim/spaces/functions/AVX512BF16_VL.h" | ||
#include "VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h" | ||
#include "VecSim/spaces/functions/AVX2.h" | ||
#include "VecSim/spaces/functions/SSE3.h" | ||
|
||
|
@@ -196,4 +197,58 @@ dist_func_t<float> IP_FP16_GetDistFunc(size_t dim, unsigned char *alignment, con | |
return ret_dist_func; | ||
} | ||
|
||
dist_func_t<float> IP_INT8_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) { | ||
unsigned char dummy_alignment; | ||
if (alignment == nullptr) { | ||
alignment = &dummy_alignment; | ||
} | ||
|
||
dist_func_t<float> ret_dist_func = INT8_InnerProduct; | ||
// Optimizations assume at least 32 int8. If we have less, we use the naive implementation. | ||
if (dim < 32) { | ||
return ret_dist_func; | ||
} | ||
#ifdef CPU_FEATURES_ARCH_X86_64 | ||
auto features = (arch_opt == nullptr) | ||
? cpu_features::GetX86Info().features | ||
: *static_cast<const cpu_features::X86Features *>(arch_opt); | ||
#ifdef OPT_AVX512_F_BW_VL_VNNI | ||
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { | ||
if (dim % 32 == 0) // no point in aligning if we have an offsetting residual | ||
*alignment = 32 * sizeof(int8_t); // align to 256 bits. | ||
return Choose_INT8_IP_implementation_AVX512F_BW_VL_VNNI(dim); | ||
} | ||
#endif | ||
#endif // __x86_64__ | ||
return ret_dist_func; | ||
} | ||
|
||
dist_func_t<float> Cosine_INT8_GetDistFunc(size_t dim, unsigned char *alignment, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function has the exact same logic as the above except for the call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only common code is the alignment intizliation, which we duplicate in all the functions in this file, how is this function different? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw note to self: check coverage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. coverage: added cosine to small dimension spaces unit tests |
||
const void *arch_opt) { | ||
unsigned char dummy_alignment; | ||
if (alignment == nullptr) { | ||
alignment = &dummy_alignment; | ||
} | ||
|
||
dist_func_t<float> ret_dist_func = INT8_Cosine; | ||
// Optimizations assume at least 32 int8. If we have less, we use the naive implementation. | ||
if (dim < 32) { | ||
return ret_dist_func; | ||
} | ||
#ifdef CPU_FEATURES_ARCH_X86_64 | ||
auto features = (arch_opt == nullptr) | ||
? cpu_features::GetX86Info().features | ||
: *static_cast<const cpu_features::X86Features *>(arch_opt); | ||
#ifdef OPT_AVX512_F_BW_VL_VNNI | ||
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) { | ||
// Align to vector memory size, including the norm at the end of the vector. | ||
if ((dim + sizeof(float)) % 32 == | ||
0) // no point in aligning if we have an offsetting residual | ||
*alignment = 32 * sizeof(int8_t); // align to 256 bits. | ||
meiravgri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Choose_INT8_Cosine_implementation_AVX512F_BW_VL_VNNI(dim); | ||
} | ||
#endif | ||
#endif // __x86_64__ | ||
return ret_dist_func; | ||
} | ||
} // namespace spaces |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
*Copyright Redis Ltd. 2021 - present | ||
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or | ||
*the Server Side Public License v1 (SSPLv1). | ||
*/ | ||
|
||
#include "VecSim/spaces/space_includes.h" | ||
|
||
static inline void L2SqrStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) { | ||
__m256i temp_a = _mm256_loadu_epi8(pVect1); | ||
__m512i va = _mm512_cvtepi8_epi16(temp_a); | ||
pVect1 += 32; | ||
|
||
__m256i temp_b = _mm256_loadu_epi8(pVect2); | ||
__m512i vb = _mm512_cvtepi8_epi16(temp_b); | ||
pVect2 += 32; | ||
|
||
__m512i diff = _mm512_sub_epi16(va, vb); | ||
// _mm512_dpwssd_epi32(src, a, b) | ||
// Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding | ||
// 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results | ||
// with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst. | ||
sum = _mm512_dpwssd_epi32(sum, diff, diff); | ||
meiravgri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <unsigned char residual> // 0..64 | ||
float INT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v, | ||
size_t dimension) { | ||
int8_t *pVect1 = (int8_t *)pVect1v; | ||
int8_t *pVect2 = (int8_t *)pVect2v; | ||
|
||
const int8_t *pEnd1 = pVect1 + dimension; | ||
|
||
__m512i sum = _mm512_setzero_epi32(); | ||
|
||
// Deal with remainder first. `dim` is more than 32, so we have at least one 32-int_8 block, | ||
// so mask loading is guaranteed to be safe | ||
if constexpr (residual % 32) { | ||
constexpr __mmask32 mask = (1LU << (residual % 32)) - 1; | ||
__m256i temp_a = _mm256_loadu_epi8(pVect1); | ||
__m512i va = _mm512_cvtepi8_epi16(temp_a); | ||
pVect1 += residual % 32; | ||
|
||
__m256i temp_b = _mm256_loadu_epi8(pVect2); | ||
__m512i vb = _mm512_cvtepi8_epi16(temp_b); | ||
pVect2 += residual % 32; | ||
|
||
__m512i diff = _mm512_maskz_sub_epi16(mask, va, vb); | ||
sum = _mm512_dpwssd_epi32(sum, diff, diff); | ||
} | ||
|
||
if constexpr (residual >= 32) { | ||
L2SqrStep(pVect1, pVect2, sum); | ||
} | ||
|
||
// We dealt with the residual part. We are left with some multiple of 64-int_8. | ||
while (pVect1 < pEnd1) { | ||
L2SqrStep(pVect1, pVect2, sum); | ||
L2SqrStep(pVect1, pVect2, sum); | ||
} | ||
|
||
return _mm512_reduce_add_epi32(sum); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use
_mm512_dpbusd_epi32
, avoiding loading only 32 elements and converting them to int16?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch!
Unfortunately,
_mm512_dpbusd_epi32
expectsa
to be unsigned 8-bit integers vecor, andb
to be signed 8-bit vectorSo although indeed faster, it gives the wrong results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bummer... So there is no function for multiplying int8int8 or uint8uint8? Only an unsigned by a signed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope
we might want to consider using
_mm512_unpackhi_epi8
and_mm512_unpacklo_epi8
foruint8_t