Skip to content

Commit 2fc644c

Browse files
committed
add dockerfile
1 parent dbefcd8 commit 2fc644c

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

Dockerfile-rocm

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
FROM rocm/dev-ubuntu-22.04:6.0.2 AS base-builder
2+
3+
ENV SCCACHE=0.5.4
4+
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
5+
ENV PATH="/root/.cargo/bin:${PATH}"
6+
7+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
8+
curl \
9+
libssl-dev \
10+
pkg-config \
11+
&& rm -rf /var/lib/apt/lists/*
12+
13+
# Donwload and configure sccache
14+
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
15+
chmod +x /usr/local/bin/sccache
16+
17+
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
18+
RUN cargo install cargo-chef --locked
19+
20+
FROM base-builder AS planner
21+
22+
WORKDIR /usr/src
23+
24+
COPY backends backends
25+
COPY core core
26+
COPY router router
27+
COPY Cargo.toml ./
28+
COPY Cargo.lock ./
29+
30+
RUN cargo chef prepare --recipe-path recipe.json
31+
32+
FROM base-builder AS builder
33+
34+
ARG CUDA_COMPUTE_CAP=80
35+
ARG GIT_SHA
36+
ARG DOCKER_LABEL
37+
38+
# sccache specific variables
39+
ARG ACTIONS_CACHE_URL
40+
ARG ACTIONS_RUNTIME_TOKEN
41+
ARG SCCACHE_GHA_ENABLED
42+
43+
WORKDIR /usr/src
44+
45+
COPY --from=planner /usr/src/recipe.json recipe.json
46+
47+
RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s
48+
49+
COPY backends backends
50+
COPY core core
51+
COPY router router
52+
COPY Cargo.toml ./
53+
COPY Cargo.lock ./
54+
55+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
56+
unzip \
57+
&& rm -rf /var/lib/apt/lists/*
58+
59+
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
60+
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
61+
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
62+
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
63+
rm -f $PROTOC_ZIP
64+
65+
COPY proto proto
66+
67+
FROM builder as http-builder
68+
69+
RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s
70+
71+
FROM builder as grpc-builder
72+
73+
RUN cargo build --release --bin text-embeddings-router -F python -F grpc --no-default-features && sccache -s
74+
75+
FROM rocm/dev-ubuntu-22.04:6.0.2 as base
76+
77+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
78+
git \
79+
python3-dev \
80+
rocthrust-dev \
81+
hipsparse-dev \
82+
hipblas-dev \
83+
hipblaslt-dev \
84+
rocblas-dev \
85+
hiprand-dev \
86+
rocrand-dev \
87+
&& rm -rf /var/lib/apt/lists/*
88+
89+
90+
# Keep in sync with `server/pyproject.toml
91+
ARG MAMBA_VERSION=23.1.0-1
92+
ARG PYTORCH_VERSION='2.3.0'
93+
ARG ROCM_VERSION='6.0.2'
94+
ARG PYTHON_VERSION='3.10.10'
95+
# Automatically set by buildx
96+
ARG TARGETPLATFORM
97+
ENV PATH /opt/conda/bin:$PATH
98+
99+
RUN curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh"
100+
RUN chmod +x ~/mambaforge.sh && \
101+
bash ~/mambaforge.sh -b -p /opt/conda && \
102+
mamba init && \
103+
rm ~/mambaforge.sh
104+
105+
# Install flash-attention, torch dependencies
106+
RUN pip install numpy einops ninja --no-cache-dir
107+
108+
RUN pip install torch --index-url https://download.pytorch.org/whl/rocm6.0
109+
110+
ARG DEFAULT_USE_FLASH_ATTENTION=True
111+
COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2
112+
RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm
113+
114+
ENV HUGGINGFACE_HUB_CACHE=/data \
115+
PORT=80 \
116+
USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION
117+
118+
FROM base as grpc
119+
120+
COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
121+
122+
ENTRYPOINT ["text-embeddings-router"]
123+
CMD ["--json-output"]
124+
125+
FROM base
126+
127+
COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
128+
129+
ENTRYPOINT ["text-embeddings-router"]
130+
CMD ["--json-output"]

backends/python/Makefile-flash-att-v2

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
flash_att_v2_commit_cuda := v2.5.9.post1
2+
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
3+
4+
build-flash-attention-v2-cuda:
5+
pip install -U packaging wheel
6+
pip install flash-attn==$(flash_att_v2_commit_cuda)
7+
8+
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
9+
echo "Flash v2 installed"
10+
11+
build-flash-attention-v2-rocm:
12+
if [ ! -d 'flash-attention-v2' ]; then \
13+
pip install -U packaging ninja --no-cache-dir && \
14+
git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \
15+
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
16+
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
17+
fi
18+
19+
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
20+
cd flash-attention-v2 && \
21+
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install

0 commit comments

Comments
 (0)