diff --git a/.github/workflows/cache.yml b/.github/workflows/cache.yml index 7e7c6ef42..c499a1599 100644 --- a/.github/workflows/cache.yml +++ b/.github/workflows/cache.yml @@ -6,14 +6,27 @@ on: workflow_dispatch: jobs: cache: - runs-on: quantecon-gpu - container: - image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312 - options: --gpus all + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} + - name: Setup Anaconda + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + auto-activate-base: true + miniconda-version: 'latest' + python-version: "3.12" + environment-file: environment.yml + activate-environment: quantecon + - name: Install JAX, Numpyro, PyTorch + shell: bash -l {0} + run: | + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + pip install --upgrade "jax[cuda12-local]" + pip install numpyro + python scripts/test-jax-install.py - name: Check nvidia drivers shell: bash -l {0} run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 115186510..663720fd2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,17 +1,31 @@ name: Build Project [using jupyter-book] -on: [pull_request] +on: + pull_request: + workflow_dispatch: jobs: preview: - runs-on: quantecon-gpu - container: - image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312 - options: --gpus all + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - # Check nvidia drivers - - name: nvidia Drivers + - name: Setup Anaconda + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + auto-activate-base: true + miniconda-version: 'latest' + python-version: "3.12" + environment-file: environment.yml + activate-environment: quantecon + - name: Install JAX, Numpyro, PyTorch + shell: bash -l {0} + run: | + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + pip install --upgrade "jax[cuda12-local]" + pip install numpyro + python scripts/test-jax-install.py + - name: Check nvidia Drivers shell: bash -l {0} run: nvidia-smi - name: Display Conda Environment Versions diff --git a/.github/workflows/collab.yml b/.github/workflows/collab.yml index efd21e264..40d996535 100644 --- a/.github/workflows/collab.yml +++ b/.github/workflows/collab.yml @@ -2,7 +2,7 @@ name: Build Project on Google Collab (Execution) on: [pull_request] jobs: execution-checks: - runs-on: quantecon-gpu + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=ubuntu24-gpu-x64/disk=large" container: image: docker://us-docker.pkg.dev/colab-images/public/runtime options: --gpus all diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 52d194208..cba997aed 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,16 +6,26 @@ on: jobs: publish: if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') - runs-on: quantecon-gpu - container: - image: ghcr.io/quantecon/lecture-python-container:cuda-12.8.1-anaconda-2024-10-py312 - options: --gpus all + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - name: Checkout uses: actions/checkout@v4 - - name: Install Git (required to commit notebooks) + - name: Setup Anaconda + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + auto-activate-base: true + miniconda-version: 'latest' + python-version: "3.12" + environment-file: environment.yml + activate-environment: quantecon + - name: Install JAX, Numpyro, PyTorch shell: bash -l {0} - run: apt-get install -y git + run: | + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + pip install --upgrade "jax[cuda12-local]" + pip install numpyro + python scripts/test-jax-install.py - name: Check nvidia drivers shell: bash -l {0} run: | diff --git a/scripts/test-jax-install.py b/scripts/test-jax-install.py new file mode 100644 index 000000000..988a6ceaa --- /dev/null +++ b/scripts/test-jax-install.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp + +devices = jax.devices() +print(f"The available devices are: {devices}") + +@jax.jit +def matrix_multiply(a, b): + return jnp.dot(a, b) + +# Example usage: +key = jax.random.PRNGKey(0) +x = jax.random.normal(key, (1000, 1000)) +y = jax.random.normal(key, (1000, 1000)) +z = matrix_multiply(x, y) + +# Now the function is JIT compiled and will likely run on GPU (if available) +print(z) + +devices = jax.devices() +print(f"The available devices are: {devices}") \ No newline at end of file