Skip to content

Commit 40eccc2

Browse files
New TPU Bazel Presubmit
1 parent 9bdc04c commit 40eccc2

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed

.bazelrc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ build:rocm --action_env=TF_HIPCC_CLANG="1"
218218
build:public_cache --remote_cache="https://storage.googleapis.com/jax-bazel-cache/" --remote_upload_local_results=false
219219
# Cache pushes are limited to JAX's CI system.
220220
build:public_cache_push --config=public_cache --remote_upload_local_results=true --google_default_credentials
221+
# CI only bazel cache for presubmits
222+
build:ci_non_rbe_cache --remote_cache="https://storage.googleapis.com/jax-presubmit-bazel-cache/"
223+
build:ci_non_rbe_cache --remote_upload_local_results=true
224+
build:ci_non_rbe_cache --google_default_credentials
221225

222226
# Note: the following cache configs are deprecated and will be removed soon.
223227
# Public read-only cache for Mac builds. JAX uses a GCS bucket to store cache

.github/workflows/cloud-tpu-ci-presubmit.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,66 @@ jobs:
4646
clone_main_xla: 1
4747
upload_artifacts_to_gcs: true
4848
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
49+
run-bazel-tpu:
50+
name: "TPU bazel tests"
51+
defaults:
52+
run:
53+
shell: bash
54+
runs-on: "linux-x86-ct6e-180-8tpu"
55+
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest"
56+
env:
57+
JAXCI_HERMETIC_PYTHON_VERSION: "3.13-ft"
58+
JAXCI_PYTHON: "python3.13t"
59+
steps:
60+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
61+
with:
62+
persist-credentials: false
63+
- name: Wait For Connection
64+
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
65+
with:
66+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
67+
- name: Install nightly libtpu
68+
run: |
69+
$JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
70+
- name: Run Bazel Tests
71+
run: |
72+
SRCDIR=$(pwd)
73+
# Run bazel test, room for improvement by either using prebuilt jaxlib, using a cache, or using an RBE instance
74+
# Test envs must be passed in due to not having access to the TPU metadata information
75+
bazel test \
76+
--config=ci_linux_x86_64 \
77+
--repo_env=HERMETIC_PYTHON_VERSION=3.13-ft \
78+
--@rules_python//python/config_settings:py_freethreaded='yes' \
79+
--run_under $SRCDIR/build/parallel_accelerator_execute.sh \
80+
--test_env=JAX_ACCELERATOR_COUNT=8 \
81+
--test_env=JAX_TESTS_PER_ACCELERATOR=1 \
82+
--local_test_jobs=8 \
83+
--test_tag_filters=-multiaccelerator \
84+
--test_env=ALLOW_MULTIPLE_LIBTPU_LOAD=1 \
85+
--test_env=JAX_TEST_NUM_THREADS=16 \
86+
--test_sharding_strategy=disabled \
87+
--nocache_test_results \
88+
--generate_json_trace_profile \
89+
--test_timeout=600 \
90+
--test_env=JAX_SKIP_SLOW_TESTS=1 \
91+
--test_env=TPU_SKIP_MDS_QUERY=true \
92+
--test_env=CHIPS_PER_HOST_BOUNDS="$CHIPS_PER_HOST_BOUNDS" \
93+
--test_env=HOST_BOUNDS="$HOST_BOUNDS" \
94+
--test_env=ALT=false \
95+
--test_env=WRAP="$WRAP" \
96+
--test_env=TPU_WORKER_ID="$TPU_WORKER_ID" \
97+
--test_env=TPU_ACCELERATOR_TYPE="$TPU_ACCELERATOR_TYPE" \
98+
--test_env=TPU_WORKER_HOSTNAMES="$TPU_WORKER_HOSTNAMES" \
99+
--test_env=TPU_RUNTIME_METRICS_PORTS="$TPU_RUNTIME_METRICS_PORTS" \
100+
--test_env=TPU_RUNTIME_METRICS_PORTS="$TPU_RUNTIME_METRICS_PORTS" \
101+
--test_env=TPU_HOST_BOUNDS="$TPU_HOST_BOUNDS" \
102+
--test_env=TPU_TOPOLOGY_ALT="$TPU_TOPOLOGY_ALT" \
103+
--test_env=TPU_TOPOLOGY_ALT="$TPU_TOPOLOGY_WRAP" \
104+
--test_env=VBAR_CONTROL_SERVICE_URL="$VBAR_CONTROL_SERVICE_URL" \
105+
-- //tests:tpu_tests
49106
107+
# Copy tpu logs to tranisent bucket
108+
./ci/copy_logs.sh
50109
run-pytest-tpu:
51110
if: github.event.repository.fork == false
52111
needs: [build-jax-artifacts]

ci/copy_logs.sh

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/bin/bash
2+
# Copyright 2025 The JAX Authors.
3+
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ==============================================================================
17+
18+
# Script that can be run at the end of jobs to copy the logs or pytest results to a transient CI bucket
19+
20+
GCS_BUCKET="gs://general-ml-ci-transient/jax-github-actions/logs/$GITHUB_REPOSITORY/$GITHUB_RUN_ID/$GITHUB_RUN_ATTEMPT"
21+
22+
# Check if we are in a Bazel workspace.
23+
if [ ! -f "WORKSPACE" ] && [ ! -f "WORKSPACE.bazel" ]; then
24+
echo "ERROR: No WORKSPACE file found. Please run this script from the root of your Bazel workspace."
25+
exit 1
26+
fi
27+
28+
29+
# --- Main Logic ---
30+
31+
echo "🚀 Starting Bazel test log upload to GCS..."
32+
echo "Destination Bucket: $GCS_BUCKET"
33+
34+
# Get the path to the bazel-out directory. We use XDG_CACHE_HOME if set as it doesn't require starting bazel and is much quicker
35+
if [ -n "$XDG_CACHE_HOME" ]; then
36+
MD5=($(echo -n $(pwd) | md5sum))
37+
BAZEL_OUT_DIR="$XDG_CACHE_HOME/bazel/_bazel_root/$MD5/execroot/__main__/bazel-out"
38+
else
39+
BAZEL_OUT_DIR=$(bazel info output_path)
40+
fi
41+
42+
if [ ! -d "$BAZEL_OUT_DIR" ]; then
43+
echo "ERROR: Could not find the Bazel output directory at '$BAZEL_OUT_DIR'."
44+
echo "Have you built or tested any targets yet?"
45+
exit 1
46+
fi
47+
48+
echo "Searching for 'testlogs' directories under: $BAZEL_OUT_DIR"
49+
50+
# Use 'find' to locate all directories named 'testlogs'.
51+
found_logs=0
52+
find "$BAZEL_OUT_DIR" -type d -name "testlogs" | while read -r testlogs_path; do
53+
found_logs=1
54+
echo "======================================================================"
55+
echo "Found testlogs directory: $testlogs_path"
56+
57+
# To avoid naming collisions in GCS, we create a descriptive path from the
58+
# log file's location relative to the bazel-out directory.
59+
# e.g., 'k8-fastbuild/testlogs' becomes 'k8-fastbuild-testlogs'
60+
relative_path=${testlogs_path#"$BAZEL_OUT_DIR/"}
61+
gcs_prefix=$(echo "$relative_path" | tr '/' '-')
62+
63+
# Define the final destination path in the GCS bucket.
64+
GCS_DESTINATION_PATH="${GCS_BUCKET}/${gcs_prefix}/"
65+
66+
echo "Uploading contents to: $GCS_DESTINATION_PATH"
67+
68+
# Use gsutil to copy the entire directory's contents recursively.
69+
# The '-m' flag enables parallel (multi-threaded/multi-processing) uploads.
70+
# The '-r' flag copies directories recursively.
71+
# The trailing '*' ensures the *contents* of the directory are copied.
72+
# gsutil -m cp -r "${testlogs_path}/*" "$GCS_DESTINATION_PATH"
73+
74+
echo "Upload complete for $testlogs_path"
75+
done
76+
77+
echo "Found logs $found_logs"
78+
# --- Final Check ---
79+
if [[ "$found_logs" -eq 0 ]]; then
80+
echo "Log: No 'testlogs' directories were found."
81+
fi
82+
83+
echo "======================================================================"
84+
echo "Script finished."
85+

0 commit comments

Comments
 (0)