-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Enable Intel®-AMX/oneDNN to accelerate IndexFlatIP search #3266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
guangzegu
wants to merge
18
commits into
facebookresearch:main
Choose a base branch
from
guangzegu:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
7db8ce6
Enable Intel®-AMX/oneDNN to accelerate IndexFlat search
guangzegu b35a0f2
formatted distances.cpp and onednn_utils.h
guangzegu 781f178
Add descriptions of Intel®-AMX/oneDNN optimization to INSTALL.md
guangzegu 2f3fdf9
Add oneDNN/AMX optimization for distance calculation using Blas for I…
guangzegu a15c5cc
Merge branch 'facebookresearch:main' into main
guangzegu 116fc01
Restructure the AMX integration with faiss
guangzegu e3ea518
Merge remote-tracking branch 'upstream/main'
guangzegu 9e34323
Refactor and optimize the code structure to support AMX/OneDNN comput…
guangzegu f556407
Format distances_dnnl.h
guangzegu ed7b184
Merge branch 'main' into main
guangzegu 78857d9
Merge remote-tracking branch 'upstream/main'
guangzegu fc447da
Merge branch 'facebookresearch:main' into main
guangzegu f43b0fa
Merge branch 'facebookresearch:main' into main
guangzegu e14be69
Merge branch 'facebookresearch:main' into main
guangzegu eda99c0
Add DNNL compilation flags to support low-precision testing
guangzegu 53fa4ad
Add unit tests for low-precision IndexFlatIP
guangzegu 6dd3a8a
Skip certain high precision tests using DNNL compile option
guangzegu fd45d23
Merge branch 'facebookresearch:main' into main
guangzegu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,6 +102,20 @@ void faiss_set_distance_compute_min_k_reservoir(int value); | |
/// rather than a heap | ||
int faiss_get_distance_compute_min_k_reservoir(); | ||
|
||
#ifdef ENABLE_DNNL | ||
/// Setter of block sizes value for oneDNN/AMX distance computations | ||
void faiss_set_distance_compute_dnnl_query_bs(int value); | ||
|
||
/// Getter of block sizes value for oneDNN/AMX distance computations | ||
int faiss_get_distance_compute_dnnl_query_bs(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible for you to move these to |
||
/// Setter of block sizes value for oneDNN/AMX distance computations | ||
void faiss_set_distance_compute_dnnl_database_bs(int value); | ||
|
||
/// Getter of block sizes value for oneDNN/AMX distance computations | ||
int faiss_get_distance_compute_dnnl_database_bs(); | ||
#endif | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/** | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
/* All distance functions for L2 and IP distances. | ||
* The actual functions are implemented in distances.cpp and distances_simd.cpp | ||
*/ | ||
|
||
#include <faiss/cppcontrib/amx/onednn_utils.h> | ||
#include <faiss/impl/AuxIndexStructures.h> | ||
#include <faiss/impl/ResultHandler.h> | ||
#include <faiss/impl/platform_macros.h> | ||
#include <omp.h> | ||
|
||
#ifndef FINTEGER | ||
#define FINTEGER long | ||
#endif | ||
|
||
namespace faiss { | ||
|
||
// block sizes for oneDNN/AMX distance computations | ||
FAISS_API int distance_compute_dnnl_query_bs = 10240; | ||
FAISS_API int distance_compute_dnnl_database_bs = 10240; | ||
|
||
/* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */ | ||
template <class BlockResultHandler> | ||
void exhaustive_inner_product_seq_dnnl( | ||
const float* x, | ||
const float* y, | ||
size_t d, | ||
size_t nx, | ||
size_t ny, | ||
BlockResultHandler& res) { | ||
using SingleResultHandler = | ||
typename BlockResultHandler::SingleResultHandler; | ||
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads()); | ||
|
||
std::unique_ptr<float[]> res_arr(new float[nx * ny]); | ||
|
||
comput_f32bf16f32_inner_product( | ||
nx, | ||
d, | ||
ny, | ||
d, | ||
const_cast<float*>(x), | ||
const_cast<float*>(y), | ||
res_arr.get()); | ||
|
||
#pragma omp parallel num_threads(nt) | ||
{ | ||
SingleResultHandler resi(res); | ||
#pragma omp for | ||
for (size_t i = 0; i < nx; i++) { | ||
resi.begin(i); | ||
for (size_t j = 0; j < ny; j++) { | ||
float ip = res_arr[i * ny + j]; | ||
resi.add_result(ip, j); | ||
} | ||
resi.end(); | ||
} | ||
} | ||
} | ||
|
||
/* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */ | ||
template <class BlockResultHandler> | ||
void exhaustive_inner_product_blas_dnnl( | ||
const float* x, | ||
const float* y, | ||
size_t d, | ||
size_t nx, | ||
size_t ny, | ||
BlockResultHandler& res) { | ||
/* block sizes */ | ||
const size_t bs_x = distance_compute_dnnl_query_bs; | ||
const size_t bs_y = distance_compute_dnnl_database_bs; | ||
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]); | ||
|
||
for (size_t i0 = 0; i0 < nx; i0 += bs_x) { | ||
size_t i1 = i0 + bs_x; | ||
if (i1 > nx) | ||
i1 = nx; | ||
|
||
res.begin_multiple(i0, i1); | ||
|
||
for (size_t j0 = 0; j0 < ny; j0 += bs_y) { | ||
size_t j1 = j0 + bs_y; | ||
if (j1 > ny) | ||
j1 = ny; | ||
/* compute the actual dot products */ | ||
FINTEGER nyi = j1 - j0, nxi = i1 - i0; | ||
comput_f32bf16f32_inner_product( | ||
nxi, | ||
d, | ||
nyi, | ||
d, | ||
const_cast<float*>(x + i0 * d), | ||
const_cast<float*>(y + j0 * d), | ||
ip_block.get()); | ||
|
||
res.add_results(j0, j1, ip_block.get()); | ||
} | ||
res.end_multiple(); | ||
InterruptCallback::check(); | ||
} | ||
} | ||
|
||
} // namespace faiss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
/** | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
/* All distance functions for L2 and IP distances. | ||
* The actual functions are implemented in distances.cpp and distances_simd.cpp | ||
*/ | ||
|
||
#pragma once | ||
#include <stdlib.h> | ||
#include <mutex> | ||
#include <shared_mutex> | ||
#include "oneapi/dnnl/dnnl.hpp" | ||
|
||
namespace faiss { | ||
|
||
static dnnl::engine cpu_engine; | ||
static dnnl::stream engine_stream; | ||
static bool is_onednn_init = false; | ||
static std::mutex init_mutex; | ||
|
||
static bool is_amxbf16_supported() { | ||
unsigned int eax, ebx, ecx, edx; | ||
__asm__ __volatile__("cpuid" | ||
: "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) | ||
: "a"(7), "c"(0)); | ||
return edx & (1 << 22); | ||
} | ||
|
||
static void init_onednn() { | ||
std::unique_lock<std::mutex> lock(init_mutex); | ||
|
||
if (is_onednn_init) { | ||
return; | ||
} | ||
|
||
// init dnnl engine | ||
cpu_engine = dnnl::engine(dnnl::engine::kind::cpu, 0); | ||
engine_stream = dnnl::stream(cpu_engine); | ||
|
||
is_onednn_init = true; | ||
} | ||
|
||
__attribute__((constructor)) static void library_load() { | ||
// this functionn will be automatically called when the library is loaded | ||
// printf("Library loaded.\n"); | ||
init_onednn(); | ||
} | ||
|
||
/** | ||
* @brief Compute float32 matrix inner product with bf16 intermediate results to | ||
* accelerate | ||
* @details The main idea is: | ||
* 1. Define float32 memory layout for input and output | ||
* 2. Create low precision bf16 memory descriptors as inner product input | ||
* 3. Generate inner product primitive descriptor | ||
* 4. Execute float32 => (reorder) => bf16 => (inner product) => float32 | ||
* chain operation, isolate different precision data, accelerate inner | ||
* product | ||
* 5. Pipeline execution via streams for asynchronous scheduling | ||
* | ||
* @param xrow Row number of input matrix X | ||
* @param xcol Column number of input matrix X | ||
* @param yrow Row number of weight matrix Y | ||
* @param ycol Column number of weight matrix Y | ||
* @param in_f32_1 Input matrix pointer in float32 type | ||
* @param in_f32_2 Weight matrix pointer in float32 type | ||
* @param out_f32 Output matrix pointer for result in float32 type | ||
* @return None | ||
*/ | ||
static void comput_f32bf16f32_inner_product( | ||
uint32_t xrow, | ||
uint32_t xcol, | ||
uint32_t yrow, | ||
uint32_t ycol, | ||
float* in_f32_1, | ||
float* in_f32_2, | ||
float* out_f32) { | ||
dnnl::memory::desc f32_md1 = dnnl::memory::desc( | ||
{xrow, xcol}, | ||
dnnl::memory::data_type::f32, | ||
dnnl::memory::format_tag::ab); | ||
dnnl::memory::desc f32_md2 = dnnl::memory::desc( | ||
{yrow, ycol}, | ||
dnnl::memory::data_type::f32, | ||
dnnl::memory::format_tag::ab); | ||
dnnl::memory::desc f32_dst_md2 = dnnl::memory::desc( | ||
{xrow, yrow}, | ||
dnnl::memory::data_type::f32, | ||
dnnl::memory::format_tag::ab); | ||
|
||
dnnl::memory f32_mem1 = dnnl::memory(f32_md1, cpu_engine, in_f32_1); | ||
dnnl::memory f32_mem2 = dnnl::memory(f32_md2, cpu_engine, in_f32_2); | ||
dnnl::memory f32_dst_mem = dnnl::memory(f32_dst_md2, cpu_engine, out_f32); | ||
|
||
// inner memory bf16 | ||
dnnl::memory::desc bf16_md1 = dnnl::memory::desc( | ||
{xrow, xcol}, | ||
dnnl::memory::data_type::bf16, | ||
dnnl::memory::format_tag::any); | ||
dnnl::memory::desc bf16_md2 = dnnl::memory::desc( | ||
{yrow, ycol}, | ||
dnnl::memory::data_type::bf16, | ||
dnnl::memory::format_tag::any); | ||
|
||
dnnl::inner_product_forward::primitive_desc inner_product_pd = | ||
dnnl::inner_product_forward::primitive_desc( | ||
cpu_engine, | ||
dnnl::prop_kind::forward_training, | ||
bf16_md1, | ||
bf16_md2, | ||
f32_dst_md2); | ||
|
||
dnnl::inner_product_forward inner_product_prim = | ||
dnnl::inner_product_forward(inner_product_pd); | ||
|
||
dnnl::memory bf16_mem1 = | ||
dnnl::memory(inner_product_pd.src_desc(), cpu_engine); | ||
dnnl::reorder(f32_mem1, bf16_mem1) | ||
.execute(engine_stream, f32_mem1, bf16_mem1); | ||
|
||
dnnl::memory bf16_mem2 = | ||
dnnl::memory(inner_product_pd.weights_desc(), cpu_engine); | ||
dnnl::reorder(f32_mem2, bf16_mem2) | ||
.execute(engine_stream, f32_mem2, bf16_mem2); | ||
|
||
inner_product_prim.execute( | ||
engine_stream, | ||
{{DNNL_ARG_SRC, bf16_mem1}, | ||
{DNNL_ARG_WEIGHTS, bf16_mem2}, | ||
{DNNL_ARG_DST, f32_dst_mem}}); | ||
|
||
// Wait for the computation to finalize. | ||
engine_stream.wait(); | ||
|
||
// printf("comput_f32bf16f32_inner_product finished#######>\n"); | ||
} | ||
|
||
} // namespace faiss |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for you to move these to
cpi/cppcontrib/amx/distances_dnnl_c.h
and if not feasible, gate it behind a compilation flag?