@@ -12,9 +12,17 @@ ARG CUDA_MINOR_VERSION
12
12
# TORCHVISION_VERSION is mandatory
13
13
RUN test -n "$TORCHVISION_VERSION"
14
14
15
+ # Use mamba to speed up conda installs
16
+ RUN conda install -c conda-forge mamba
17
+
18
+ # Install cudf/cuml so that cudatoolkit upgrades are included in the pytorch build
19
+ RUN conda config --add channels nvidia && \
20
+ conda config --add channels rapidsai
21
+ RUN mamba install -y cudf cuml
22
+
15
23
# Build instructions: https://github.com/pytorch/pytorch#from-source
16
- RUN conda install astunparse numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
17
- RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}
24
+ RUN mamba install astunparse numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
25
+ RUN mamba install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}
18
26
19
27
# By default, it uses the version from version.txt which includes the `a0` (alpha zero) suffix and part of the git hash.
20
28
# This causes dependency conflicts like these: https://paste.googleplex.com/4786486378496000
@@ -46,7 +54,7 @@ RUN sudo apt-get update && \
46
54
# ncurses.h is required for this install
47
55
sudo apt-get install libncurses-dev && \
48
56
# Fixing the build: https://github.com/pytorch/audio/issues/666#issuecomment-635928685
49
- conda install -c conda-forge ncurses && \
57
+ mamba install -c conda-forge ncurses && \
50
58
cd /usr/local/src && \
51
59
git clone https://github.com/pytorch/audio && \
52
60
cd audio && \
0 commit comments