Skip to content

Commit 6f1510b

Browse files
authored
ENH: enable RunsOn with custom ami (#451)
* ENH: enable RunsOn with custom ami * fix custom ami name for RunsOn * use g4dn.xlarge * enable manual trigering' * use 2xlarge instance * tmp: disable build cache * Revert "tmp: disable build cache" This reverts commit f18500e. * convert cache, collab and publish to use runson
1 parent 13abb75 commit 6f1510b

File tree

5 files changed

+76
-18
lines changed

5 files changed

+76
-18
lines changed

.github/workflows/cache.yml

+17-4
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,27 @@ on:
66
workflow_dispatch:
77
jobs:
88
cache:
9-
runs-on: quantecon-gpu
10-
container:
11-
image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
12-
options: --gpus all
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1310
steps:
1411
- uses: actions/checkout@v4
1512
with:
1613
ref: ${{ github.event.pull_request.head.sha }}
14+
- name: Setup Anaconda
15+
uses: conda-incubator/setup-miniconda@v3
16+
with:
17+
auto-update-conda: true
18+
auto-activate-base: true
19+
miniconda-version: 'latest'
20+
python-version: "3.12"
21+
environment-file: environment.yml
22+
activate-environment: quantecon
23+
- name: Install JAX, Numpyro, PyTorch
24+
shell: bash -l {0}
25+
run: |
26+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
27+
pip install --upgrade "jax[cuda12-local]"
28+
pip install numpyro
29+
python scripts/test-jax-install.py
1730
- name: Check nvidia drivers
1831
shell: bash -l {0}
1932
run: |

.github/workflows/ci.yml

+21-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
11
name: Build Project [using jupyter-book]
2-
on: [pull_request]
2+
on:
3+
pull_request:
4+
workflow_dispatch:
35
jobs:
46
preview:
5-
runs-on: quantecon-gpu
6-
container:
7-
image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
8-
options: --gpus all
7+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
98
steps:
109
- uses: actions/checkout@v4
1110
with:
1211
ref: ${{ github.event.pull_request.head.sha }}
13-
# Check nvidia drivers
14-
- name: nvidia Drivers
12+
- name: Setup Anaconda
13+
uses: conda-incubator/setup-miniconda@v3
14+
with:
15+
auto-update-conda: true
16+
auto-activate-base: true
17+
miniconda-version: 'latest'
18+
python-version: "3.12"
19+
environment-file: environment.yml
20+
activate-environment: quantecon
21+
- name: Install JAX, Numpyro, PyTorch
22+
shell: bash -l {0}
23+
run: |
24+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
25+
pip install --upgrade "jax[cuda12-local]"
26+
pip install numpyro
27+
python scripts/test-jax-install.py
28+
- name: Check nvidia Drivers
1529
shell: bash -l {0}
1630
run: nvidia-smi
1731
- name: Display Conda Environment Versions

.github/workflows/collab.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Build Project on Google Collab (Execution)
22
on: [pull_request]
33
jobs:
44
execution-checks:
5-
runs-on: quantecon-gpu
5+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large"
66
container:
77
image: docker://us-docker.pkg.dev/colab-images/public/runtime
88
options: --gpus all

.github/workflows/publish.yml

+16-6
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,26 @@ on:
66
jobs:
77
publish:
88
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
9-
runs-on: quantecon-gpu
10-
container:
11-
image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
12-
options: --gpus all
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1310
steps:
1411
- name: Checkout
1512
uses: actions/checkout@v4
16-
- name: Install Git (required to commit notebooks)
13+
- name: Setup Anaconda
14+
uses: conda-incubator/setup-miniconda@v3
15+
with:
16+
auto-update-conda: true
17+
auto-activate-base: true
18+
miniconda-version: 'latest'
19+
python-version: "3.12"
20+
environment-file: environment.yml
21+
activate-environment: quantecon
22+
- name: Install JAX, Numpyro, PyTorch
1723
shell: bash -l {0}
18-
run: apt-get install -y git
24+
run: |
25+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
26+
pip install --upgrade "jax[cuda12-local]"
27+
pip install numpyro
28+
python scripts/test-jax-install.py
1929
- name: Check nvidia drivers
2030
shell: bash -l {0}
2131
run: |

scripts/test-jax-install.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
devices = jax.devices()
5+
print(f"The available devices are: {devices}")
6+
7+
@jax.jit
8+
def matrix_multiply(a, b):
9+
return jnp.dot(a, b)
10+
11+
# Example usage:
12+
key = jax.random.PRNGKey(0)
13+
x = jax.random.normal(key, (1000, 1000))
14+
y = jax.random.normal(key, (1000, 1000))
15+
z = matrix_multiply(x, y)
16+
17+
# Now the function is JIT compiled and will likely run on GPU (if available)
18+
print(z)
19+
20+
devices = jax.devices()
21+
print(f"The available devices are: {devices}")

0 commit comments

Comments
 (0)