Skip to content

Commit db4ff2b

Browse files
fix: remove impossible dtypes
1 parent 26faa40 commit db4ff2b

File tree

4 files changed

+38
-10
lines changed

4 files changed

+38
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Options:
128128
If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures
129129
130130
[env: DTYPE=]
131-
[possible values: float16, float32]
131+
[possible values: float16]
132132
133133
--pooling <POOLING>
134134
Optionally control the pooling method.

backends/src/dtype.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,21 @@ use clap::ValueEnum;
66
#[derive(Debug, PartialEq)]
77
#[cfg_attr(feature = "clap", derive(Clone, ValueEnum))]
88
pub enum DType {
9-
#[cfg(any(feature = "python", feature = "candle"))]
9+
// Float16 is not available on accelerate
10+
#[cfg(any(
11+
feature = "python",
12+
all(feature = "candle", not(feature = "accelerate"))
13+
))]
1014
Float16,
11-
#[cfg(any(feature = "python", feature = "candle"))]
15+
// Float32 is not available on candle cuda
16+
#[cfg(any(
17+
feature = "python",
18+
all(
19+
feature = "candle",
20+
not(feature = "flash-attn"),
21+
not(feature = "flash-attn-v1")
22+
)
23+
))]
1224
Float32,
1325
// #[cfg(feature = "candle")]
1426
// Q6K,
@@ -17,9 +29,21 @@ pub enum DType {
1729
impl fmt::Display for DType {
1830
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1931
match self {
20-
#[cfg(any(feature = "python", feature = "candle"))]
32+
// Float16 is not available on accelerate
33+
#[cfg(any(
34+
feature = "python",
35+
all(feature = "candle", not(feature = "accelerate"))
36+
))]
2137
DType::Float16 => write!(f, "float16"),
22-
#[cfg(any(feature = "python", feature = "candle"))]
38+
// Float32 is not available on candle cuda
39+
#[cfg(any(
40+
feature = "python",
41+
all(
42+
feature = "candle",
43+
not(feature = "flash-attn"),
44+
not(feature = "flash-attn-v1")
45+
)
46+
))]
2347
DType::Float32 => write!(f, "float32"),
2448
// #[cfg(feature = "candle")]
2549
// DType::Q6K => write!(f, "q6k"),

router/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ mkl = ["text-embeddings-backend/mkl"]
5454
accelerate = ["text-embeddings-backend/accelerate"]
5555
python = ["text-embeddings-backend/python"]
5656
candle = ["text-embeddings-backend/candle"]
57-
candle-cuda = ["text-embeddings-backend/candle", "text-embeddings-backend/flash-attn"]
58-
candle-cuda-turing = ["text-embeddings-backend/candle", "text-embeddings-backend/flash-attn-v1"]
57+
candle-cuda = ["candle", "text-embeddings-backend/flash-attn"]
58+
candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1"]
5959
static-linking = ["text-embeddings-backend/static-linking"]

router/src/main.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,14 @@ async fn main() -> Result<()> {
232232

233233
// Get dtype
234234
let dtype = args.dtype.unwrap_or_else(|| {
235-
if cfg!(feature = "accelerate") {
236-
return DType::Float32;
235+
#[cfg(feature = "accelerate")]
236+
{
237+
DType::Float32
238+
}
239+
#[cfg(not(feature = "accelerate"))]
240+
{
241+
DType::Float16
237242
}
238-
DType::Float16
239243
});
240244

241245
// Create backend

0 commit comments

Comments
 (0)