This repository contains implementations of In-Context Learning for classification experiments using both JAX and PyTorch frameworks.
main.py
- Main PyTorch implementationmain_jax.py
- Legacy JAX implementationmodel.py
- PyTorch model implementationmodel_jax.py
- JAX model implementationdataset.py
- PyTorch dataset implementationdataset_jax.py
- JAX dataset implementationvisualize.ipynb
- Visualization notebooks for analysis (old results)- Various run scripts for different experiments
- Install dependencies:
pip install torch numpy wandb tqdm
- For JAX implementation (legacy):
pip install jax jaxlib
The main PyTorch implementation can be run using various scripts:
- Basic run:
./run_torch.sh
For the legacy JAX implementation:
./run.sh
Use the provided Jupyter notebooks for visualization and analysis:
visualize.ipynb
- Main visualization notebookvisualize_layer4.ipynb
- Layer 4 specific visualizations
Results are stored in:
outs_torch/
- PyTorch implementation outputsouts/
- JAX implementation outputs
- The PyTorch implementation (
main.py
) is the current active codebase, WANDB report see: https://wandb.ai/explainableml/ICL_torch/reports/Results---VmlldzoxMjE1Mjk0Nw?accessToken=l7pw9osgk32n02dgt2ysd8dr12mfmpb6v8pblg6gn54ph7ywxbu7kdgk1bq57r6m - The JAX implementation (
main_jax.py
) is maintained for reference, WANDB report see: https://api.wandb.ai/links/explainableml/iawqfvdf - WandB integration is available for experiment tracking (set
WANDB = True
in main.py)