Skip to content

[MOD-8198] Introduce INT8 distance functions #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/x86_64InstructionFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(x86_64)|(AMD64|amd64)|(^i.86$)")
CHECK_CXX_COMPILER_FLAG(-mavx512vbmi2 CXX_AVX512VBMI2)
CHECK_CXX_COMPILER_FLAG(-mavx512fp16 CXX_AVX512FP16)
CHECK_CXX_COMPILER_FLAG(-mavx512f CXX_AVX512F)
CHECK_CXX_COMPILER_FLAG(-mavx512vnni CXX_AVX512VNNI)
CHECK_CXX_COMPILER_FLAG(-mavx2 CXX_AVX2)
CHECK_CXX_COMPILER_FLAG(-mavx CXX_AVX)
CHECK_CXX_COMPILER_FLAG(-mf16c CXX_F16C)
Expand Down Expand Up @@ -48,6 +49,10 @@ if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(x86_64)|(AMD64|amd64)|(^i.86$)")
add_compile_definitions(OPT_AVX512_BW_VBMI2)
endif()

if(CXX_AVX512F AND CXX_AVX512BW AND CXX_AVX512VL AND CXX_AVX512VNNI)
add_compile_definitions(OPT_AVX512_F_BW_VL_VNNI)
endif()

if(CXX_F16C AND CXX_FMA AND CXX_AVX)
add_compile_definitions(OPT_F16C)
endif()
Expand Down
6 changes: 6 additions & 0 deletions src/VecSim/spaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "(x86_64)|(AMD64|amd64)|(^i.86$)")
list(APPEND OPTIMIZATIONS functions/AVX512F.cpp)
endif()

if(CXX_AVX512F AND CXX_AVX512BW AND CXX_AVX512VL AND CXX_AVX512VNNI)
message("Building with AVX512F, AVX512BW, AVX512VL and AVX512VNNI")
set_source_files_properties(functions/AVX512F_BW_VL_VNNI.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512vnni")
list(APPEND OPTIMIZATIONS functions/AVX512F_BW_VL_VNNI.cpp)
endif()

if(CXX_AVX2)
message("Building with AVX2")
set_source_files_properties(functions/AVX2.cpp PROPERTIES COMPILE_FLAGS -mavx2)
Expand Down
24 changes: 24 additions & 0 deletions src/VecSim/spaces/IP/IP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,27 @@ float FP16_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension
}
return 1.0f - res;
}

static inline int INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) {
int8_t *pVect1 = (int8_t *)pVect1v;
int8_t *pVect2 = (int8_t *)pVect2v;

int res = 0;
for (size_t i = 0; i < dimension; i++) {
res += pVect1[i] * pVect2[i];
}
return res;
}

float INT8_InnerProduct(const void *pVect1v, const void *pVect2v, size_t dimension) {
return 1 - INT8_InnerProductImp(pVect1v, pVect2v, dimension);
}

float INT8_Cosine(const void *pVect1v, const void *pVect2v, size_t dimension) {
// We expect the vectors' norm to be stored at the end of the vector.
float norm_v1 =
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect1v) + dimension);
float norm_v2 =
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect2v) + dimension);
return 1.0f - float(INT8_InnerProductImp(pVect1v, pVect2v, dimension)) / (norm_v1 * norm_v2);
}
3 changes: 3 additions & 0 deletions src/VecSim/spaces/IP/IP.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ float FP16_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension

float BF16_InnerProduct_LittleEndian(const void *pVect1v, const void *pVect2v, size_t dimension);
float BF16_InnerProduct_BigEndian(const void *pVect1v, const void *pVect2v, size_t dimension);

float INT8_InnerProduct(const void *pVect1, const void *pVect2, size_t dimension);
float INT8_Cosine(const void *pVect1, const void *pVect2, size_t dimension);
77 changes: 77 additions & 0 deletions src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_INT8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
*Copyright Redis Ltd. 2021 - present
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
*the Server Side Public License v1 (SSPLv1).
*/

#include "VecSim/spaces/space_includes.h"

static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
__m256i temp_a = _mm256_loadu_epi8(pVect1);
__m512i va = _mm512_cvtepi8_epi16(temp_a);
pVect1 += 32;

__m256i temp_b = _mm256_loadu_epi8(pVect2);
__m512i vb = _mm512_cvtepi8_epi16(temp_b);
pVect2 += 32;

// _mm512_dpwssd_epi32(src, a, b)
// Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding
// 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results
// with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst.
sum = _mm512_dpwssd_epi32(sum, va, vb);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use _mm512_dpbusd_epi32, avoiding loading only 32 elements and converting them to int16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!
Unfortunately, _mm512_dpbusd_epi32 expects a to be unsigned 8-bit integers vecor, and b to be signed 8-bit vector

Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding signed 8-bit integers in b

So although indeed faster, it gives the wrong results

Benchmark                                  Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------------------
IP_dpwssd/Dim:512                       5.41 ns         5.41 ns    128425185 dist=-1.024k
IP_dpbusd/Dim:512                       3.26 ns         3.26 ns    218138693 dist=130.048k
IP_sanity_bm/Dim:512/iterations:1        490 ns          360 ns            1 dist=-1.024k

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bummer... So there is no function for multiplying int8int8 or uint8uint8? Only an unsigned by a signed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope
we might want to consider using _mm512_unpackhi_epi8 and _mm512_unpacklo_epi8 for uint8_t

}

template <unsigned char residual> // 0..64
static inline int INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) {
int8_t *pVect1 = (int8_t *)pVect1v;
int8_t *pVect2 = (int8_t *)pVect2v;

const int8_t *pEnd1 = pVect1 + dimension;

__m512i sum = _mm512_setzero_epi32();

// Deal with remainder first. `dim` is more than 32, so we have at least one 32-int_8 block,
// so mask loading is guaranteed to be safe
if constexpr (residual % 32) {
__mmask32 mask = (1LU << (residual % 32)) - 1;
__m256i temp_a = _mm256_maskz_loadu_epi8(mask, pVect1);
__m512i va = _mm512_cvtepi8_epi16(temp_a);
pVect1 += residual % 32;

__m256i temp_b = _mm256_maskz_loadu_epi8(mask, pVect2);
__m512i vb = _mm512_cvtepi8_epi16(temp_b);
pVect2 += residual % 32;

sum = _mm512_dpwssd_epi32(sum, va, vb);
}

if constexpr (residual >= 32) {
InnerProductStep(pVect1, pVect2, sum);
}

// We dealt with the residual part. We are left with some multiple of 64-int_8.
while (pVect1 < pEnd1) {
InnerProductStep(pVect1, pVect2, sum);
InnerProductStep(pVect1, pVect2, sum);
}

return _mm512_reduce_add_epi32(sum);
}

template <unsigned char residual> // 0..64
float INT8_InnerProductSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v,
size_t dimension) {

return 1 - INT8_InnerProductImp<residual>(pVect1v, pVect2v, dimension);
}
template <unsigned char residual> // 0..64
float INT8_CosineSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v,
size_t dimension) {
float ip = INT8_InnerProductImp<residual>(pVect1v, pVect2v, dimension);
float norm_v1 =
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect1v) + dimension);
float norm_v2 =
*reinterpret_cast<const float *>(static_cast<const int8_t *>(pVect2v) + dimension);
return 1.0f - ip / (norm_v1 * norm_v2);
}
56 changes: 56 additions & 0 deletions src/VecSim/spaces/IP_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "VecSim/spaces/functions/AVX512BW_VBMI2.h"
#include "VecSim/spaces/functions/AVX512FP16_VL.h"
#include "VecSim/spaces/functions/AVX512BF16_VL.h"
#include "VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h"
#include "VecSim/spaces/functions/AVX2.h"
#include "VecSim/spaces/functions/SSE3.h"

Expand Down Expand Up @@ -196,4 +197,59 @@ dist_func_t<float> IP_FP16_GetDistFunc(size_t dim, unsigned char *alignment, con
return ret_dist_func;
}

dist_func_t<float> IP_INT8_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) {
unsigned char dummy_alignment;
if (alignment == nullptr) {
alignment = &dummy_alignment;
}

dist_func_t<float> ret_dist_func = INT8_InnerProduct;
// Optimizations assume at least 32 int8. If we have less, we use the naive implementation.
if (dim < 32) {
return ret_dist_func;
}
#ifdef CPU_FEATURES_ARCH_X86_64
auto features = (arch_opt == nullptr)
? cpu_features::GetX86Info().features
: *static_cast<const cpu_features::X86Features *>(arch_opt);
#ifdef OPT_AVX512_F_BW_VL_VNNI
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) {
if (dim % 32 == 0) // no point in aligning if we have an offsetting residual
*alignment = 32 * sizeof(int8_t); // align to 256 bits.
return Choose_INT8_IP_implementation_AVX512F_BW_VL_VNNI(dim);
}
#endif
#endif // __x86_64__
return ret_dist_func;
}

dist_func_t<float> Cosine_INT8_GetDistFunc(size_t dim, unsigned char *alignment,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has the exact same logic as the above except for the call to Choose_INT8_Cosine_implementation_AVX512F_BW_VL_VNNI eight. Consider doing a consolidation for the common logic

Copy link
Collaborator Author

@meiravgri meiravgri Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only common code is the alignment intizliation, which we duplicate in all the functions in this file, how is this function different?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw note to self: check coverage

Copy link
Collaborator Author

@meiravgri meiravgri Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coverage: added cosine to small dimension spaces unit tests

const void *arch_opt) {
unsigned char dummy_alignment;
if (alignment == nullptr) {
alignment = &dummy_alignment;
}

dist_func_t<float> ret_dist_func = INT8_Cosine;
// Optimizations assume at least 32 int8. If we have less, we use the naive implementation.
if (dim < 32) {
return ret_dist_func;
}
#ifdef CPU_FEATURES_ARCH_X86_64
auto features = (arch_opt == nullptr)
? cpu_features::GetX86Info().features
: *static_cast<const cpu_features::X86Features *>(arch_opt);
#ifdef OPT_AVX512_F_BW_VL_VNNI
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) {
// For int8 vectors with cosine distance, the extra float for the norm shifts alignment to
// `(dim + sizeof(float)) % 32`.
// Vectors satisfying this have a residual, causing offset loads during calculation.
// To avoid complexity, we skip alignment here, assuming the performance impact is
// negligible.
return Choose_INT8_Cosine_implementation_AVX512F_BW_VL_VNNI(dim);
}
#endif
#endif // __x86_64__
return ret_dist_func;
}
} // namespace spaces
4 changes: 4 additions & 0 deletions src/VecSim/spaces/IP_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ dist_func_t<float> IP_BF16_GetDistFunc(size_t dim, unsigned char *alignment = nu
const void *arch_opt = nullptr);
dist_func_t<float> IP_FP16_GetDistFunc(size_t dim, unsigned char *alignment = nullptr,
const void *arch_opt = nullptr);
dist_func_t<float> IP_INT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr,
const void *arch_opt = nullptr);
dist_func_t<float> Cosine_INT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr,
const void *arch_opt = nullptr);
} // namespace spaces
14 changes: 14 additions & 0 deletions src/VecSim/spaces/L2/L2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,17 @@ float FP16_L2Sqr(const void *pVect1, const void *pVect2, size_t dimension) {
}
return res;
}

float INT8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension) {
int8_t *pVect1 = (int8_t *)pVect1v;
int8_t *pVect2 = (int8_t *)pVect2v;

int res = 0;
for (size_t i = 0; i < dimension; i++) {
int16_t a = pVect1[i];
int16_t b = pVect2[i];
int16_t diff = a - b;
res += diff * diff;
}
return float(res);
}
2 changes: 2 additions & 0 deletions src/VecSim/spaces/L2/L2.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ float BF16_L2Sqr_LittleEndian(const void *pVect1v, const void *pVect2v, size_t d
float BF16_L2Sqr_BigEndian(const void *pVect1v, const void *pVect2v, size_t dimension);

float FP16_L2Sqr(const void *pVect1, const void *pVect2, size_t dimension);

float INT8_L2Sqr(const void *pVect1v, const void *pVect2v, size_t dimension);
63 changes: 63 additions & 0 deletions src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_INT8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
*Copyright Redis Ltd. 2021 - present
*Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or
*the Server Side Public License v1 (SSPLv1).
*/

#include "VecSim/spaces/space_includes.h"

static inline void L2SqrStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
__m256i temp_a = _mm256_loadu_epi8(pVect1);
__m512i va = _mm512_cvtepi8_epi16(temp_a);
pVect1 += 32;

__m256i temp_b = _mm256_loadu_epi8(pVect2);
__m512i vb = _mm512_cvtepi8_epi16(temp_b);
pVect2 += 32;

__m512i diff = _mm512_sub_epi16(va, vb);
// _mm512_dpwssd_epi32(src, a, b)
// Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding
// 16-bit integers in `b`, producing 2 intermediate signed 32-bit results. Sum these 2 results
// with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst.
sum = _mm512_dpwssd_epi32(sum, diff, diff);
}

template <unsigned char residual> // 0..64
float INT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v,
size_t dimension) {
int8_t *pVect1 = (int8_t *)pVect1v;
int8_t *pVect2 = (int8_t *)pVect2v;

const int8_t *pEnd1 = pVect1 + dimension;

__m512i sum = _mm512_setzero_epi32();

// Deal with remainder first. `dim` is more than 32, so we have at least one 32-int_8 block,
// so mask loading is guaranteed to be safe
if constexpr (residual % 32) {
constexpr __mmask32 mask = (1LU << (residual % 32)) - 1;
__m256i temp_a = _mm256_loadu_epi8(pVect1);
__m512i va = _mm512_cvtepi8_epi16(temp_a);
pVect1 += residual % 32;

__m256i temp_b = _mm256_loadu_epi8(pVect2);
__m512i vb = _mm512_cvtepi8_epi16(temp_b);
pVect2 += residual % 32;

__m512i diff = _mm512_maskz_sub_epi16(mask, va, vb);
sum = _mm512_dpwssd_epi32(sum, diff, diff);
}

if constexpr (residual >= 32) {
L2SqrStep(pVect1, pVect2, sum);
}

// We dealt with the residual part. We are left with some multiple of 64-int_8.
while (pVect1 < pEnd1) {
L2SqrStep(pVect1, pVect2, sum);
L2SqrStep(pVect1, pVect2, sum);
}

return _mm512_reduce_add_epi32(sum);
}
27 changes: 27 additions & 0 deletions src/VecSim/spaces/L2_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "VecSim/spaces/functions/SSE.h"
#include "VecSim/spaces/functions/AVX512BW_VBMI2.h"
#include "VecSim/spaces/functions/AVX512FP16_VL.h"
#include "VecSim/spaces/functions/AVX512F_BW_VL_VNNI.h"
#include "VecSim/spaces/functions/AVX2.h"
#include "VecSim/spaces/functions/SSE3.h"

Expand Down Expand Up @@ -189,4 +190,30 @@ dist_func_t<float> L2_FP16_GetDistFunc(size_t dim, unsigned char *alignment, con
return ret_dist_func;
}

dist_func_t<float> L2_INT8_GetDistFunc(size_t dim, unsigned char *alignment, const void *arch_opt) {
unsigned char dummy_alignment;
if (alignment == nullptr) {
alignment = &dummy_alignment;
}

dist_func_t<float> ret_dist_func = INT8_L2Sqr;
// Optimizations assume at least 32 int8. If we have less, we use the naive implementation.
if (dim < 32) {
return ret_dist_func;
}
#ifdef CPU_FEATURES_ARCH_X86_64
auto features = (arch_opt == nullptr)
? cpu_features::GetX86Info().features
: *static_cast<const cpu_features::X86Features *>(arch_opt);
#ifdef OPT_AVX512_F_BW_VL_VNNI
if (features.avx512f && features.avx512bw && features.avx512vl && features.avx512vnni) {
if (dim % 32 == 0) // no point in aligning if we have an offsetting residual
*alignment = 32 * sizeof(int8_t); // align to 256 bits.
return Choose_INT8_L2_implementation_AVX512F_BW_VL_VNNI(dim);
}
#endif
#endif // __x86_64__
return ret_dist_func;
}

} // namespace spaces
2 changes: 2 additions & 0 deletions src/VecSim/spaces/L2_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ dist_func_t<float> L2_BF16_GetDistFunc(size_t dim, unsigned char *alignment = nu
const void *arch_opt = nullptr);
dist_func_t<float> L2_FP16_GetDistFunc(size_t dim, unsigned char *alignment = nullptr,
const void *arch_opt = nullptr);
dist_func_t<float> L2_INT8_GetDistFunc(size_t dim, unsigned char *alignment = nullptr,
const void *arch_opt = nullptr);
} // namespace spaces
4 changes: 2 additions & 2 deletions src/VecSim/spaces/computer/calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class IndexCalculatorInterface : public VecsimBaseObject {
/**
* This object purpose is to calculate the distance between two vectors.
* It extends the IndexCalculatorInterface class' type to hold the distance function.
* Every specific implmentation of the distance claculater should hold by refrence or by value the
* Every specific implementation of the distance calculator should hold by reference or by value the
* parameters required for the calculation. The distance calculation API of all DistanceCalculator
* classes is: calc_dist(v1,v2,dim). Internally it calls the distance function according the
* template signature, allowing fexability in the distance function arguments.
* template signature, allowing flexibility in the distance function arguments.
*/
template <typename DistType, typename DistFuncType>
class DistanceCalculatorInterface : public IndexCalculatorInterface<DistType> {
Expand Down
Loading
Loading