- Python 3.9 or later
- Pytorch 2.3.1 or later
- Numpy 1.20.3 or later
├── data/ # Generated datasets
│ ├── seed*-*.npy # Training/test data
│ └── finite-seed*-*.npy # Finite sample data
├── logs/ # Training outputs
│ ├── seed*-*.npy # Loss trajectories
│ ├── seed*-*.pth # Model checkpoints
│ └── finite-*.npy # Finite sample results
├── viz/ # Generated figures
│ ├── fig-*.pdf # Vector plots
├── models.py # Model implementations
├── train.py # Main training script
└── README.md # This file
File | Description |
---|---|
models.py | Define Origin and Reparameterized Classes of One-Layer Transformers |
trains.py | Define Training and Testing Procedures |
Argument | Type | Description | Default | Options |
---|---|---|---|---|
--s/--seed |
int | Random seed for reproducibility | 0 | Any integer |
--n/--noise |
int | Whether to add noise to the data | - | 0 (no), 1 (yes) |
--alpha |
float | Noise probability when noise is enabled | 0.5 | 0.0-1.0 |
--m/--model_type |
str | Transformer variant to use | - | origin , reparam , reparamW |
--a/--att_type |
str | Attention mechanism type | - | linear , softmax , relu |
--k/--kappa |
int | Flag for full-input to the feed-forward layer | - | 0, 1 |
--lr |
float | Learning rate for optimizer | 0.1 | > 0.0 |
--nst/--nsteps |
int | Number of training steps for population loss minimization experiments | 2000 | > 0 |
--fnst/--finite_nsteps |
int | Number of samples for finite-sample experiments | 4 | > 0 |
-g/--generate |
int | Flag to generate data | 0 | 0 (no), 1 (yes) |
-v/--visualize |
int | Flag to generate plots | 0 | 0 (no), 1 (yes) |
--pl/--plot |
int | Type of plot to generate | -1 | 0, 1, 2, 3 |
--f/--finite |
int | Flag for finite sample experiments | 0 | 0 (no), 1 (yes) |
Generate train and test datasets for noiseless (
python train.py -g 1 --s 0 --n 0
python train.py -g 1 --s 0 --n 1 --alpha 0.2
python train.py -g 1 --s 0 --n 1 --alpha 0.5
python train.py -g 1 --s 0 --n 1 --alpha 0.8
The datasets are saved into the ./data/
directory.
Train all or a subset of train.py
. For example, the following command trains Origin-Softmax-FA
on noiseless learning with learning rate
python train.py -s 0 -n 0 -m origin -a softmax -k 1 -lr 0.5
Another example of a command that trains Reparam-Linear-F
on noisy learning with
python train.py -s 0 -n 1 -m reparam -a linear -k 0 -alpha 0.2 -lr 0.1
To train the model where -m reparamW
.
The training results are saved into the ./logs/
directory.
Once all models have been trained for all
train.py -v 1 -pl x
where x
is either 0, 1, 2
or 3
specifying the type of plots to be generated according to the table below.
x (plot type) |
Description |
---|---|
0 | Generate Figure 1 in the main text |
1 | Generate Figure 2 in the main text |
2 | Generate the additional results on population loss minimization (Figure 3 to 9 in the Appendix) |
3 | Generate the additional results on layer-specific learning (Figure 10 to 12 in the Appendix) |
This project is licensed under the MIT License.