Skip to content

AYUSH27112021/llama2-from-scratch-jax

Repository files navigation

🦙 LLaMA2-from-Scratch (JAX)

This project implements Meta’s LLaMA2 language model architecture from scratch using JAX, with a focus on clarity and modularity. It's intended as a learning resource and a base for further experimentation.


📁 Project Structure

llama2-from-scratch-jax/
├── Model/                         # Saved trained model checkpoints
├── Tokenized/                    # Tokenized dataset used for training
├── Tokenizer/                    # Trained tokenizer model and vocab files
├── .gitattributes
├── .gitignore
├── Tokenizer_for_llama2.ipynb    # Notebook for training/preparing tokenizer
├── configuration.py              # Model/training hyperparameters and config
├── export.py                     # Utility for exporting model checkpoints
├── model.py                      # Core model implementation (LLaMA2 in JAX)
├── train.py                      # Training loop

🚀 Getting Started

1. Clone the repository

git clone https://github.com/your-username/llama2-from-scratch-jax.git
cd llama2-from-scratch-jax

2. Install dependencies

Make sure you have Python 3.10+ and install necessary packages:

pip install -r requirements.txt

Required packages typically include:

  • jax, flax, optax
  • numpy
  • tokenizers or Hugging Face tokenizers
  • tqdm, matplotlib, etc. (optional)

🧠 Overview

  • Model: Implements LLaMA2 transformer architecture using JAX and Flax.
  • Tokenizer: Prepares and stores the vocabulary and tokenizer model.
  • Training: Modular and easy-to-read training loop with checkpoint support.
  • Export: Utilities to save/export trained model weights.

📊 Workflow

  1. Tokenizer Preparation
    Use Tokenizer_for_llama2.ipynb to train a tokenizer or load a pretrained one. Output is saved to the Tokenizer/ directory.

  2. Tokenize Dataset
    Store your processed/tokenized data in the Tokenized/ directory.

  3. Training the Model
    Edit configs in configuration.py as needed, then run:

    python train.py

    Checkpoints will be saved to the Model/ directory.

  4. Exporting
    To convert model weights to a usable format, use:

    python export.py

❗ Notes

  • This is an experimental implementation; not optimized for large-scale training.
  • No license

Let me know if you want to add example outputs or Colab instructions!

About

Reproducing LLaMA 2 in pure JAX, with clarity, modularity, and research reproducibility.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published