Skip to content

Commit 85e44a2

Browse files
feat: add integration tests (#101)
1 parent 2c50a0e commit 85e44a2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+9450
-510
lines changed

Cargo.lock

Lines changed: 195 additions & 35 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
integration-tests:
2+
cargo test --release
3+
4+
cuda-integration-tests:
5+
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --release
6+
7+
integration-tests-review:
8+
cargo insta test --review --release
9+
10+
cuda-integration-tests-review:
11+
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --release
12+

backends/candle/Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ candle-flash-attn = { version = "0.3.0", optional = true }
1515
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true }
1616
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "58684e116aae248c353f87846ddf0b2a8a7ed855", optional = true }
1717
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
18-
lazy_static = "^1.4"
1918
text-embeddings-backend-core = { path = "../core" }
2019
tracing = "^0.1"
2120
safetensors = "^0.4"
@@ -24,6 +23,14 @@ serde = { version = "^1.0", features = ["serde_derive"] }
2423
serde_json = "^1.0"
2524
memmap2 = "^0.9"
2625

26+
[dev-dependencies]
27+
insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] }
28+
is_close = "0.1.3"
29+
hf-hub = "0.3.2"
30+
anyhow = "1.0.75"
31+
tokenizers = { version = "^0.15.0", default-features = false, features = ["onig", "esaxx_fast"] }
32+
serial_test = "2.0.0"
33+
2734
[build-dependencies]
2835
anyhow = { version = "1", features = ["backtrace"] }
2936

backends/candle/src/compute_cap.rs

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,40 @@ use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::{
22
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
33
};
44
use candle::cuda_backend::cudarc::driver::CudaDevice;
5-
use lazy_static::lazy_static;
5+
use std::sync::Once;
66

7-
lazy_static! {
8-
pub static ref RUNTIME_COMPUTE_CAP: usize = {
9-
let device = CudaDevice::new(0).expect("cuda is not available");
10-
let major = device
11-
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
12-
.unwrap();
13-
let minor = device
14-
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
15-
.unwrap();
16-
(major * 10 + minor) as usize
17-
};
18-
pub static ref COMPILE_COMPUTE_CAP: usize = env!("CUDA_COMPUTE_CAP").parse::<usize>().unwrap();
7+
static INIT: Once = Once::new();
8+
static mut RUNTIME_COMPUTE_CAP: usize = 0;
9+
static mut COMPILE_COMPUTE_CAP: usize = 0;
10+
11+
fn init_compute_caps() {
12+
unsafe {
13+
INIT.call_once(|| {
14+
let device = CudaDevice::new(0).expect("cuda is not available");
15+
let major = device
16+
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
17+
.unwrap();
18+
let minor = device
19+
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
20+
.unwrap();
21+
RUNTIME_COMPUTE_CAP = (major * 10 + minor) as usize;
22+
COMPILE_COMPUTE_CAP = env!("CUDA_COMPUTE_CAP").parse::<usize>().unwrap();
23+
});
24+
}
25+
}
26+
27+
pub fn get_compile_compute_cap() -> usize {
28+
unsafe {
29+
init_compute_caps();
30+
COMPILE_COMPUTE_CAP
31+
}
32+
}
33+
34+
pub fn get_runtime_compute_cap() -> usize {
35+
unsafe {
36+
init_compute_caps();
37+
RUNTIME_COMPUTE_CAP
38+
}
1939
}
2040

2141
fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize) -> bool {
@@ -30,8 +50,8 @@ fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize)
3050
}
3151

3252
pub fn incompatible_compute_cap() -> bool {
33-
let compile_compute_cap = *COMPILE_COMPUTE_CAP;
34-
let runtime_compute_cap = *RUNTIME_COMPUTE_CAP;
53+
let compile_compute_cap = get_compile_compute_cap();
54+
let runtime_compute_cap = get_runtime_compute_cap();
3555
!compute_cap_matching(runtime_compute_cap, compile_compute_cap)
3656
}
3757

backends/candle/src/flash_attn.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::compute_cap::RUNTIME_COMPUTE_CAP;
1+
use crate::compute_cap::get_runtime_compute_cap;
22
use candle::Tensor;
33

44
#[allow(clippy::too_many_arguments, unused)]
@@ -13,7 +13,9 @@ pub(crate) fn flash_attn_varlen(
1313
softmax_scale: f32,
1414
causal: bool,
1515
) -> Result<Tensor, candle::Error> {
16-
if *RUNTIME_COMPUTE_CAP == 75 {
16+
let runtime_compute_cap = get_runtime_compute_cap();
17+
18+
if runtime_compute_cap == 75 {
1719
#[cfg(feature = "flash-attn-v1")]
1820
{
1921
use candle_flash_attn_v1::flash_attn_varlen;
@@ -31,7 +33,7 @@ pub(crate) fn flash_attn_varlen(
3133
}
3234
#[cfg(not(feature = "flash-attn-v1"))]
3335
candle::bail!("Flash attention v1 is not installed. Use `flash-attn-v1` feature.")
34-
} else if (80..90).contains(&*RUNTIME_COMPUTE_CAP) {
36+
} else if (80..90).contains(&runtime_compute_cap) {
3537
#[cfg(feature = "flash-attn")]
3638
{
3739
use candle_flash_attn::flash_attn_varlen;
@@ -49,7 +51,7 @@ pub(crate) fn flash_attn_varlen(
4951
}
5052
#[cfg(not(feature = "flash-attn"))]
5153
candle::bail!("Flash attention is not installed. Use `flash-attn-v1` feature.")
52-
} else if *RUNTIME_COMPUTE_CAP == 90 {
54+
} else if runtime_compute_cap == 90 {
5355
#[cfg(feature = "flash-attn")]
5456
{
5557
use candle_flash_attn::flash_attn_varlen;
@@ -70,6 +72,6 @@ pub(crate) fn flash_attn_varlen(
7072
}
7173
candle::bail!(
7274
"GPU with CUDA capability {} is not supported",
73-
*RUNTIME_COMPUTE_CAP
75+
runtime_compute_cap
7476
);
7577
}

backends/candle/src/layers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ mod cublaslt;
33
mod layer_norm;
44
mod linear;
55

6-
pub use cublaslt::CUBLASLT;
6+
pub use cublaslt::get_cublas_lt_wrapper;
77
pub use layer_norm::LayerNorm;
88
pub use linear::{HiddenAct, Linear};

backends/candle/src/layers/cublaslt.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11
use crate::layers::HiddenAct;
22
use candle::{Device, Result, Tensor};
3-
use lazy_static::lazy_static;
3+
use std::sync::Once;
44

55
#[cfg(feature = "cuda")]
66
use candle_cublaslt::{fused_batch_matmul, fused_matmul, Activation, CublasLt};
77

8-
lazy_static! {
9-
pub static ref CUBLASLT: Option<CublasLtWrapper> = {
10-
match Device::cuda_if_available(0) {
11-
Ok(device) => {
12-
#[cfg(feature = "cuda")]
13-
{
14-
Some(CublasLtWrapper {
15-
cublaslt: CublasLt::new(&device).unwrap(),
16-
})
17-
}
18-
#[cfg(not(feature = "cuda"))]
19-
{
20-
None
8+
static INIT: Once = Once::new();
9+
static mut CUBLASLT: Option<CublasLtWrapper> = None;
10+
11+
pub fn get_cublas_lt_wrapper() -> Option<&'static CublasLtWrapper> {
12+
unsafe {
13+
INIT.call_once(|| {
14+
CUBLASLT = match Device::cuda_if_available(0) {
15+
Ok(device) => {
16+
#[cfg(feature = "cuda")]
17+
{
18+
Some(CublasLtWrapper {
19+
cublaslt: CublasLt::new(&device).unwrap(),
20+
})
21+
}
22+
#[cfg(not(feature = "cuda"))]
23+
{
24+
None
25+
}
2126
}
22-
}
23-
Err(_) => None,
24-
}
25-
};
27+
Err(_) => None,
28+
};
29+
});
30+
CUBLASLT.as_ref()
31+
}
2632
}
2733

2834
#[derive(Debug, Clone)]

backends/candle/src/layers/linear.rs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use crate::layers::CUBLASLT;
2-
use candle::{Device, Result, Tensor, D};
1+
use crate::layers::cublaslt::get_cublas_lt_wrapper;
2+
use candle::{Device, Result, Tensor};
33
use serde::Deserialize;
44

55
#[derive(Debug, Deserialize, PartialEq, Clone)]
@@ -33,23 +33,27 @@ impl Linear {
3333
let _enter = self.span.enter();
3434

3535
#[allow(unused)]
36-
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), &*CUBLASLT) {
37-
// fused matmul requires x to be dims2
38-
let mut final_shape = x.dims().to_vec();
39-
final_shape.pop();
40-
final_shape.push(self.weight.dims()[0]);
41-
42-
let x = x.flatten_to(D::Minus2)?;
43-
let result = cublaslt.matmul(
44-
&self.weight,
45-
&x,
46-
None,
47-
None,
48-
None,
49-
self.bias.as_ref(),
50-
self.act.clone(),
51-
)?;
52-
result.reshape(final_shape)
36+
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), get_cublas_lt_wrapper()) {
37+
match x.dims() {
38+
&[bsize, _, _] => cublaslt.batch_matmul(
39+
&self.weight.broadcast_left(bsize)?,
40+
x,
41+
None,
42+
None,
43+
None,
44+
self.bias.as_ref(),
45+
self.act.clone(),
46+
),
47+
_ => cublaslt.matmul(
48+
&self.weight,
49+
x,
50+
None,
51+
None,
52+
None,
53+
self.bias.as_ref(),
54+
self.act.clone(),
55+
),
56+
}
5357
} else {
5458
let w = match x.dims() {
5559
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,

backends/candle/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ mod layers;
77
mod models;
88

99
#[cfg(feature = "cuda")]
10-
use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_COMPUTE_CAP};
10+
use crate::compute_cap::{
11+
get_compile_compute_cap, get_runtime_compute_cap, incompatible_compute_cap,
12+
};
1113
#[cfg(feature = "cuda")]
1214
use crate::models::FlashBertModel;
1315
use crate::models::{BertModel, JinaBertModel, Model, PositionEmbeddingType};
@@ -94,7 +96,7 @@ impl CandleBackend {
9496
#[cfg(feature = "cuda")]
9597
{
9698
if incompatible_compute_cap() {
97-
return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", *RUNTIME_COMPUTE_CAP, *COMPILE_COMPUTE_CAP)));
99+
return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", get_runtime_compute_cap(), get_compile_compute_cap())));
98100
}
99101

100102
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))

backends/candle/src/models/bert.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT};
1+
use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear};
22
use crate::models::Model;
33
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
44
use candle_nn::{Embedding, VarBuilder};
@@ -185,7 +185,9 @@ impl BertAttention {
185185
let value_layer = &qkv[2];
186186

187187
#[allow(unused_variables)]
188-
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = (device, &*CUBLASLT) {
188+
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
189+
(device, get_cublas_lt_wrapper())
190+
{
189191
#[cfg(feature = "cuda")]
190192
{
191193
// cuBLASLt batch matmul implementation requires inputs to be dims3

0 commit comments

Comments
 (0)