Skip to content

Commit 4156088

Browse files
committed
Fix wrong Pytorch version with CUDA 11.8
1 parent 42011ee commit 4156088

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

.github/workflows/build-wheels-release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
$cudaNum = [int]$cudaVersion.substring($cudaVersion.LastIndexOf('.')+1)
6464
while ($cudaNum -ge 0) { $cudaChannels += '-c nvidia/label/cuda-' + $cudaVersion.Remove($cudaVersion.LastIndexOf('.')+1) + $cudaNum + ' '; $cudaNum-- }
6565
mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()
66-
if ([version]$env:CUDAVER -gt [version]'11.8.0') {$torch = "torch==2.1.0"} else {$torch = "torch==2.0.1"}
66+
if ([version]$env:CUDAVER -lt [version]'11.8.0') {$torch = "torch==2.0.1"} else {$torch = "torch==2.1.0"}
6767
python -m pip install build wheel safetensors sentencepiece ninja $torch --extra-index-url "https://download.pytorch.org/whl/cu$cudaVersionPytorch"
6868
6969
- name: Build Wheel

.github/workflows/build-wheels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
$cudaNum = [int]$cudaVersion.substring($cudaVersion.LastIndexOf('.')+1)
6161
while ($cudaNum -ge 0) { $cudaChannels += '-c nvidia/label/cuda-' + $cudaVersion.Remove($cudaVersion.LastIndexOf('.')+1) + $cudaNum + ' '; $cudaNum-- }
6262
mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()
63-
if ([version]$env:CUDAVER -gt [version]'11.8.0') {$torch = "torch==2.1.0"} else {$torch = "torch==2.0.1"}
63+
if ([version]$env:CUDAVER -lt [version]'11.8.0') {$torch = "torch==2.0.1"} else {$torch = "torch==2.1.0"}
6464
python -m pip install build wheel safetensors sentencepiece ninja $torch --extra-index-url "https://download.pytorch.org/whl/cu$cudaVersionPytorch"
6565
6666
- name: Build Wheel

0 commit comments

Comments
 (0)