Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
28b605a
Begin baseline model; skeleton code
SecroLoL Apr 21, 2024
368280e
Initial baseline + pointergen model. TODO fix dimensionalities and in…
SecroLoL Apr 22, 2024
5bdbc96
Add baseline model (seq2seq no coverage, no pgen)
SecroLoL Apr 30, 2024
e8a0407
Add constants for model building and args loading
SecroLoL Apr 30, 2024
6390a53
Delete out logging statements
SecroLoL Apr 30, 2024
d6811f2
Add printing settings to see deeper into tensors
SecroLoL May 2, 2024
94b2056
Add masking to the attention lyaer
SecroLoL May 2, 2024
92dc8a1
add embedding dim to the decoder linear layer input size
SecroLoL May 2, 2024
72e3f0f
Fix squeezing to make dimensions match during forward pass
SecroLoL May 2, 2024
18f1318
Add pgen to decoder, also consider OOV words
SecroLoL May 2, 2024
c6e552e
Add function to process OOV words to create extended vocab map and in…
SecroLoL May 2, 2024
65a9d06
Remove TODO items that were already finished
SecroLoL May 2, 2024
c7ab9b0
Add coverage to model, TODO remove the debugging statements after con…
SecroLoL May 7, 2024
af0f9e5
Improve documentation: add more comments and delete debugging print s…
SecroLoL May 13, 2024
027d33c
Improve documentation, change some print statements and add comments
SecroLoL May 13, 2024
ad317fc
When an OOV word is chosen as the next input to decoder, give the emb…
SecroLoL May 13, 2024
91a6a8f
Remove torch debugging print config
SecroLoL May 13, 2024
54ab2b2
Add boiler-plate imports and logging setup for trainer
SecroLoL May 13, 2024
b48e1fe
Add imports for argparsing and base model file
SecroLoL May 14, 2024
c05556b
Add trainer init, method to build the model after parsing args, and b…
SecroLoL May 14, 2024
e1b4c98
Write main to accept args from cli
SecroLoL May 14, 2024
82902a1
Delete documentation comments
SecroLoL May 14, 2024
2b08784
add argparse and load in vars for training
SecroLoL May 14, 2024
7368b3c
Improve documentation for trainer init
SecroLoL May 14, 2024
99cd01c
Add default values for train, eval, save, and wordvec pretrain files
SecroLoL May 14, 2024
e3950d8
Update constants
SecroLoL May 14, 2024
4936bb2
Add file parsing for train, eval, and wordvec pretrain files. Raise e…
SecroLoL May 14, 2024
a23661a
Move model components to device. Add dropout to Encoder after the LST…
SecroLoL May 21, 2024
52c6e22
Set up training loop with sample data, awaiting proper data pipeline
SecroLoL May 21, 2024
a8ee27c
Fix in-place tensor manipulation bug
SecroLoL May 21, 2024
d279148
Move optimizer zero-ing out the gradient to top of training loop
SecroLoL May 21, 2024
f4883ab
Include add_unsaved_module method to allow for char embeddings
SecroLoL May 24, 2024
bb37a6f
Add CharLM embeddings to model (optionally)
SecroLoL May 25, 2024
f935654
Add CharLM option to trainer
SecroLoL May 25, 2024
fa099a1
Start beam search decoding
SecroLoL May 25, 2024
f6cc54a
add utils function to convert text to vocab token ids
SecroLoL May 25, 2024
ee64121
Add CharLM to generate embedding for next word in decoding if not usi…
SecroLoL May 25, 2024
36a38c8
Remove original pointergen dir (accidentally included)
SecroLoL May 25, 2024
8326157
add helper methods to model class to assist with decoder beam search
SecroLoL May 26, 2024
f556635
Build beam search algorithm
SecroLoL May 26, 2024
b89ec15
Add bugfixes to the summary decoder, using proper inputs. Save model …
SecroLoL May 26, 2024
74d777c
Add preprocessing to data for loading CNN dataset
SecroLoL May 26, 2024
ca91f3d
Update output distribution with pgen to have size of the ext vocab ma…
SecroLoL May 27, 2024
828b3d5
Implement Dataset loader class
SecroLoL May 27, 2024
881a8b7
add new dataloading process to trainer loop
SecroLoL May 27, 2024
9abb6cb
add rouge eval for trained model
SecroLoL May 27, 2024
acd5b7b
Update default train and eval path roots
SecroLoL May 27, 2024
2ec986f
Add argparse
SecroLoL May 27, 2024
047e03f
Add function header documentation
SecroLoL May 27, 2024
fcde39d
Add argparse
SecroLoL May 27, 2024
bc9cb50
Add eval return objects for ROUGE
SecroLoL May 28, 2024
55276a1
add eval to train loop after every epoch
SecroLoL May 28, 2024
d42f073
Add enc_steps and dec_steps limits to truncate inputs if needed. Also…
SecroLoL May 31, 2024
80877cb
add max_enc_steps and max_dec_steps. Also, use OOV vocab inclusive ma…
SecroLoL May 31, 2024
55ac159
add max_enc_steps and max_dec_steps. Also, use OOV vocab inclusive ma…
SecroLoL May 31, 2024
6409fb6
add max_enc_steps and max_dec_steps. Also, use OOV vocab inclusive ma…
SecroLoL May 31, 2024
0e96bdd
add max_enc_steps and max_dec_steps. Also, use OOV vocab inclusive ma…
SecroLoL May 31, 2024
dccecf7
Add enc steps dec steps too
SecroLoL May 31, 2024
33f26cc
Remove prints
SecroLoL May 31, 2024
15595df
Add [START] and [STOP] tokens to embeddings layer
SecroLoL May 31, 2024
8eaa028
Code cleanup: add tqdm bar to training and log exceptions
SecroLoL Jun 1, 2024
778994b
Make log errors type ERROR
SecroLoL Jun 1, 2024
bad594a
Rename evaluation script due to name overlap with HF library
SecroLoL Jun 1, 2024
bdc9b80
Add random sampling filler for batches that are not evenly divisible
SecroLoL Jun 1, 2024
57dd86e
Update trainer
SecroLoL Jun 1, 2024
7fe812d
Add util function to generate checkpoint for model paths
SecroLoL Jun 1, 2024
4ef16d4
remove TODO items that are finished
SecroLoL Jun 1, 2024
44eea44
Change some variable names and improve readability
SecroLoL Jun 1, 2024
e95d125
Add logging + addtl documentation
SecroLoL Jun 2, 2024
39ce330
Bugfix the coverage detach None issue
SecroLoL Jun 2, 2024
85a959c
Move model to device
SecroLoL Jun 2, 2024
ed5a761
Add model checkpoint loading for training
SecroLoL Jun 2, 2024
f275f8a
Oops. The checkpoint load file should not be mandatory.
SecroLoL Jun 2, 2024
668fcb8
Update logger to be native to decode.py, not passed via evaluate_mode…
SecroLoL Jun 3, 2024
87d2be2
Fix issues with beam search decoding with None inputs
SecroLoL Jun 3, 2024
67a594a
Add logger statements for debugging evaluation steps
SecroLoL Jun 3, 2024
f70f435
Update eval after each training epoch: Use avg loss, not the val set acc
SecroLoL Jun 3, 2024
ab5faf9
Add fixes: use teacher-forcing during training and also use gradient …
SecroLoL Jun 5, 2024
533a5bb
Add layernorm to Encoder, Decoder
SecroLoL Jun 5, 2024
035baf8
Update documentation. Also fix layernorm dims to work with model stru…
SecroLoL Jun 8, 2024
92a80ab
Move START and STOP tokens to constants file instead of model file
SecroLoL Jun 8, 2024
c047e47
Put LayerNorm right before sigmoid (correctly now). Add Xavier initia…
SecroLoL Jun 8, 2024
a9a6686
Clumsy, missing the arg to sigmoid fixed
SecroLoL Jun 8, 2024
fcc35e4
Alter checkpoint loading to include coverage if not included before
SecroLoL Jun 8, 2024
6e64b64
Since the prev_coverage is a tensor, need to check the elements for N…
SecroLoL Jun 8, 2024
f08cd28
Implement validation loss for model selection
SecroLoL Jun 8, 2024
2e74f05
Rename helper functions for eval
SecroLoL Jun 8, 2024
99a926b
Fix error: use Softmax instead of LogSoftmax, and apply log at the ve…
SecroLoL Jun 8, 2024
7ab8466
Touchups to last modification
SecroLoL Jun 8, 2024
ef272b7
Delete old sections of training loop: gradnorm debugging and evaluati…
SecroLoL Jun 8, 2024
faca33c
Introduce more dropout to enhance regularization of the model
SecroLoL Jun 10, 2024
6502249
Add weight decay to Adam for L2 reg
SecroLoL Jun 10, 2024
f40c535
Revert LSTM dropout changes. cannot use built-in dropout with single-…
SecroLoL Jun 10, 2024
1083e21
Add verbose option for model eval: prints out reference summaries and…
SecroLoL Jun 12, 2024
01f7f22
Bugfix coverage being None when adding extension to hypothesis during…
SecroLoL Jun 12, 2024
fc6ef06
Update truncation for inputs: Add a STOP token to the end of the summ…
SecroLoL Jun 12, 2024
bd4f686
remove debugging statement
SecroLoL Jun 12, 2024
ee9647a
Bugfix padding tokens not appearing properly and handled by the loss fn
SecroLoL Jun 13, 2024
4c19a0f
Add error-checking for padding token, and update validation loss to u…
SecroLoL Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions stanza/models/summarization/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Constant values for model building and execution
"""

import os

# general
PADDING_TOKEN = "<PAD>"

# for model.py
DEFAULT_ENCODER_HIDDEN_DIM = 128
DEFAULT_ENCODER_NUM_LAYERS = 1
DEFAULT_DECODER_HIDDEN_DIM = 128
DEFAULT_DECODER_NUM_LAYERS = 1
UNK_ID = 1
UNK = "<UNK>"
START_TOKEN = "<s>"
STOP_TOKEN = "</s>"

# model hyperparams
LSTM_DROPOUT_P = 0.2
ATTN_DROPOUT_P = 0.5

DEFAULT_BATCH_SIZE = 16
DEFAULT_EVAL_ROOT = os.path.join(os.path.dirname(__file__), "data", "validation")
DEFAULT_TRAIN_ROOT = os.path.join(os.path.dirname(__file__), "data", "train")
DEFAULT_SAVE_NAME = os.path.join(os.path.dirname(__file__), "saved_models", "summarization_model.pt")
DEFAULT_WORDVEC_PRETRAIN_FILE = os.path.join(os.path.dirname(__file__), "pretrain", "en", "glove.pt")
191 changes: 191 additions & 0 deletions stanza/models/summarization/src/beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
Run beam search decoding from a trained abstractive summarization model
"""
import torch
import logging

logger = logging.getLogger('stanza.summarization')
logger.propagate = False

# Check if the logger has handlers already configured
if not logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

from typing import List, Tuple, Mapping, Any
from stanza.models.summarization.src.model import BaselineSeq2Seq
from stanza.models.summarization.constants import UNK_ID, STOP_TOKEN, START_TOKEN
from stanza.models.common.vocab import BaseVocab, UNK

class Hypothesis():

"""
Represents a hypothesis during beam search. Holds all information needed for the hypothesis.
"""

def __init__(self, tokens, log_probs, state, attn_dists, p_gens, coverage):
"""
Args:
tokens (List[int]): The ids of tokens that form summary so far.
log_probs (List[float]): List of the log probabilities of the tokens so far. (same length as tokens)
state (Tuple[Tensor, Tensor]): Current state of the decoder LSTM, tuple of the LSTM hidden + cell state.
attn_dists (List[Tensor]): List of the attention distributions at each point of the decoder (same length as tokens)
p_gens (List[float]): Values of the generation probabilities (same length as tokens). None if not using pointer-gen.
coverage (Tensor): Current coverage vector. None if not using coverage.
"""
self.tokens = tokens
self.log_probs = log_probs
self.state = state
self.attn_dists = attn_dists
self.p_gens = p_gens
self.coverage = coverage

def extend(self, token, log_prob, state, attn_dist, p_gen, coverage):
"""
Return a new hypothesis, extended with the information from the latest step of beam search.

Args:
token (int): The latest token ID produced by beam search
log_prob (float): Log probability of the latest token
state (Tuple[Tensor, Tensor]): Current decoder state of the hidden and cell state.
attn_dist (Tensor): Current attention distribution from latest step.
p_gen (float): Generation probability from latest step.
coverage (Tensor): Latest coverage vector, or None if not using coverage.

Returns:
Hypothesis() : New hypothesis for next step
"""
new_hypothesis = Hypothesis(
tokens=self.tokens + [token],
log_probs=self.log_probs + [log_prob],
state=state,
attn_dists = self.attn_dists + [attn_dist],
p_gens=self.p_gens + [p_gen],
coverage=coverage
)
return new_hypothesis

def get_latest_token(self):
# Get the last token decoded in this hypothesis
return self.tokens[-1]

def get_log_prob(self):
# The sum of the log probabilities so far
return sum(self.log_probs)

def get_avg_log_prob(self):
# Normalize by sequence length (longer sequences will always have lower probability)
return self.get_log_prob() / len(self.tokens)


def run_beam_search(model: BaselineSeq2Seq, unit2id: Mapping, id2unit: Mapping, example: List[str], beam_size: int,
max_dec_steps: int, min_dec_steps: int, max_enc_steps: int):
"""
Performs beam search decoding on an example (ONE EXAMPLE).

Returns the hypothesis for each example with the highest average log probability.
"""
# Truncate if needed
example = example[: max_enc_steps]
if example[-1] != STOP_TOKEN:
example = example + [STOP_TOKEN]


batch = [example for _ in range(beam_size)] # each batch is a single example repeated `beam_size` times
device = next(model.parameters()).device

# Run encoder over the batch of examples to get the encoder hidden states and decoder init state
# note that the batch is the same example repeated
enc_states, dec_hidden_init, dec_cell_init = model.run_encoder(batch)

# enc states shape (batch size, seq len, 2 * enc hidden dim)
# dec states are shape (batch size, dec hidden dim)
# note that we only have one example, so the batch size should be 1

# Initialize N-Hypotheses for beam search
hyps = [
Hypothesis(
tokens=[unit2id.get(START_TOKEN, UNK_ID)],
log_probs=[0.0],
state=(dec_hidden_init[0], dec_cell_init[0]), # only one example, so get the state for that example
attn_dists=[],
p_gens=[],
coverage=torch.zeros(enc_states.shape[1], device=device) if model.coverage else None # sequence length
) for _ in range(beam_size)
]
results = [] # stores our finished hypotheses (decoded out the STOP token)

# Run the loop while we still have decoding steps and the number of finished results is less than the beam size
steps = 0
while steps < max_dec_steps and len(results) < beam_size:
latest_tokens = [h.get_latest_token() for h in hyps] # get latest token from each hypothesis
latest_tokens = [t if t in range(len(unit2id)) else UNK_ID for t in latest_tokens] # change any OOV words to UNK
latest_tokens = [[id2unit.get(t)] for t in latest_tokens] # convert back to word because model.decode_onestep() expects string
hidden_states = [h.state[0] for h in hyps]
cell_states = [h.state[1] for h in hyps]
prev_coverage = [h.coverage for h in hyps]

# run the decoder for one timestep, decoding out choices for the next token of each sequence
topk_ids, topk_log_probs, new_hiddens, new_cells, attn_dists, p_gens, new_coverage, unit2id_ = model.decode_onestep(
examples = batch,
latest_tokens=latest_tokens,
enc_states=enc_states,
dec_hidden=torch.stack(hidden_states).to(device),
dec_cell=torch.stack(cell_states).to(device),
prev_coverage=torch.stack(prev_coverage).to(device) if prev_coverage[0] is not None else None
)
# create updated id2unit from unit2id_.
# Note that the outputted unit2id_ is always continually updated every call to model.decode_onestep()
# So we know that the id2unit is always updated with the most recent OOV words that can be chosen in our hyps
id2unit = {idx: word.replace('\xa0', ' ') for word, idx in unit2id_.items()}

# extend current hypotheses with the possible next tokens. We determine the choices to be 2 x beam size for the choices
all_hyps = []
num_original_hyps = 1 if steps == 0 else len(hyps)
for i in range(num_original_hyps):
p_gen = [] if p_gens is None else p_gens[i]
new_coverage_i = new_coverage[i] if new_coverage is not None else None
h, new_hidden, new_cell, attn_dist = hyps[i], new_hiddens[i], new_cells[i], attn_dists[i]
for j in range(2 * beam_size): # for each of the top 2*beam_size hypotheses:
# Extend the ith hypothesis with the jth option
new_hyp = h.extend(
token=topk_ids[i, j].item(),
log_prob=topk_log_probs[i, j],
state=(new_hidden, new_cell),
attn_dist=attn_dist,
p_gen=p_gen,
coverage=new_coverage_i
)
all_hyps.append(new_hyp)
# Filter and collect any hypotheses that have produced the end token (or are over limit)
hyps = []
for h in sort_hypotheses(all_hyps): # in order of most likely h
if h.get_latest_token() == unit2id.get(STOP_TOKEN): # if we reach the stop token
# if the hypothesis is sufficiently long, then put in results, otherwise discard
if steps >= min_dec_steps:
results.append(h)
else: # hasn't reached stop token, so continue to expand the hypothesis
hyps.append(h)
if len(hyps) == beam_size or len(results) == beam_size:
# Once we've collected beam_size-many hypotheses for the next step or beam_size-many complete hypotheses, stop
break
steps += 1

# We now have either beam_size results or reached the maximum number of decoder steps
if len(results) == 0:
# If we don't have any complete results, add all current hypotheses (incomplete summaries) to results
results = hyps

# Sort hypotheses by the average log probability and return the hypothesis with the highest average log prob
hyps_sorted = sort_hypotheses(results)
return hyps_sorted[0], id2unit


def sort_hypotheses(hyps: List[Hypothesis]):
"""
Return of a list of Hypothesis objects sorted by descending average log prob
"""
return sorted(hyps, key=lambda h: h.get_avg_log_prob(), reverse=True)
89 changes: 89 additions & 0 deletions stanza/models/summarization/src/decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Takes an existing model and runs beam search decoding across many examples

"""

import torch

from copy import deepcopy
from typing import List, Tuple, Mapping, Any
from stanza.models.summarization.src.model import BaselineSeq2Seq
from stanza.models.summarization.src.beam_search import *
from stanza.models.common.vocab import BaseVocab, UNK
from logging import Logger
from stanza.utils.get_tqdm import get_tqdm

tqdm = get_tqdm()

logger = logging.getLogger('stanza.summarization')
logger.propagate = False

# Check if the logger has handlers already configured
if not logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


class BeamSearchDecoder():

"""
Decoder for summarization using beam search
"""

def __init__(self, model: BaselineSeq2Seq):
self.model = model
self.stop_token = "</s>"
self.start_token = "<s>"
self.ext_id2unit = {idx: word for word, idx in self.model.ext_vocab_map.items()}
self.ext_unit2id = self.model.ext_vocab_map

logger.info(f"Loaded model into BeamSearchDecoder on device {next(self.model.parameters()).device}")

def decode_examples(self, examples: List[List[str]], beam_size: int, max_dec_steps: int = None, min_dec_steps: int = None,
max_enc_steps: int = None, verbose: bool = True) -> List[List[str]]:
summaries = [] # outputs
num_examples = len(examples)
PRINT_EVERY = 1000
for i, article in tqdm(enumerate(examples), desc="decoding examples for evaluation..."):
if i % PRINT_EVERY == 0:
logger.info(f"Attempting to generate examples for eval for article {i + 1} / {num_examples}")
try:
# Run beam search to get the best hypothesis
best_hyp, id2unit = run_beam_search(self.model,
self.ext_unit2id,
self.ext_id2unit,
article,
beam_size=beam_size,
max_dec_steps=max_dec_steps,
min_dec_steps=min_dec_steps,
max_enc_steps=max_enc_steps,
)

output_ids = [int(t) for t in best_hyp.tokens[1: ]] # exclude START tokens but not STOP because not guaranteed to contain STOP

decoded_words = [id2unit.get(idx) for idx in output_ids]
if self.stop_token in decoded_words:
fst_stop_index = decoded_words.index(self.stop_token) # index of the first STOP token
decoded_words = decoded_words[: fst_stop_index]
summaries.append(decoded_words)

if verbose:
decoded_output = " ".join(decoded_words)
self.log_output(article, decoded_output)


except Exception as e:
logger.error(f'Error on article {i}: {" ".join([word for word in article])}\n')
raise(e)
assert len(examples) == len(summaries), f"Expected number of summaries ({len(summaries)}) to match number of articles ({len(examples)})."
return summaries

def log_output(self, article: List[str], summary: str) -> None:
if logger is None:
raise ValueError(f"Cannot log output without a Logger. Logger: {logger}")
article_text = " ".join(article)
logger.info(f"ARTICLE TEXT: {article_text}\n--------\nSUMMARY: {summary}")

Loading