Skip to content

[WIP] add ModernBERT #2256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@
from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import (
MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor,
)
from keras_hub.src.models.modernbert.modernbert_backbone import (
ModernBertBackbone as ModernBertBackbone,
)
from keras_hub.src.models.modernbert.modernbert_tokenizer import (
ModernBertTokenizer as ModernBertTokenizer,
)
from keras_hub.src.models.moonshine.moonshine_audio_to_text import (
MoonshineAudioToText as MoonshineAudioToText,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from keras_hub.src.models.mixtral.mixtral_tokenizer import (
MixtralTokenizer as MixtralTokenizer,
)
from keras_hub.src.models.modernbert.modernbert_tokenizer import (
ModernBertTokenizer as ModernBertTokenizer,
)
from keras_hub.src.models.moonshine.moonshine_tokenizer import (
MoonshineTokenizer as MoonshineTokenizer,
)
Expand Down
Empty file.
122 changes: 122 additions & 0 deletions keras_hub/src/models/modernbert/modernbert_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.modernbert.modernbert_layers import (
ModernBERTEncoderLayer,
)
from keras_hub.src.utils.keras_utils import gelu_approximate


@keras_hub_export("keras_hub.models.ModernBertBackbone")
class ModernBertBackbone(Backbone):
def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstrings.

self,
vocabulary_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
max_sequence_length=8192,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused? If this doesn't actually do anything, no need for it. We don't need config to track pretraining params (though we can document on Kaggle the longest sequence length checkpoints were trained on).

dropout=0.0,
rotary_max_wavelength=160000.0,
layer_norm_epsilon=1e-5,
dtype=None,
**kwargs,
):
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=keras.initializers.TruncatedNormal(
stddev=0.02
),
dtype=dtype,
name="token_embedding",
)
self.position_embedding = RotaryEmbedding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we cache the rotary tensor or anything, so there's probably no advantage to using a shared layer here. Instead I would just create a RotaryEmbedding inside inside the encoder layer, same as other models that using rotary embeddings in this repo.

max_wavelength=rotary_max_wavelength,
dtype=dtype,
name="rotary_embedding",
)
self.embeddings_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
rms_scaling=True,
name="embeddings_layer_norm",
)
self.transformer_layers = []
for i in range(num_layers):
layer = ModernBERTEncoderLayer(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModernBERTEncoderLayer -> ModernBertEncoderLayer

hidden_size=hidden_dim,
intermediate_size=intermediate_dim,
num_heads=num_heads,
activation=gelu_approximate,
layer_norm_epsilon=layer_norm_epsilon,
rotary_embedding=self.position_embedding,
dtype=dtype,
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)
self.final_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_epsilon,
rms_scaling=True,
dtype=dtype,
name="final_layernorm",
)

# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
x = self.token_embedding(token_id_input)
x = self.embeddings_layer_norm(x)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x)
sequence_output = self.final_norm(x)

# Instantiate using Functional API Model constructor
super().__init__(
inputs={
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.max_sequence_length = max_sequence_length
self.dropout = dropout
self.rotary_max_wavelength = rotary_max_wavelength
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"max_sequence_length": self.max_sequence_length,
"dropout": self.dropout,
"rotary_max_wavelength": self.rotary_max_wavelength,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config
38 changes: 38 additions & 0 deletions keras_hub/src/models/modernbert/modernbert_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from keras import ops

from keras_hub.src.models.modernbert.modernbert_backbone import (
ModernBertBackbone,
)
from keras_hub.src.tests.test_case import TestCase


class ModernBertBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_heads": 4,
"hidden_dim": 8,
"intermediate_dim": 32,
}
self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
"padding_mask": ops.ones((2, 5), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=ModernBertBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ModernBertBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
95 changes: 95 additions & 0 deletions keras_hub/src/models/modernbert/modernbert_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import keras
from keras import layers
from keras import ops

from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't add a cross model dep like this. If we want stuff like this across models, we should pull it into a util. But in this case, I think we probably want to unfuse the qkv matrix, for consistent UX for things like lora.

from keras_hub.src.models.flux.flux_maths import scaled_dot_product_attention


class MLP(keras.layers.Layer):
def __init__(
self,
hidden_size,
intermediate_size,
activation="gelu",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we check that we want gelu and not gelu(x, approximate=True)? https://keras.io/api/layers/activations/#gelu-function

We should do whatever matches upstream

dtype=None,
**kwargs,
):
super(MLP, self).__init__(**kwargs)
self.Wi = layers.Dense(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally don't do algebriac notation like this in attr names.. self.input_dense and self.output_dense or something like that.

intermediate_size * 2,
use_bias=False,
dtype=dtype,
)
self.act = keras.activations.get(activation)
self.Wo = layers.Dense(
hidden_size,
use_bias=False,
dtype=dtype,
)

def call(self, x):
input, gate = ops.split(self.Wi(x), 2, axis=-1)
return self.Wo(self.act(input) * gate)


class ModernBERTAttention(keras.Model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModernBERTAttention -> ModernBertAttention

Same anywhere we have BERT in a classname

def __init__(
self, hidden_size, num_heads, rotary_embedding, dtype=None, **kwargs
):
super(ModernBERTAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.rotary_embedding = rotary_embedding
self.Wqkv = layers.Dense(hidden_size * 3, use_bias=False, dtype=dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here. dense_qkv.

We may want to split this for consistency with other models. We generally don't do the fused qkv like this.

self.Wo = layers.Dense(hidden_size, use_bias=False, dtype=dtype)

def build(self, input_shape):
self.Wqkv.build(input_shape)
self.Wo.build((None, input_shape[1], input_shape[-1]))

def call(self, x):
qkv = self.Wqkv(x)
q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)

# Apply rotary embeddings
q = self.rotary_embedding(q)
k = self.rotary_embedding(k)

# Apply scaled dot product attention
x = scaled_dot_product_attention(q, k, v)

# Reshape and apply final dense layer
x = ops.transpose(x, (0, 2, 1, 3))
b, s, h, d = ops.shape(x)
x = ops.reshape(x, (b, s, h * d))
x = self.Wo(x)
return x


class ModernBERTEncoderLayer(keras.Model):
def __init__(
self,
hidden_size,
intermediate_size,
num_heads,
activation="gelu",
layer_norm_epsilon=1e-05,
rotary_embedding=None,
dtype=None,
**kwargs,
):
super(ModernBERTEncoderLayer, self).__init__(**kwargs)
self.attn = ModernBERTAttention(
hidden_size, num_heads, rotary_embedding, dtype=dtype
)
self.mlp_norm = layers.LayerNormalization(
epsilon=layer_norm_epsilon, dtype=dtype
)
self.mlp = MLP(hidden_size, intermediate_size, activation, dtype=dtype)

def call(self, x):
x = self.attn(x)
x = self.mlp_norm(x)
x = self.mlp(x)
return x
36 changes: 36 additions & 0 deletions keras_hub/src/models/modernbert/modernbert_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.modernbert.modernbert_backbone import (
ModernBertBackbone,
)
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer


@keras_hub_export(
[
"keras_hub.tokenizers.ModernBertTokenizer",
"keras_hub.models.ModernBertTokenizer",
]
)
class ModernBertTokenizer(BytePairTokenizer):
backbone_cls = ModernBertBackbone
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring for all public symbols.


def __init__(
self,
vocabulary=None,
merges=None,
**kwargs,
):
self._add_special_token("[CLS]", "cls_token")
self._add_special_token("[SEP]", "sep_token")
self._add_special_token("[PAD]", "pad_token")
self._add_special_token("[UNK]", "unk_token")
self._add_special_token("[MASK]", "mask_token")
# Also add `tokenizer.start_token` and `tokenizer.end_token` for
# compatibility with other tokenizers.
self._add_special_token("[CLS]", "start_token")
self._add_special_token("[SEP]", "end_token")
super().__init__(
vocabulary=vocabulary,
merges=merges,
**kwargs,
)
35 changes: 35 additions & 0 deletions keras_hub/src/models/modernbert/modernbert_tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from keras_hub.src.models.modernbert.modernbert_tokenizer import (
ModernBertTokenizer,
)
from keras_hub.src.tests.test_case import TestCase


class ModernBertTokenizerTest(TestCase):
def setUp(self):
self.vocab = ["[CLS]", "[PAD]", "[SEP]", "air", "Ġair", "plane", "Ġat"]
self.vocab += ["port", "[MASK]", "[UNK]"]
self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
"[CLS] airplane at airport[SEP][PAD]",
" airplane airport",
]

def test_tokenizer_basics(self):
self.run_preprocessing_layer_test(
cls=ModernBertTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]],
expected_detokenize_output=[
"[CLS] airplane at airport[SEP][PAD]",
" airplane airport",
],
)

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
ModernBertTokenizer(vocabulary=["a", "b", "c"], merges=[])
Loading
Loading