Skip to content

Commit 26faa40

Browse files
feat: default to float32 on accelerate
1 parent 7af5bd6 commit 26faa40

File tree

2 files changed

+65
-18
lines changed

2 files changed

+65
-18
lines changed

README.md

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ Benchmark for [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1
3333
- [API Documentation](#api-documentation)
3434
- [Using a private or gated model](#using-a-private-or-gated-model)
3535
- [Distributed Tracing](#distributed-tracing)
36-
- [Local Install](#local-install)
36+
- [Local Install](#local-install)
37+
- [Docker Build](#docker-build)
3738

3839
- No compilation step
3940
- Dynamic shapes
@@ -89,7 +90,7 @@ curl 127.0.0.1:8080/embed \
8990
```
9091

9192
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
92-
We also recommend using NVIDIA drivers with CUDA version 12 or higher.
93+
We also recommend using NVIDIA drivers with CUDA version 12.2 or higher.
9394

9495
To see all options to serve your models:
9596

@@ -123,9 +124,10 @@ Options:
123124
124125
--dtype <DTYPE>
125126
The dtype to be forced upon the model
127+
128+
If `dtype` is not set, it defaults to float32 on accelerate, and float16 for all other architectures
126129
127130
[env: DTYPE=]
128-
[default: float16]
129131
[possible values: float16, float32]
130132
131133
--pooling <POOLING>
@@ -217,13 +219,14 @@ Options:
217219

218220
Text Embeddings Inference ships with multiple Docker images that you can use to target a specific backend:
219221

220-
| Architecture | Image |
221-
|--------------|-------------------------------------------------------------|
222-
| CPU | ghcr.io/huggingface/text-embeddings-inference:cpu-latest |
223-
| Turing | ghcr.io/huggingface/text-embeddings-inference:turing-latest |
224-
| Ampere 80 | ghcr.io/huggingface/text-embeddings-inference:latest |
225-
| Ampere 86 | ghcr.io/huggingface/text-embeddings-inference:86-latest |
226-
| Hopper | ghcr.io/huggingface/text-embeddings-inference:hopper-latest |
222+
| Architecture | Image |
223+
|-----------------------------------|-------------------------------------------------------------|
224+
| CPU | ghcr.io/huggingface/text-embeddings-inference:cpu-latest |
225+
| Volta | NOT SUPPORTED |
226+
| Turing (T4, RTX 2000 series, ...) | ghcr.io/huggingface/text-embeddings-inference:turing-latest |
227+
| Ampere 80 (A100, A30) | ghcr.io/huggingface/text-embeddings-inference:latest |
228+
| Ampere 86 (A10, A40, ...) | ghcr.io/huggingface/text-embeddings-inference:86-latest |
229+
| Hopper (H100) | ghcr.io/huggingface/text-embeddings-inference:hopper-latest |
227230

228231
### API documentation
229232

@@ -256,9 +259,9 @@ docker run --gpus all -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/da
256259
`text-embeddings-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
257260
by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
258261

259-
### Local install
262+
## Local install
260263

261-
#### CPU
264+
### CPU
262265

263266
You can also opt to install `text-embeddings-inference` locally.
264267

@@ -292,9 +295,11 @@ text-embeddings-router --model-id $model --revision $revision --port 8080
292295
sudo apt-get install libssl-dev gcc -y
293296
```
294297

295-
#### Cuda
298+
### Cuda
299+
300+
GPUs with Cuda compute capabilities < 7.5 are not supported (V100, Titan V, GTX 1000 series, ...).
296301

297-
Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12 or higher.
302+
Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12.2 or higher.
298303
You also need to add the nvidia binaries to your path:
299304

300305
```shell
@@ -305,6 +310,11 @@ Then run:
305310

306311
```shell
307312
# This can take a while as we need to compile a lot of cuda kernels
313+
314+
# On Turing GPUs (T4, RTX 2000 series ... )
315+
cargo install --path router -F candle-cuda-turing --no-default-features
316+
317+
# On Ampere and Hopper
308318
cargo install --path router -F candle-cuda --no-default-features
309319
```
310320

@@ -316,3 +326,32 @@ revision=refs/pr/5
316326

317327
text-embeddings-router --model-id $model --revision $revision --port 8080
318328
```
329+
330+
## Docker build
331+
332+
You can build the CPU container with:
333+
334+
```shell
335+
docker build .
336+
```
337+
338+
To build the Cuda containers, you need to know the compute cap of the GPU you will be using
339+
at runtime.
340+
341+
Then you can build the container with:
342+
343+
```shell
344+
# Example for Turing (T4, RTX 2000 series, ...)
345+
runtime_compute_cap=75
346+
347+
# Example for A100
348+
runtime_compute_cap=80
349+
350+
# Example for A10
351+
runtime_compute_cap=86
352+
353+
# Example for H100
354+
runtime_compute_cap=90
355+
356+
docker build . -f Dockerfile-cuda --build-arg CUDA_COMPUTE_CAP=$runtime_compute_cap
357+
```

router/src/main.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ struct Args {
4949
tokenization_workers: Option<usize>,
5050

5151
/// The dtype to be forced upon the model.
52-
#[clap(default_value = "float16", long, env, value_enum)]
53-
dtype: DType,
52+
#[clap(long, env, value_enum)]
53+
dtype: Option<DType>,
5454

5555
/// Optionally control the pooling method.
5656
///
@@ -230,11 +230,19 @@ async fn main() -> Result<()> {
230230
position_offset,
231231
);
232232

233+
// Get dtype
234+
let dtype = args.dtype.unwrap_or_else(|| {
235+
if cfg!(feature = "accelerate") {
236+
return DType::Float32;
237+
}
238+
DType::Float16
239+
});
240+
233241
// Create backend
234242
tracing::info!("Starting model backend");
235243
let backend = Backend::new(
236244
model_root,
237-
args.dtype.clone(),
245+
dtype.clone(),
238246
pool.clone(),
239247
args.uds_path,
240248
args.otlp_endpoint,
@@ -265,7 +273,7 @@ async fn main() -> Result<()> {
265273
let info = Info {
266274
model_id: args.model_id,
267275
model_sha: args.revision,
268-
model_dtype: args.dtype.to_string(),
276+
model_dtype: dtype.to_string(),
269277
model_pooling: pool.to_string(),
270278
max_concurrent_requests: args.max_concurrent_requests,
271279
max_input_length: config.max_position_embeddings,

0 commit comments

Comments
 (0)