File tree 5 files changed +76
-18
lines changed
5 files changed +76
-18
lines changed Original file line number Diff line number Diff line change 6
6
workflow_dispatch :
7
7
jobs :
8
8
cache :
9
- runs-on : quantecon-gpu
10
- container :
11
- image : ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
12
- options : --gpus all
9
+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
13
10
steps :
14
11
- uses : actions/checkout@v4
15
12
with :
16
13
ref : ${{ github.event.pull_request.head.sha }}
14
+ - name : Setup Anaconda
15
+ uses : conda-incubator/setup-miniconda@v3
16
+ with :
17
+ auto-update-conda : true
18
+ auto-activate-base : true
19
+ miniconda-version : ' latest'
20
+ python-version : " 3.12"
21
+ environment-file : environment.yml
22
+ activate-environment : quantecon
23
+ - name : Install JAX, Numpyro, PyTorch
24
+ shell : bash -l {0}
25
+ run : |
26
+ pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
27
+ pip install --upgrade "jax[cuda12-local]"
28
+ pip install numpyro
29
+ python scripts/test-jax-install.py
17
30
- name : Check nvidia drivers
18
31
shell : bash -l {0}
19
32
run : |
Original file line number Diff line number Diff line change 1
1
name : Build Project [using jupyter-book]
2
- on : [pull_request]
2
+ on :
3
+ pull_request :
4
+ workflow_dispatch :
3
5
jobs :
4
6
preview :
5
- runs-on : quantecon-gpu
6
- container :
7
- image : ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
8
- options : --gpus all
7
+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
9
8
steps :
10
9
- uses : actions/checkout@v4
11
10
with :
12
11
ref : ${{ github.event.pull_request.head.sha }}
13
- # Check nvidia drivers
14
- - name : nvidia Drivers
12
+ - name : Setup Anaconda
13
+ uses : conda-incubator/setup-miniconda@v3
14
+ with :
15
+ auto-update-conda : true
16
+ auto-activate-base : true
17
+ miniconda-version : ' latest'
18
+ python-version : " 3.12"
19
+ environment-file : environment.yml
20
+ activate-environment : quantecon
21
+ - name : Install JAX, Numpyro, PyTorch
22
+ shell : bash -l {0}
23
+ run : |
24
+ pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
25
+ pip install --upgrade "jax[cuda12-local]"
26
+ pip install numpyro
27
+ python scripts/test-jax-install.py
28
+ - name : Check nvidia Drivers
15
29
shell : bash -l {0}
16
30
run : nvidia-smi
17
31
- name : Display Conda Environment Versions
Original file line number Diff line number Diff line change @@ -2,7 +2,7 @@ name: Build Project on Google Collab (Execution)
2
2
on : [pull_request]
3
3
jobs :
4
4
execution-checks :
5
- runs-on : quantecon- gpu
5
+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24- gpu-x64/disk=large "
6
6
container :
7
7
image : docker://us-docker.pkg.dev/colab-images/public/runtime
8
8
options : --gpus all
Original file line number Diff line number Diff line change 6
6
jobs :
7
7
publish :
8
8
if : github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
9
- runs-on : quantecon-gpu
10
- container :
11
- image : ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
12
- options : --gpus all
9
+ runs-on : " runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
13
10
steps :
14
11
- name : Checkout
15
12
uses : actions/checkout@v4
16
- - name : Install Git (required to commit notebooks)
13
+ - name : Setup Anaconda
14
+ uses : conda-incubator/setup-miniconda@v3
15
+ with :
16
+ auto-update-conda : true
17
+ auto-activate-base : true
18
+ miniconda-version : ' latest'
19
+ python-version : " 3.12"
20
+ environment-file : environment.yml
21
+ activate-environment : quantecon
22
+ - name : Install JAX, Numpyro, PyTorch
17
23
shell : bash -l {0}
18
- run : apt-get install -y git
24
+ run : |
25
+ pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
26
+ pip install --upgrade "jax[cuda12-local]"
27
+ pip install numpyro
28
+ python scripts/test-jax-install.py
19
29
- name : Check nvidia drivers
20
30
shell : bash -l {0}
21
31
run : |
Original file line number Diff line number Diff line change
1
+ import jax
2
+ import jax .numpy as jnp
3
+
4
+ devices = jax .devices ()
5
+ print (f"The available devices are: { devices } " )
6
+
7
+ @jax .jit
8
+ def matrix_multiply (a , b ):
9
+ return jnp .dot (a , b )
10
+
11
+ # Example usage:
12
+ key = jax .random .PRNGKey (0 )
13
+ x = jax .random .normal (key , (1000 , 1000 ))
14
+ y = jax .random .normal (key , (1000 , 1000 ))
15
+ z = matrix_multiply (x , y )
16
+
17
+ # Now the function is JIT compiled and will likely run on GPU (if available)
18
+ print (z )
19
+
20
+ devices = jax .devices ()
21
+ print (f"The available devices are: { devices } " )
You can’t perform that action at this time.
0 commit comments