Skip to content

Commit 463329a

Browse files
feat: USE_FLASH_ATTENTION env var (#57)
1 parent 04f94f6 commit 463329a

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

.github/workflows/build_75.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
CUDA_COMPUTE_CAP=75
131131
GIT_SHA=${{ env.GITHUB_SHA }}
132132
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}
133+
DEFAULT_USE_FLASH_ATTENTION=False
133134
tags: ${{ steps.meta-75.outputs.tags }}
134135
labels: ${{ steps.meta-75.outputs.labels }}
135136
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-75,mode=max

Dockerfile-cuda

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 AS base-builder
1+
FROM nvidia/cuda:12.0.0-devel-ubuntu22.04 AS base-builder
22

33
ENV SCCACHE=0.5.4
44
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
@@ -77,10 +77,13 @@ RUN if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \
7777
cargo build --release --bin text-embeddings-router -F candle-cuda -F static-linking --no-default-features && sccache -s; \
7878
fi;
7979

80-
FROM nvidia/cuda:12.2.0-base-ubuntu22.04
80+
FROM nvidia/cuda:12.0.0-base-ubuntu22.04
81+
82+
ARG DEFAULT_USE_FLASH_ATTENTION=True
8183

8284
ENV HUGGINGFACE_HUB_CACHE=/data \
83-
PORT=80
85+
PORT=80 \
86+
USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION
8487

8588
COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
8689

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ curl 127.0.0.1:8080/embed \
9797
```
9898

9999
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
100-
We also recommend using NVIDIA drivers with CUDA version 12.2 or higher.
100+
We also recommend using NVIDIA drivers with CUDA version 12.0 or higher.
101101

102102
To see all options to serve your models:
103103

@@ -236,6 +236,9 @@ Text Embeddings Inference ships with multiple Docker images that you can use to
236236
| Ada Lovelace (RTX 4000 series, ...) | ghcr.io/huggingface/text-embeddings-inference:89-0.3.0 |
237237
| Hopper (H100) | ghcr.io/huggingface/text-embeddings-inference:hopper-0.3.0 (experimental) |
238238

239+
**Warning**: Flash Attention is turned off by default for the Turing image as it suffers from precision issues.
240+
You can turn Flash Attention v1 ON by using the `USE_FLASH_ATTENTION=True` environment variable.
241+
239242
### API documentation
240243

241244
You can consult the OpenAPI documentation of the `text-embeddings-inference` REST API using the `/docs` route.
@@ -307,7 +310,7 @@ sudo apt-get install libssl-dev gcc -y
307310

308311
GPUs with Cuda compute capabilities < 7.5 are not supported (V100, Titan V, GTX 1000 series, ...).
309312

310-
Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12.2 or higher.
313+
Make sure you have Cuda and the nvidia drivers installed. We recommend using NVIDIA drivers with CUDA version 12.0 or higher.
311314
You also need to add the nvidia binaries to your path:
312315

313316
```shell

backends/candle/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ impl CandleBackend {
132132
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
133133
&& dtype == DType::F16
134134
&& config.position_embedding_type == PositionEmbeddingType::Absolute
135-
// Flash attention v1 precision problem with head_size == 32
135+
// Allow disabling because of flash attention v1 precision problems
136136
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
137-
&& !(*RUNTIME_COMPUTE_CAP == 75 && (config.hidden_size / config.num_attention_heads) == 32)
137+
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
138138
{
139139
tracing::info!("Starting FlashBert model on Cuda");
140140
Box::new(FlashBertModel::load(vb, &config, pool).s()?)

0 commit comments

Comments
 (0)