Skip to content

YiranHuangIrene/ICL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

60 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

In-Context Learning (ICL) Implementation

This repository contains implementations of In-Context Learning for classification experiments using both JAX and PyTorch frameworks.

Project Structure

  • main.py - Main PyTorch implementation
  • main_jax.py - Legacy JAX implementation
  • model.py - PyTorch model implementation
  • model_jax.py - JAX model implementation
  • dataset.py - PyTorch dataset implementation
  • dataset_jax.py - JAX dataset implementation
  • visualize.ipynb - Visualization notebooks for analysis (old results)
  • Various run scripts for different experiments

Setup

  1. Install dependencies:
pip install torch numpy wandb tqdm
  1. For JAX implementation (legacy):
pip install jax jaxlib

Usage

PyTorch Implementation

The main PyTorch implementation can be run using various scripts:

  • Basic run: ./run_torch.sh

JAX Implementation (Legacy)

For the legacy JAX implementation:

./run.sh

Visualization

Use the provided Jupyter notebooks for visualization and analysis:

  • visualize.ipynb - Main visualization notebook
  • visualize_layer4.ipynb - Layer 4 specific visualizations

Outputs

Results are stored in:

  • outs_torch/ - PyTorch implementation outputs
  • outs/ - JAX implementation outputs

Notes

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published