This repository contains JAX code for pretraining a GPT2 model on Google TPUs. The notebook is adapted from the offical miniGPT tutorial and works on free-tier Colab, Kaggle and Cloud TPU (single board TPU v2 - v6e all work).
Only GPT2 and GPT2 medium have been tested. Bigger variants will OOM on TPU v3. Dataset used is OpenWebText, same as nanoGPT because I wanted to compare the final losses.
First, get your Kaggle credential and Weights and Biases API key ready. Add to your secrets if you are using Colab and Kaggle.
- Manully download the OpenWebText dataset using Kaggle CLI (you need to set up kaggle.json first) like this:
kaggle datasets download -d windmaple/OpenWebText-gpt2 && unzip OpenWebText-gpt2.zip
- I store the
.bin
files on my Google Drive (/content/drive/MyDrive/LLM-pretraining/OpenWebText/
) so that they are cached. Changedata_dir
if your.bin
files are in a different folder. - Make sure your W&B API key is accessible to your notebook
- Connect the TPU v2 runtime and run. Free tiers gives you an hour or two before disconnecting, so you can't really finish the training
- A paid account, you may be able to see it through, although I haven't tried it myself (TPU v2 is just too slow). Colab now offers TPU v5e as well, but it's only one chip (unlike v2's 4 chip) and you need to change the mesh to run on it
Kaggle is more generous, offering 9 hours of non-interrupted TPU v3 per session, which is sufficient to train the smallest GPT2 variant.
- Import the notebook on Kaggle
- Add the OpenWebText dataset as input in the top right corner of the side panel on the right.
- Make sure W&B API key is accessible to your notebook
- Choose TPU v3
- Run the notebook. Kaggle will first download the dataset first. After that, it takes ~7 hours to finish.
- You can also try GPT2 medium if you change the
GPT2_variant
variable. Kaggle will stop the run before it finishes though
An alternative way to run is to Save version
-> Save & Run ALL
, which just run the notebook in the background.
Technically you can also train a GPT2 medium model on Kaggle; although Kaggle disconnects you after 9 full hours it can save checkpoint files for you, so that you can resume training. But I haven't tried this because it's a bit of a pain.
OK, I lied. This one is not free. But since you are the paying God, you can pretty much do whatever you want, like training GPT2 medium to completion.
- Spin up your TPU VM and ssh into it
- Download the notebook
- Pip install kaggle and get your kaggle.json
- Pip install jupyter
- Start a tmux or screen session and then run the notebook like this (an alternative is to convert it to a
.py
file):export WANDB_API_KEY=$your_key; time jupyter execute GPT2_pretrain.ipynb
- There won't be much logging shown in the console, but don't worry, everything is directed to your W&B so you can see the output there
W&B has integration with Cloud TPU and reports TPU metrics in the systems panel (2nd page) automatically.
You can also, in another console, pip install tpu-info
and then watch -n 1 tpu-info
.
Google Cloud console has additional monitoring tools if you use v4 or newer.
Neither is integrated with W&B unfortunately. But you can still pip install tpu-info
and then add !tpu-info
in the middle of the training loop. Note that this might slow down training a bit.
If stars are aligned, you can get the final losses like below:

which are very much in line with nanoGPT's.
Trillium chips, which have 32G HBM per chip and can accommodate 2X batch size, can finish training in just 82 minutes.