-
Notifications
You must be signed in to change notification settings - Fork 286
[WIP] add ModernBERT #2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] add ModernBERT #2256
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
ReversibleEmbedding, | ||
) | ||
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.modernbert.modernbert_layers import ( | ||
ModernBERTEncoderLayer, | ||
) | ||
from keras_hub.src.utils.keras_utils import gelu_approximate | ||
|
||
|
||
@keras_hub_export("keras_hub.models.ModernBertBackbone") | ||
class ModernBertBackbone(Backbone): | ||
def __init__( | ||
self, | ||
vocabulary_size, | ||
num_layers, | ||
num_heads, | ||
hidden_dim, | ||
intermediate_dim, | ||
max_sequence_length=8192, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused? If this doesn't actually do anything, no need for it. We don't need config to track pretraining params (though we can document on Kaggle the longest sequence length checkpoints were trained on). |
||
dropout=0.0, | ||
rotary_max_wavelength=160000.0, | ||
layer_norm_epsilon=1e-5, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.token_embedding = ReversibleEmbedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
embeddings_initializer=keras.initializers.TruncatedNormal( | ||
stddev=0.02 | ||
), | ||
dtype=dtype, | ||
name="token_embedding", | ||
) | ||
self.position_embedding = RotaryEmbedding( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we cache the rotary tensor or anything, so there's probably no advantage to using a shared layer here. Instead I would just create a |
||
max_wavelength=rotary_max_wavelength, | ||
dtype=dtype, | ||
name="rotary_embedding", | ||
) | ||
self.embeddings_layer_norm = keras.layers.LayerNormalization( | ||
epsilon=layer_norm_epsilon, | ||
dtype=dtype, | ||
rms_scaling=True, | ||
name="embeddings_layer_norm", | ||
) | ||
self.transformer_layers = [] | ||
for i in range(num_layers): | ||
layer = ModernBERTEncoderLayer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ModernBERTEncoderLayer -> ModernBertEncoderLayer |
||
hidden_size=hidden_dim, | ||
intermediate_size=intermediate_dim, | ||
num_heads=num_heads, | ||
activation=gelu_approximate, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
rotary_embedding=self.position_embedding, | ||
dtype=dtype, | ||
name=f"transformer_layer_{i}", | ||
) | ||
self.transformer_layers.append(layer) | ||
self.final_norm = keras.layers.LayerNormalization( | ||
epsilon=layer_norm_epsilon, | ||
rms_scaling=True, | ||
dtype=dtype, | ||
name="final_layernorm", | ||
) | ||
|
||
# === Functional Model === | ||
token_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="token_ids" | ||
) | ||
padding_mask_input = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
x = self.token_embedding(token_id_input) | ||
x = self.embeddings_layer_norm(x) | ||
for transformer_layer in self.transformer_layers: | ||
x = transformer_layer(x) | ||
sequence_output = self.final_norm(x) | ||
|
||
# Instantiate using Functional API Model constructor | ||
super().__init__( | ||
inputs={ | ||
"token_ids": token_id_input, | ||
"padding_mask": padding_mask_input, | ||
}, | ||
outputs=sequence_output, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.vocabulary_size = vocabulary_size | ||
self.num_layers = num_layers | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.max_sequence_length = max_sequence_length | ||
self.dropout = dropout | ||
self.rotary_max_wavelength = rotary_max_wavelength | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocabulary_size": self.vocabulary_size, | ||
"num_layers": self.num_layers, | ||
"num_heads": self.num_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"max_sequence_length": self.max_sequence_length, | ||
"dropout": self.dropout, | ||
"rotary_max_wavelength": self.rotary_max_wavelength, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
} | ||
) | ||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pytest | ||
from keras import ops | ||
|
||
from keras_hub.src.models.modernbert.modernbert_backbone import ( | ||
ModernBertBackbone, | ||
) | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class ModernBertBackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"vocabulary_size": 10, | ||
"num_layers": 2, | ||
"num_heads": 4, | ||
"hidden_dim": 8, | ||
"intermediate_dim": 32, | ||
} | ||
self.input_data = { | ||
"token_ids": ops.ones((2, 5), dtype="int32"), | ||
"padding_mask": ops.ones((2, 5), dtype="int32"), | ||
} | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=ModernBertBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 5, 8), | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=ModernBertBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import keras | ||
from keras import layers | ||
from keras import ops | ||
|
||
from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't add a cross model dep like this. If we want stuff like this across models, we should pull it into a util. But in this case, I think we probably want to unfuse the qkv matrix, for consistent UX for things like lora. |
||
from keras_hub.src.models.flux.flux_maths import scaled_dot_product_attention | ||
|
||
|
||
class MLP(keras.layers.Layer): | ||
def __init__( | ||
self, | ||
hidden_size, | ||
intermediate_size, | ||
activation="gelu", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have we check that we want gelu and not gelu(x, approximate=True)? https://keras.io/api/layers/activations/#gelu-function We should do whatever matches upstream |
||
dtype=None, | ||
**kwargs, | ||
): | ||
super(MLP, self).__init__(**kwargs) | ||
self.Wi = layers.Dense( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally don't do algebriac notation like this in attr names.. |
||
intermediate_size * 2, | ||
use_bias=False, | ||
dtype=dtype, | ||
) | ||
self.act = keras.activations.get(activation) | ||
self.Wo = layers.Dense( | ||
hidden_size, | ||
use_bias=False, | ||
dtype=dtype, | ||
) | ||
|
||
def call(self, x): | ||
input, gate = ops.split(self.Wi(x), 2, axis=-1) | ||
return self.Wo(self.act(input) * gate) | ||
|
||
|
||
class ModernBERTAttention(keras.Model): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ModernBERTAttention -> ModernBertAttention Same anywhere we have BERT in a classname |
||
def __init__( | ||
self, hidden_size, num_heads, rotary_embedding, dtype=None, **kwargs | ||
): | ||
super(ModernBERTAttention, self).__init__(**kwargs) | ||
self.num_heads = num_heads | ||
self.hidden_size = hidden_size | ||
self.rotary_embedding = rotary_embedding | ||
self.Wqkv = layers.Dense(hidden_size * 3, use_bias=False, dtype=dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here. We may want to split this for consistency with other models. We generally don't do the fused qkv like this. |
||
self.Wo = layers.Dense(hidden_size, use_bias=False, dtype=dtype) | ||
|
||
def build(self, input_shape): | ||
self.Wqkv.build(input_shape) | ||
self.Wo.build((None, input_shape[1], input_shape[-1])) | ||
|
||
def call(self, x): | ||
qkv = self.Wqkv(x) | ||
q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads) | ||
|
||
# Apply rotary embeddings | ||
q = self.rotary_embedding(q) | ||
k = self.rotary_embedding(k) | ||
|
||
# Apply scaled dot product attention | ||
x = scaled_dot_product_attention(q, k, v) | ||
|
||
# Reshape and apply final dense layer | ||
x = ops.transpose(x, (0, 2, 1, 3)) | ||
b, s, h, d = ops.shape(x) | ||
x = ops.reshape(x, (b, s, h * d)) | ||
x = self.Wo(x) | ||
return x | ||
|
||
|
||
class ModernBERTEncoderLayer(keras.Model): | ||
def __init__( | ||
self, | ||
hidden_size, | ||
intermediate_size, | ||
num_heads, | ||
activation="gelu", | ||
layer_norm_epsilon=1e-05, | ||
rotary_embedding=None, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
super(ModernBERTEncoderLayer, self).__init__(**kwargs) | ||
self.attn = ModernBERTAttention( | ||
hidden_size, num_heads, rotary_embedding, dtype=dtype | ||
) | ||
self.mlp_norm = layers.LayerNormalization( | ||
epsilon=layer_norm_epsilon, dtype=dtype | ||
) | ||
self.mlp = MLP(hidden_size, intermediate_size, activation, dtype=dtype) | ||
|
||
def call(self, x): | ||
x = self.attn(x) | ||
x = self.mlp_norm(x) | ||
x = self.mlp(x) | ||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.modernbert.modernbert_backbone import ( | ||
ModernBertBackbone, | ||
) | ||
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.tokenizers.ModernBertTokenizer", | ||
"keras_hub.models.ModernBertTokenizer", | ||
] | ||
) | ||
class ModernBertTokenizer(BytePairTokenizer): | ||
backbone_cls = ModernBertBackbone | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring for all public symbols. |
||
|
||
def __init__( | ||
self, | ||
vocabulary=None, | ||
merges=None, | ||
**kwargs, | ||
): | ||
self._add_special_token("[CLS]", "cls_token") | ||
self._add_special_token("[SEP]", "sep_token") | ||
self._add_special_token("[PAD]", "pad_token") | ||
self._add_special_token("[UNK]", "unk_token") | ||
self._add_special_token("[MASK]", "mask_token") | ||
# Also add `tokenizer.start_token` and `tokenizer.end_token` for | ||
# compatibility with other tokenizers. | ||
self._add_special_token("[CLS]", "start_token") | ||
self._add_special_token("[SEP]", "end_token") | ||
super().__init__( | ||
vocabulary=vocabulary, | ||
merges=merges, | ||
**kwargs, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from keras_hub.src.models.modernbert.modernbert_tokenizer import ( | ||
ModernBertTokenizer, | ||
) | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class ModernBertTokenizerTest(TestCase): | ||
def setUp(self): | ||
self.vocab = ["[CLS]", "[PAD]", "[SEP]", "air", "Ġair", "plane", "Ġat"] | ||
self.vocab += ["port", "[MASK]", "[UNK]"] | ||
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) | ||
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] | ||
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] | ||
self.merges += ["Ġai r", "Ġa i", "pla ne"] | ||
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} | ||
self.input_data = [ | ||
"[CLS] airplane at airport[SEP][PAD]", | ||
" airplane airport", | ||
] | ||
|
||
def test_tokenizer_basics(self): | ||
self.run_preprocessing_layer_test( | ||
cls=ModernBertTokenizer, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], | ||
expected_detokenize_output=[ | ||
"[CLS] airplane at airport[SEP][PAD]", | ||
" airplane airport", | ||
], | ||
) | ||
|
||
def test_errors_missing_special_tokens(self): | ||
with self.assertRaises(ValueError): | ||
ModernBertTokenizer(vocabulary=["a", "b", "c"], merges=[]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add docstrings.