-
Notifications
You must be signed in to change notification settings - Fork 20
Dorer SQ8 dist functions [MOD-9626] #673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
69d63ac
add sq8
dor-forer af85432
Change to IP_AVX512F
dor-forer b215799
Change
dor-forer 8b4188b
vec1
dor-forer a1d1a16
float
dor-forer b5860bb
finish
dor-forer 0d07d71
now
dor-forer 66c49e8
remove Choose_SQ8_Cosine_implementation_AVX512F
dor-forer aa26c71
in test
dor-forer 43b58a8
alignemnt
dor-forer 1e12fa3
back to bw
dor-forer 984a030
back again
dor-forer c3670a8
again
dor-forer 11303b7
optimization
dor-forer 7474c05
more BW
dor-forer 2cfd9b6
fix avx
dor-forer 3cdf05e
add avx cosine test
dor-forer fc8bc7d
avx
dor-forer 513839b
add impl
dor-forer f676c1b
add l2
dor-forer 9a899cc
replace OPT_AVX512_F_BW_VL_VNNI
dor-forer 4fa5327
align
dor-forer 1379d6d
Fix avx
dor-forer f7fdb2b
add l2 sse
dor-forer 4fa88b2
Remove prints
dor-forer 4476833
sve2 l2
dor-forer 2a7477c
add neon
dor-forer b1f502c
fix sve
dor-forer dc154b5
add sq8 cosine test
dor-forer 25a9400
test utils
dor-forer 9ced0be
static const
dor-forer 6028dd7
format
dor-forer 3c2ee11
change to uint
dor-forer 5c2952c
Merge branch 'main' of https://github.com/RedisAI/VectorSimilarity in…
dor-forer ad3985e
format
dor-forer 41216e6
Merge branch 'main' of https://github.com/RedisAI/VectorSimilarity in…
dor-forer 76d2fdd
added fma avx2
dor-forer b47cc52
format
dor-forer 6566a0b
remove opt.avx2
dor-forer d767ea9
fix OPT_AVX2 bm-spaces
dor-forer ea0ac00
pr chanes
dor-forer ef09ead
format
dor-forer 7567730
change to _mm_cvtsi32_si128
dor-forer a767547
Change in the l2
dor-forer e6422dc
PR changes
dor-forer 10a6098
added chunk to functions
dor-forer 767e190
diff squared
dor-forer 44be275
format
dor-forer 3a956bf
chnage diff
dor-forer 5840e3f
Remove align from tests improve sse4
dor-forer 2a89dd8
format
dor-forer e562a86
applied to l2
dor-forer 2a0b4e6
format
dor-forer ab18690
Remove alignment l2
dor-forer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
/* | ||
* Copyright (c) 2006-Present, Redis Ltd. | ||
* All rights reserved. | ||
* | ||
* Licensed under your choice of the Redis Source Available License 2.0 | ||
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the | ||
* GNU Affero General Public License v3 (AGPLv3). | ||
*/ | ||
#include "VecSim/spaces/space_includes.h" | ||
#include "VecSim/spaces/AVX_utils.h" | ||
|
||
static inline void InnerProductStepSQ8_FMA(const float *&pVect1, const uint8_t *&pVect2, | ||
__m256 &sum256, const __m256 &min_val_vec, | ||
const __m256 &delta_vec) { | ||
// Load 8 float elements from pVect1 | ||
__m256 v1 = _mm256_loadu_ps(pVect1); | ||
pVect1 += 8; | ||
|
||
// Load 8 uint8 elements from pVect2, convert to int32, then to float | ||
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2); | ||
pVect2 += 8; | ||
|
||
// Zero-extend uint8 to int32 | ||
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128); | ||
|
||
// Convert int32 to float | ||
__m256 v2_f = _mm256_cvtepi32_ps(v2_256); | ||
|
||
// Dequantize and compute dot product in one step using FMA | ||
// (val * delta) + min_val -> v2_dequant | ||
// sum256 += v1 * v2_dequant | ||
// Using FMA: sum256 = v1 * v2_dequant + sum256 | ||
|
||
// First, compute v2_dequant = v2_f * delta_vec + min_val_vec | ||
__m256 v2_dequant = _mm256_fmadd_ps(v2_f, delta_vec, min_val_vec); | ||
|
||
// Then, compute sum256 += v1 * v2_dequant using FMA | ||
sum256 = _mm256_fmadd_ps(v1, v2_dequant, sum256); | ||
} | ||
|
||
template <unsigned char residual> // 0..15 | ||
float SQ8_InnerProductImp_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
const float *pVect1 = static_cast<const float *>(pVect1v); | ||
// pVect2 is a quantized uint8_t vector | ||
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v); | ||
const float *pEnd1 = pVect1 + dimension; | ||
|
||
// Get dequantization parameters from the end of quantized vector | ||
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension); | ||
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float)); | ||
// Create broadcast vectors for SIMD operations | ||
__m256 min_val_vec = _mm256_set1_ps(min_val); | ||
__m256 delta_vec = _mm256_set1_ps(delta); | ||
|
||
__m256 sum256 = _mm256_setzero_ps(); | ||
|
||
// Deal with 1-7 floats with mask loading, if needed. `dim` is >16, so we have at least one | ||
// 16-float block, so mask loading is guaranteed to be safe. | ||
if constexpr (residual % 8) { | ||
__mmask8 constexpr mask = (1 << (residual % 8)) - 1; | ||
__m256 v1 = my_mm256_maskz_loadu_ps<mask>(pVect1); | ||
pVect1 += residual % 8; | ||
|
||
// Load quantized values and dequantize | ||
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2); | ||
pVect2 += residual % 8; | ||
|
||
// Zero-extend uint8 to int32 | ||
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128); | ||
|
||
// Convert int32 to float | ||
__m256 v2_f = _mm256_cvtepi32_ps(v2_256); | ||
|
||
// Dequantize using FMA: (val * delta) + min_val | ||
__m256 v2_dequant = _mm256_fmadd_ps(v2_f, delta_vec, min_val_vec); | ||
|
||
// Compute dot product with masking | ||
sum256 = _mm256_mul_ps(v1, v2_dequant); | ||
} | ||
|
||
// If the reminder is >=8, have another step of 8 floats | ||
if constexpr (residual >= 8) { | ||
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec); | ||
} | ||
|
||
// We dealt with the residual part. We are left with some multiple of 16 floats. | ||
// In each iteration we calculate 16 floats = 512 bits. | ||
do { | ||
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec); | ||
InnerProductStepSQ8_FMA(pVect1, pVect2, sum256, min_val_vec, delta_vec); | ||
} while (pVect1 < pEnd1); | ||
|
||
return my_mm256_reduce_add_ps(sum256); | ||
} | ||
|
||
template <unsigned char residual> // 0..15 | ||
float SQ8_InnerProductSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
return 1.0f - SQ8_InnerProductImp_FMA<residual>(pVect1v, pVect2v, dimension); | ||
} | ||
|
||
template <unsigned char residual> // 0..15 | ||
float SQ8_CosineSIMD16_AVX2_FMA(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
// Get dequantization parameters from the end of quantized vector | ||
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v); | ||
const float inv_norm = *reinterpret_cast<const float *>(pVect2 + dimension + 2 * sizeof(float)); | ||
|
||
// Calculate inner product using common implementation with normalization | ||
float ip = SQ8_InnerProductImp_FMA<residual>(pVect1v, pVect2v, dimension); | ||
|
||
// For cosine, we need to account for the vector norms | ||
// The inv_norm parameter is stored after min_val and delta in the quantized vector | ||
return 1.0f - ip * inv_norm; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright (c) 2006-Present, Redis Ltd. | ||
* All rights reserved. | ||
* | ||
* Licensed under your choice of the Redis Source Available License 2.0 | ||
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the | ||
* GNU Affero General Public License v3 (AGPLv3). | ||
*/ | ||
#include "VecSim/spaces/space_includes.h" | ||
#include "VecSim/spaces/AVX_utils.h" | ||
|
||
static inline void InnerProductStepSQ8(const float *&pVect1, const uint8_t *&pVect2, __m256 &sum256, | ||
const __m256 &min_val_vec, const __m256 &delta_vec) { | ||
// Load 8 float elements from pVect1 | ||
__m256 v1 = _mm256_loadu_ps(pVect1); | ||
pVect1 += 8; | ||
|
||
// Load 8 uint8 elements from pVect2, convert to int32, then to float | ||
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2); | ||
pVect2 += 8; | ||
|
||
// Zero-extend uint8 to int32 | ||
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128); | ||
|
||
// Convert int32 to float | ||
__m256 v2_f = _mm256_cvtepi32_ps(v2_256); | ||
|
||
// Dequantize: (val * delta) + min_val | ||
__m256 v2_dequant = _mm256_add_ps(_mm256_mul_ps(v2_f, delta_vec), min_val_vec); | ||
dor-forer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Compute dot product and add to sum | ||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2_dequant)); | ||
dor-forer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <unsigned char residual> // 0..15 | ||
float SQ8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
const float *pVect1 = static_cast<const float *>(pVect1v); | ||
// pVect2 is a quantized uint8_t vector | ||
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v); | ||
const float *pEnd1 = pVect1 + dimension; | ||
|
||
// Get dequantization parameters from the end of quantized vector | ||
const float min_val = *reinterpret_cast<const float *>(pVect2 + dimension); | ||
const float delta = *reinterpret_cast<const float *>(pVect2 + dimension + sizeof(float)); | ||
// Create broadcast vectors for SIMD operations | ||
__m256 min_val_vec = _mm256_set1_ps(min_val); | ||
__m256 delta_vec = _mm256_set1_ps(delta); | ||
|
||
__m256 sum256 = _mm256_setzero_ps(); | ||
|
||
// Deal with 1-7 floats with mask loading, if needed. `dim` is >16, so we have at least one | ||
// 16-float block, so mask loading is guaranteed to be safe. | ||
if constexpr (residual % 8) { | ||
__mmask8 constexpr mask = (1 << (residual % 8)) - 1; | ||
__m256 v1 = my_mm256_maskz_loadu_ps<mask>(pVect1); | ||
pVect1 += residual % 8; | ||
|
||
// Load quantized values and dequantize | ||
__m128i v2_128 = _mm_loadl_epi64((__m128i *)pVect2); | ||
pVect2 += residual % 8; | ||
|
||
// Zero-extend uint8 to int32 | ||
__m256i v2_256 = _mm256_cvtepu8_epi32(v2_128); | ||
|
||
// Convert int32 to float | ||
__m256 v2_f = _mm256_cvtepi32_ps(v2_256); | ||
|
||
// Dequantize: (val * delta) + min_val | ||
__m256 v2_dequant = _mm256_add_ps(_mm256_mul_ps(v2_f, delta_vec), min_val_vec); | ||
|
||
// Compute dot product with masking | ||
sum256 = _mm256_mul_ps(v1, v2_dequant); | ||
} | ||
|
||
// If the reminder is >=8, have another step of 8 floats | ||
if constexpr (residual >= 8) { | ||
InnerProductStepSQ8(pVect1, pVect2, sum256, min_val_vec, delta_vec); | ||
} | ||
|
||
// We dealt with the residual part. We are left with some multiple of 16 floats. | ||
// In each iteration we calculate 16 floats = 512 bits. | ||
do { | ||
InnerProductStepSQ8(pVect1, pVect2, sum256, min_val_vec, delta_vec); | ||
InnerProductStepSQ8(pVect1, pVect2, sum256, min_val_vec, delta_vec); | ||
} while (pVect1 < pEnd1); | ||
|
||
return my_mm256_reduce_add_ps(sum256); | ||
} | ||
|
||
template <unsigned char residual> // 0..15 | ||
float SQ8_InnerProductSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
return 1.0f - SQ8_InnerProductImp<residual>(pVect1v, pVect2v, dimension); | ||
} | ||
|
||
template <unsigned char residual> // 0..15 | ||
float SQ8_CosineSIMD16_AVX2(const void *pVect1v, const void *pVect2v, size_t dimension) { | ||
// Get dequantization parameters from the end of quantized vector | ||
const uint8_t *pVect2 = static_cast<const uint8_t *>(pVect2v); | ||
const float inv_norm = *reinterpret_cast<const float *>(pVect2 + dimension + 2 * sizeof(float)); | ||
|
||
// Calculate inner product using common implementation with normalization | ||
float ip = SQ8_InnerProductImp<residual>(pVect1v, pVect2v, dimension); | ||
|
||
// For cosine, we need to account for the vector norms | ||
// The inv_norm parameter is stored after min_val and delta in the quantized vector | ||
return 1.0f - ip * inv_norm; | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider remodelling so the metadata is at the start of the vector