Skip to content

Commit e07f68a

Browse files
feat: support float32 on cuda (#41)
1 parent c202507 commit e07f68a

File tree

17 files changed

+418
-258
lines changed

17 files changed

+418
-258
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ authors = ["Olivier Dehaene"]
1717
homepage = "https://github.com/huggingface/text-embeddings-inference"
1818

1919
[patch.crates-io]
20-
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "4c8e6d36a4a4c31e2e4649ae5246226452a01fc1" }
20+
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "8be6ff46e4a2014fb563570e0d206c09aea88152" }
2121
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-core" }
2222
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-nn" }
2323
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-transformers" }

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ Options:
132132
If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures
133133
134134
[env: DTYPE=]
135-
[possible values: float16]
135+
[possible values: float16, float32]
136136
137137
--pooling <POOLING>
138138
Optionally control the pooling method.

backends/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tracing = "^0.1"
1818
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
1919
python = ["dep:text-embeddings-backend-python"]
2020
candle = ["dep:text-embeddings-backend-candle"]
21+
cuda = ["text-embeddings-backend-candle?/cuda"]
2122
mkl = ["text-embeddings-backend-candle?/mkl"]
2223
mkl-dynamic = ["text-embeddings-backend-candle?/mkl-dynamic"]
2324
accelerate = ["text-embeddings-backend-candle?/accelerate"]

backends/candle/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ candle-nn = { version = "0.3.0" }
1313
candle-transformers = { version = "0.3.0" }
1414
candle-flash-attn = { version = "0.3.0", optional = true }
1515
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true }
16-
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "07e1a5490211e25ed0d096a2b21d3c607666eaae", optional = true }
16+
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "ffd246552c266640fab217f964a83960e07a66ec", optional = true }
1717
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
1818
lazy_static = "^1.4"
1919
text-embeddings-backend-core = { path = "../core" }

backends/candle/src/flash_attn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::compute_cap::RUNTIME_COMPUTE_CAP;
22
use candle::Tensor;
33

4-
#[allow(clippy::too_many_arguments)]
4+
#[allow(clippy::too_many_arguments, unused)]
55
pub(crate) fn flash_attn_varlen(
66
q: &Tensor,
77
k: &Tensor,

backends/candle/src/layers.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#[allow(dead_code, unused)]
2+
mod cublaslt;
3+
mod layer_norm;
4+
mod linear;
5+
6+
pub use cublaslt::CUBLASLT;
7+
pub use layer_norm::LayerNorm;
8+
pub use linear::{HiddenAct, Linear};
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use crate::layers::HiddenAct;
2+
use candle::{Device, Result, Tensor};
3+
use lazy_static::lazy_static;
4+
5+
#[cfg(feature = "cuda")]
6+
use candle_cublaslt::{fused_batch_matmul, fused_matmul, Activation, CublasLt};
7+
8+
lazy_static! {
9+
pub static ref CUBLASLT: Option<CublasLtWrapper> = {
10+
match Device::cuda_if_available(0) {
11+
Ok(device) => {
12+
#[cfg(feature = "cuda")]
13+
{
14+
Some(CublasLtWrapper {
15+
cublaslt: CublasLt::new(&device).unwrap(),
16+
})
17+
}
18+
#[cfg(not(feature = "cuda"))]
19+
{
20+
None
21+
}
22+
}
23+
Err(_) => None,
24+
}
25+
};
26+
}
27+
28+
#[derive(Debug, Clone)]
29+
pub struct CublasLtWrapper {
30+
#[cfg(feature = "cuda")]
31+
pub cublaslt: CublasLt,
32+
}
33+
34+
impl CublasLtWrapper {
35+
#[allow(clippy::too_many_arguments)]
36+
pub fn matmul(
37+
&self,
38+
a: &Tensor,
39+
b: &Tensor,
40+
out: Option<&Tensor>,
41+
alpha: Option<f32>,
42+
beta: Option<f32>,
43+
bias: Option<&Tensor>,
44+
act: Option<HiddenAct>,
45+
) -> Result<Tensor> {
46+
#[cfg(feature = "cuda")]
47+
{
48+
let act = act.clone().map(|a| match a {
49+
HiddenAct::Gelu => Activation::Gelu,
50+
HiddenAct::Relu => Activation::Relu,
51+
});
52+
53+
fused_matmul(
54+
&a,
55+
&b,
56+
out,
57+
alpha,
58+
beta,
59+
bias,
60+
act.clone(),
61+
self.cublaslt.clone(),
62+
)
63+
}
64+
#[cfg(not(feature = "cuda"))]
65+
{
66+
candle::bail!("`cuda` feature is not enabled")
67+
}
68+
}
69+
70+
#[allow(clippy::too_many_arguments)]
71+
pub fn batch_matmul(
72+
&self,
73+
a: &Tensor,
74+
b: &Tensor,
75+
out: Option<&Tensor>,
76+
alpha: Option<f32>,
77+
beta: Option<f32>,
78+
bias: Option<&Tensor>,
79+
act: Option<HiddenAct>,
80+
) -> Result<Tensor> {
81+
#[cfg(feature = "cuda")]
82+
{
83+
let act = act.clone().map(|a| match a {
84+
HiddenAct::Gelu => Activation::Gelu,
85+
HiddenAct::Relu => Activation::Relu,
86+
});
87+
88+
fused_batch_matmul(
89+
&a,
90+
&b,
91+
out,
92+
alpha,
93+
beta,
94+
bias,
95+
act.clone(),
96+
self.cublaslt.clone(),
97+
)
98+
}
99+
#[cfg(not(feature = "cuda"))]
100+
{
101+
candle::bail!("`cuda` feature is not enabled")
102+
}
103+
}
104+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use candle::{DType, Device, Result, Tensor, D};
2+
use candle_nn::VarBuilder;
3+
4+
#[derive(Debug)]
5+
pub struct LayerNorm {
6+
weight: Tensor,
7+
bias: Tensor,
8+
epsilon: f32,
9+
span: tracing::Span,
10+
}
11+
12+
impl LayerNorm {
13+
pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result<Self> {
14+
Ok(Self {
15+
weight: vb
16+
.get(hidden_size, "weight")
17+
.or_else(|_| vb.get(hidden_size, "gamma"))?,
18+
bias: vb
19+
.get(hidden_size, "bias")
20+
.or_else(|_| vb.get(hidden_size, "beta"))?,
21+
epsilon,
22+
span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
23+
})
24+
}
25+
26+
pub fn forward(&self, hidden_states: &Tensor, residual: &Tensor) -> Result<Tensor> {
27+
let _enter = self.span.enter();
28+
29+
match hidden_states.device() {
30+
Device::Cpu => {
31+
let hidden_states = hidden_states.add(residual)?;
32+
let hidden_states_dtype = hidden_states.dtype();
33+
let internal_dtype = match hidden_states_dtype {
34+
DType::F16 | DType::BF16 => DType::F32,
35+
d => d,
36+
};
37+
let hidden_size = hidden_states.dim(D::Minus1)?;
38+
let hidden_states = hidden_states.to_dtype(internal_dtype)?;
39+
let mean_hidden_states =
40+
(hidden_states.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
41+
let hidden_states = hidden_states.broadcast_sub(&mean_hidden_states)?;
42+
let norm_hidden_states =
43+
(hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
44+
let hidden_states_normed = hidden_states
45+
.broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?;
46+
let hidden_states = hidden_states_normed
47+
.to_dtype(hidden_states_dtype)?
48+
.broadcast_mul(&self.weight)?;
49+
hidden_states.broadcast_add(&self.bias)
50+
}
51+
Device::Cuda(_) => {
52+
#[cfg(feature = "cuda")]
53+
{
54+
use candle_layer_norm::fused_add_layer_norm;
55+
56+
let original_shape = hidden_states.shape();
57+
let hidden_states = hidden_states.flatten_to(D::Minus2)?;
58+
let residual = residual.flatten_to(D::Minus2)?;
59+
60+
let result = fused_add_layer_norm(
61+
&hidden_states,
62+
&residual,
63+
&self.weight,
64+
&self.bias,
65+
self.epsilon,
66+
)?;
67+
result.reshape(original_shape)
68+
}
69+
#[cfg(not(feature = "cuda"))]
70+
candle::bail!("`cuda` feature is not enabled")
71+
}
72+
}
73+
}
74+
}

backends/candle/src/layers/linear.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use crate::layers::CUBLASLT;
2+
use candle::{Device, Result, Tensor, D};
3+
use serde::Deserialize;
4+
5+
#[derive(Debug, Deserialize, PartialEq, Clone)]
6+
#[serde(rename_all = "lowercase")]
7+
pub enum HiddenAct {
8+
Gelu,
9+
Relu,
10+
}
11+
12+
#[derive(Debug)]
13+
pub struct Linear {
14+
weight: Tensor,
15+
bias: Option<Tensor>,
16+
act: Option<HiddenAct>,
17+
span: tracing::Span,
18+
}
19+
20+
impl Linear {
21+
pub fn new(weight: Tensor, bias: Option<Tensor>, act: Option<HiddenAct>) -> Self {
22+
let span = tracing::span!(tracing::Level::TRACE, "linear");
23+
24+
Self {
25+
weight,
26+
bias,
27+
act,
28+
span,
29+
}
30+
}
31+
32+
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
33+
let _enter = self.span.enter();
34+
35+
#[allow(unused)]
36+
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), &*CUBLASLT) {
37+
// fused matmul requires x to be dims2
38+
let mut final_shape = x.dims().to_vec();
39+
final_shape.pop();
40+
final_shape.push(self.weight.dims()[0]);
41+
42+
let x = x.flatten_to(D::Minus2)?;
43+
let result = cublaslt.matmul(
44+
&self.weight,
45+
&x,
46+
None,
47+
None,
48+
None,
49+
self.bias.as_ref(),
50+
self.act.clone(),
51+
)?;
52+
result.reshape(final_shape)
53+
} else {
54+
let w = match x.dims() {
55+
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
56+
_ => self.weight.t()?,
57+
};
58+
let x = x.matmul(&w)?;
59+
let x = match &self.bias {
60+
None => Ok(x),
61+
Some(bias) => x.broadcast_add(bias),
62+
}?;
63+
if let Some(act) = &self.act {
64+
match act {
65+
HiddenAct::Gelu => x.gelu(),
66+
HiddenAct::Relu => x.relu(),
67+
}
68+
} else {
69+
Ok(x)
70+
}
71+
}
72+
}
73+
}

0 commit comments

Comments
 (0)