This repository is an extension of the remarkable project ayeka/llama-2-jax by Ayeka. It provides support for a diverse range of Models, including Mistral, Vicuna, Orca, Zephyr, and more. The conversion for using these models through JAX has been successfully implemented for various instances.
The models currently converted include:
- Mistral-7B-v0.1
- Mistral-7B-Instruct-v0.2
- SOLAR-10.7B-v1.0
- SOLAR-10.7B-Instruct-v1.0
- StableBeluga-13B
- StableBeluga-7B
- StableLM-Zephyr-3B
- Zephyr-7B-beta
- llama1-30B
- Vicuna-7B
- Vicuna-13B
- llama1-13B
- llama2-13B
- llama1-7B
- llama2-7B
- Orca-2-13B
- Orca-2-7B
- Mistral-7B-Instruct-v0.1
- llama2-7B-chat
- llama2-13B-chat
You can find converted model pickle files on Hugging Face here.
The project is backed by Cloud TPUs from Google's TPU Research Cloud.
Comprehensive documentation for this library, authored by Ayeka, is available on GitHub Pages.
Run this script on TPU-VMs for setting-up environment.
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt update
sudo apt install python3.11-full -y
python3.11 -m venv venv
. venv/bin/activate
pip install -U pip
pip install -U wheel
pip install jupyter notebook
git clone https://github.com/divyapatel4/llama-2-jax.git
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install git+https://github.com/huggingface/transformers.git
pip install git+https://github.com/deepmind/optax.git # https://github.com/google-deepmind/optax/issues/472
pip install -r llama-2-jax/requirements.txt
sudo apt-get install libpython3.11
python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('TOKEN')"
If using TPU pods, populate ~/podips.txt
with the IP addresses of all hosts. Ensure local SSH access to itself and all IPs listed.
On TPU Pods:
./podrun -i -- ~/venv/bin/pip install git+https://github.com/huggingface/transformers.git
./podrun -i -- ~/venv/bin/pip install git+https://github.com/deepmind/optax.git
./podrun -iw -- ~/venv/bin/pip install -r requirements.txt
If working with Llama 2 models, log in to Hugging Face CLI:
huggingface-cli login
In non-interactive mode:
python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<YOUR_HUGGING_FACE_TOKEN>')"
On TPU Pods:
./podrun -i -- ~/venv/bin/python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<YOUR_HUGGING_FACE_TOKEN>')"
python scripts/convert_params_runner.py llama2-7B
python generate.py
On TPU Pods:
./podrun -icw ~/venv/bin/python generate.py
Before training, log in to W&B:
wandb login <YOUR_WANDB_API_KEY>
For a training example, clone the Grade School Math (GSM) dataset:
cd .. && git clone --depth=1 https://github.com/openai/grade-school-math.git
python train.py
On TPU Pods:
./podrun -icw ~/venv/bin/python train.py