This project implements Meta’s LLaMA2 language model architecture from scratch using JAX, with a focus on clarity and modularity. It's intended as a learning resource and a base for further experimentation.
llama2-from-scratch-jax/
├── Model/ # Saved trained model checkpoints
├── Tokenized/ # Tokenized dataset used for training
├── Tokenizer/ # Trained tokenizer model and vocab files
├── .gitattributes
├── .gitignore
├── Tokenizer_for_llama2.ipynb # Notebook for training/preparing tokenizer
├── configuration.py # Model/training hyperparameters and config
├── export.py # Utility for exporting model checkpoints
├── model.py # Core model implementation (LLaMA2 in JAX)
├── train.py # Training loop
git clone https://github.com/your-username/llama2-from-scratch-jax.git
cd llama2-from-scratch-jax
Make sure you have Python 3.10+ and install necessary packages:
pip install -r requirements.txt
Required packages typically include:
jax
,flax
,optax
numpy
tokenizers
or Hugging Face tokenizerstqdm
,matplotlib
, etc. (optional)
- Model: Implements LLaMA2 transformer architecture using JAX and Flax.
- Tokenizer: Prepares and stores the vocabulary and tokenizer model.
- Training: Modular and easy-to-read training loop with checkpoint support.
- Export: Utilities to save/export trained model weights.
-
Tokenizer Preparation
UseTokenizer_for_llama2.ipynb
to train a tokenizer or load a pretrained one. Output is saved to theTokenizer/
directory. -
Tokenize Dataset
Store your processed/tokenized data in theTokenized/
directory. -
Training the Model
Edit configs inconfiguration.py
as needed, then run:python train.py
Checkpoints will be saved to the
Model/
directory. -
Exporting
To convert model weights to a usable format, use:python export.py
- This is an experimental implementation; not optimized for large-scale training.
- No license
Let me know if you want to add example outputs or Colab instructions!