Skip to content

Commit bc83804

Browse files
pszemrajPeter Szemraj
andauthored
Improve training config (#11)
This PR aims to bring improvements to the training args used for pretraining MPNet from several angles: 1. improved default values for the training args[^1] and updates to some to more closely follow hyperparams in MPNet paper 2. clearer, more succinct descriptions of what the core args are/do and how to use them 3. addition of new A) options for some existing training args[^2] and B) exposing/integrating some hardcoded parameters[^3] to new CLI args to be adjustable by the user [^1]: i.e. like grad clip which has become standard during pretrain since the original repo came out [^2]: added support for new activation fns "silu" and "relu2" [^3]: the relaative attention hyperparams `relative_attention_num_buckets` and `max_distance` are hardcoded to values for 512 ctx, dhould be set-able by user w/ reasonable defaults --------- Signed-off-by: peter szemraj <peterszemraj@gmail.com> Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> Co-authored-by: Peter Szemraj <peterszemraj+dev@gmail.com>
1 parent cec50e3 commit bc83804

File tree

8 files changed

+437
-82
lines changed

8 files changed

+437
-82
lines changed

annotated_mpnet/data/mpnet_data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
as the data collator
44
"""
55

6-
import os
76
import logging
7+
import os
8+
import random
89
from typing import Dict, Iterator, Sized
910

1011
from rich.logging import RichHandler
@@ -15,12 +16,12 @@
1516
)
1617
LOGGER = logging.getLogger(__name__)
1718

19+
1820
import numpy as np
1921
import torch
22+
from datasets import load_dataset
2023
from torch.utils.data import Sampler
2124
from transformers import PreTrainedTokenizer
22-
from datasets import load_dataset
23-
import random
2425

2526
from annotated_mpnet.utils import utils
2627
from annotated_mpnet.utils.perm_utils_fast import make_span_perm

annotated_mpnet/modeling/mpnet_for_pretraining.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def __init__(self, args, tokenizer) -> None:
7272
encoder_normalize_before=True,
7373
activation_fn=args.activation_fn,
7474
normalize_before=args.normalize_before,
75+
relative_attention_num_buckets=args.relative_attention_num_buckets,
76+
relative_attention_max_distance=args.relative_attention_max_distance,
7577
)
7678

7779
# Add the language modeling head
@@ -534,6 +536,10 @@ def make_query_and_content_mask(
534536
[ 0 0 0 0 1 1 1 0 0 0 ]
535537
[ 0 0 0 0 1 1 1 0 0 0 ]
536538
[ 0 0 0 0 1 1 1 0 0 0 ]
539+
540+
Note: This function is designed to scale automatically with sequence length as it's
541+
matrix-based and constructs masks based on the provided seq_len and pred_size.
542+
There's no need to modify this function when changing context length.
537543
"""
538544

539545
# Define helper function to keep things organized

annotated_mpnet/transformer_modules/sentence_encoder.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def __init__(
7171
embed_scale: float = None,
7272
freeze_embeddings: bool = False,
7373
n_trans_layers_to_freeze: int = 0,
74-
relative_attention_num_buckets: int = 32,
74+
relative_attention_num_buckets: int = None,
75+
relative_attention_max_distance: int = None,
7576
normalize_before: bool = False,
7677
export: bool = False,
7778
) -> None:
@@ -115,6 +116,8 @@ def __init__(
115116
This is probably only useful for finetuning
116117
relative_attention_num_buckets: the number of buckets to add to the relative atttention
117118
portion of the attention mechanism
119+
relative_attention_max_distance: the maximum distance (in tokens) to consider in the relative
120+
attention mechanism
118121
normalize_before: boolean dictating if a layer norm should be applied before the encoder
119122
layers
120123
export: boolean dictating ONNX exporting, which I think we won't be using
@@ -160,7 +163,27 @@ def __init__(
160163
)
161164

162165
# Set up relative attention bias for the attention mechanism
163-
self.relative_attention_num_buckets = relative_attention_num_buckets
166+
# and compute params for relative attention if they are not specified
167+
base_context = 512
168+
base_buckets = 32 # Default buckets for 512 context length is 32
169+
base_max_distance = 128 # Default max distance for 512 context length is 128
170+
171+
if relative_attention_num_buckets is None:
172+
# linear scaling of num buckets based on seq len (round up to nearest 8)
173+
scaled_buckets = max(32, int(base_buckets * max_seq_len / base_context))
174+
self.relative_attention_num_buckets = (scaled_buckets + 7) // 8 * 8
175+
else:
176+
self.relative_attention_num_buckets = relative_attention_num_buckets
177+
178+
if relative_attention_max_distance is None:
179+
# linear scaling of max distance based on seq len (round up to nearest 8)
180+
scaled_max_distance = max(
181+
128, int(base_max_distance * max_seq_len / base_context)
182+
)
183+
self.relative_attention_max_distance = (scaled_max_distance + 7) // 8 * 8
184+
else:
185+
self.relative_attention_max_distance = relative_attention_max_distance
186+
164187
self.relative_attention_bias = nn.Embedding(
165188
self.relative_attention_num_buckets, num_attention_heads, padding_idx=None
166189
)
@@ -259,7 +282,7 @@ def forward(
259282

260283
# Compute the relative attention bias
261284
positions_bias = self.compute_position_bias(
262-
x, self.relative_attention_num_buckets
285+
x, self.relative_attention_num_buckets, self.relative_attention_max_distance
263286
)
264287

265288
# If the user wants ALL hidden states, we keep track of it here
@@ -293,10 +316,18 @@ def forward(
293316

294317
return inner_states, sentence_rep
295318

296-
# Helper function below
297-
def compute_position_bias(self, x, num_buckets):
319+
def compute_position_bias(self, x, num_buckets, max_distance):
298320
"""
299-
Helper function that computes the position bias based on the number of buckets provided
321+
Computes the relative position bias for self-attention.
322+
323+
Args:
324+
x: Input tensor with shape (seq_len, batch_size, embed_dim).
325+
num_buckets: Number of buckets to use for relative position encoding.
326+
max_distance: The maximum distance to consider for relative positions.
327+
328+
Returns:
329+
A tensor representing the relative position bias, with shape
330+
(batch_size * num_heads, qlen, klen).
300331
"""
301332

302333
# Get the batch size, q and k len
@@ -307,7 +338,9 @@ def compute_position_bias(self, x, num_buckets):
307338
relative_position = memory_position - context_position
308339

309340
rp_bucket = self.relative_position_bucket(
310-
relative_position, num_buckets=num_buckets
341+
relative_position,
342+
num_buckets=num_buckets,
343+
max_distance=max_distance,
311344
)
312345
rp_bucket = rp_bucket.to(x.device)
313346
values = self.relative_attention_bias(rp_bucket)
@@ -317,7 +350,24 @@ def compute_position_bias(self, x, num_buckets):
317350
return values
318351

319352
@staticmethod
320-
def relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
353+
def relative_position_bucket(
354+
relative_position, num_buckets: int = 32, max_distance: int = 128
355+
):
356+
"""
357+
Computes the relative position bias for a given tensor of relative positions.
358+
Defaults are for original MPNet @ context length 512.
359+
360+
Args:
361+
relative_position: Tensor of shape (bsz, qlen, klen) containing the relative
362+
positions between the queries and keys.
363+
num_buckets: The number of buckets to use for the relative position bias.
364+
Defaults to 32.
365+
max_distance: The maximum distance to consider when computing the relative
366+
position bias. Defaults to 128.
367+
368+
Returns:
369+
A tensor of shape (bsz, qlen, klen) containing the relative position biases.
370+
"""
321371
ret = 0
322372
n = -relative_position
323373

annotated_mpnet/transformer_modules/sentence_encoder_layer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def __init__(
5656
forward pass
5757
attention_dropout: similar to above, but is the dropout prob within the self-attention
5858
mechanism
59-
activation_fn: the activation function you will be using in this network. Although ReLU
60-
is the default, more and more evidence points towards GELU being better for large
61-
NLP-based transformers
59+
activation_fn: the activation function you will be using in this network.
6260
add_bias_kv: boolean that dictates whether or not to add a bias parameter to the K, V
6361
matrices in the self-attention mechanism
6462
add_zero_attn: boolean that dictate whether or not to add zero attention to the

0 commit comments

Comments
 (0)