Skip to content

ngmq/onelayer-transformer-ICR-DA-NTP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Prerequisites

  • Python 3.9 or later
  • Pytorch 2.3.1 or later
  • Numpy 1.20.3 or later

File and Directory Structure

├── data/                    # Generated datasets
│   ├── seed*-*.npy          # Training/test data
│   └── finite-seed*-*.npy   # Finite sample data
├── logs/                    # Training outputs
│   ├── seed*-*.npy          # Loss trajectories
│   ├── seed*-*.pth          # Model checkpoints
│   └── finite-*.npy         # Finite sample results
├── viz/                     # Generated figures
│   ├── fig-*.pdf            # Vector plots
├── models.py                # Model implementations
├── train.py                 # Main training script
└── README.md                # This file
File Description
models.py Define Origin and Reparameterized Classes of One-Layer Transformers
trains.py Define Training and Testing Procedures

Command Line Arguments for train.py

Argument Type Description Default Options
--s/--seed int Random seed for reproducibility 0 Any integer
--n/--noise int Whether to add noise to the data - 0 (no), 1 (yes)
--alpha float Noise probability when noise is enabled 0.5 0.0-1.0
--m/--model_type str Transformer variant to use - origin, reparam, reparamW
--a/--att_type str Attention mechanism type - linear, softmax, relu
--k/--kappa int Flag for full-input to the feed-forward layer - 0, 1
--lr float Learning rate for optimizer 0.1 > 0.0
--nst/--nsteps int Number of training steps for population loss minimization experiments 2000 > 0
--fnst/--finite_nsteps int Number of samples for finite-sample experiments 4 > 0
-g/--generate int Flag to generate data 0 0 (no), 1 (yes)
-v/--visualize int Flag to generate plots 0 0 (no), 1 (yes)
--pl/--plot int Type of plot to generate -1 0, 1, 2, 3
--f/--finite int Flag for finite sample experiments 0 0 (no), 1 (yes)

Usage

1. Data Generation

Generate train and test datasets for noiseless ($\alpha = 0$) and noisy ($\alpha \in {0.2, 0.5, 0.8}$) settings.

python train.py -g 1 --s 0 --n 0
python train.py -g 1 --s 0 --n 1 --alpha 0.2
python train.py -g 1 --s 0 --n 1 --alpha 0.5
python train.py -g 1 --s 0 --n 1 --alpha 0.8

The datasets are saved into the ./data/ directory.

2. Model Training

Train all or a subset of $13$ models with learning rates $\eta = 0.1$ and $\eta = 0.5$ using train.py. For example, the following command trains Origin-Softmax-FA on noiseless learning with learning rate $0.5$.

python train.py -s 0 -n 0 -m origin -a softmax -k 1 -lr 0.5

Another example of a command that trains Reparam-Linear-F on noisy learning with $\alpha = 0.2$ and learning rate $0.1$.

python train.py -s 0 -n 1 -m reparam -a linear -k 0 -alpha 0.2 -lr 0.1

To train the model where $W$ is not reparameterized, use -m reparamW.

The training results are saved into the ./logs/ directory.

3. Visualization

Once all models have been trained for all $\alpha \in {0, 0.2, 0.5, 0.8}$ and $\eta \in {0.1, 0.5}$, the figures can be generated by running

train.py -v 1 -pl x

where x is either 0, 1, 2 or 3 specifying the type of plots to be generated according to the table below.

x (plot type) Description
0 Generate Figure 1 in the main text
1 Generate Figure 2 in the main text
2 Generate the additional results on population loss minimization (Figure 3 to 9 in the Appendix)
3 Generate the additional results on layer-specific learning (Figure 10 to 12 in the Appendix)

License

This project is licensed under the MIT License.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages