This repository contains PyTorch implementation and experiment interface for supervised NLI task with different models for learning universal sentence representations.
A baseline model MeanEmbedding and three LSTM based models LSTM, BiLSTM and BiLSTM-maxpool are trained on NLI task using SNLI data. The sentence embeddings are evaluated on 8 transfer tasks using SentEval framework.
The micro and macro metric for SentEval tasks are computed as defined in Section 5 of the InferSent paper [1]. The results are tabulated below:
| Model | snli-dev | snli-test | senteval-micro | senteval-macro | 
|---|---|---|---|---|
| MeanEmbedding | 69.5 | 69.1 | 77.31 | 77.92 | 
| LSTM | 80.5 | 80.2 | 70.467 | 70.282 | 
| BiLSTM | 80.00 | 80.08 | 71.997 | 71.531 | 
| BiLSTM-maxpool | 86.50 | 85.87 | 79.075 | 78.831 | 
This repository is organized into the following major components:
- models.py- Pytorch modules for the encoder and classifier models.
- data.py-- SNLIDataclass for preparing data for training and evaluation.
- train.py- Pytorch Lightning model and training CLI for training with different encoders.
- eval.py- CLI that takes model checkpoint and runs evaluation on SNLI and SentEval tasks.
- demo.ipynb- Jupyter notebook for testing model inference and analyzing the results.
# Using pip
pip install -r requirements.txt
# Using conda
conda env create -f environment.yml
# Download english model for SpaCy tokenizer
python -m spacy download en_core_web_sm
To run evaluation with SentEval, prepare SentEval installation as follows:
git clone https://github.com/facebookresearch/SentEval.git
cd SentEval/ && python setup.py install
# Download datasets
cd SentEval/data/downstream/ && ./get_transfer_data.bashRun train.py with one of the following encoder types: MeanEmbedding, LSTM, BiLSTM, BiLSTM-maxpool. The training process will create model checkpoints, TensorBoard logs and hyperparams file hparams.yaml in the ./logs directory.
python train.py --encoder_type='BiLSTM'Run eval.py with a model checkpoint flag to run evaluation tasks on SNLI and SentEval.
python eval.py --checkpoint_path='./logs/MeanEmbedding/version_0/checkpoints/epoch=2-step=12875.ckpt'The model checkpoints and TensorBoard logs are public and can be found here: https://drive.google.com/drive/folders/1Ebjyf0wj31EZMPEBiG1nHW-1JOMMl1IY?usp=sharing
[1] A. Conneau, D. Kiela, H. Schwenk, L. Barrault, A. Bordes, Supervised Learning of Universal Sentence Representations from Natural Language Inference Data