Skip to content

ENH: enable RunsOn with custom ami #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions .github/workflows/cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,27 @@ on:
workflow_dispatch:
jobs:
cache:
runs-on: quantecon-gpu
container:
image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
options: --gpus all
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
auto-activate-base: true
miniconda-version: 'latest'
python-version: "3.12"
environment-file: environment.yml
activate-environment: quantecon
- name: Install JAX, Numpyro, PyTorch
shell: bash -l {0}
run: |
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --upgrade "jax[cuda12-local]"
pip install numpyro
python scripts/test-jax-install.py
- name: Check nvidia drivers
shell: bash -l {0}
run: |
Expand Down
28 changes: 21 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
name: Build Project [using jupyter-book]
on: [pull_request]
on:
pull_request:
workflow_dispatch:
jobs:
preview:
runs-on: quantecon-gpu
container:
image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
options: --gpus all
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
# Check nvidia drivers
- name: nvidia Drivers
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
auto-activate-base: true
miniconda-version: 'latest'
python-version: "3.12"
environment-file: environment.yml
activate-environment: quantecon
- name: Install JAX, Numpyro, PyTorch
shell: bash -l {0}
run: |
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --upgrade "jax[cuda12-local]"
pip install numpyro
python scripts/test-jax-install.py
- name: Check nvidia Drivers
shell: bash -l {0}
run: nvidia-smi
- name: Display Conda Environment Versions
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/collab.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build Project on Google Collab (Execution)
on: [pull_request]
jobs:
execution-checks:
runs-on: quantecon-gpu
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large"
container:
image: docker://us-docker.pkg.dev/colab-images/public/runtime
options: --gpus all
Expand Down
22 changes: 16 additions & 6 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,26 @@ on:
jobs:
publish:
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
runs-on: quantecon-gpu
container:
image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312
options: --gpus all
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Git (required to commit notebooks)
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
auto-activate-base: true
miniconda-version: 'latest'
python-version: "3.12"
environment-file: environment.yml
activate-environment: quantecon
- name: Install JAX, Numpyro, PyTorch
shell: bash -l {0}
run: apt-get install -y git
run: |
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --upgrade "jax[cuda12-local]"
pip install numpyro
python scripts/test-jax-install.py
- name: Check nvidia drivers
shell: bash -l {0}
run: |
Expand Down
21 changes: 21 additions & 0 deletions scripts/test-jax-install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import jax
import jax.numpy as jnp

devices = jax.devices()
print(f"The available devices are: {devices}")

@jax.jit
def matrix_multiply(a, b):
return jnp.dot(a, b)

# Example usage:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 1000))
y = jax.random.normal(key, (1000, 1000))
z = matrix_multiply(x, y)

# Now the function is JIT compiled and will likely run on GPU (if available)
print(z)

devices = jax.devices()
print(f"The available devices are: {devices}")
Loading