From 6bed56a51e70d41b3e064ef8a0f8d49d4909a448 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 16 May 2025 15:49:19 +0100 Subject: [PATCH 1/3] feat: init ModernBertBackbone --- keras_hub/api/models/__init__.py | 3 + keras_hub/src/models/modernbert/__init__.py | 0 .../models/modernbert/modernbert_backbone.py | 138 ++++++++++++++++++ .../modernbert/modernbert_backbone_test.py | 38 +++++ 4 files changed, 179 insertions(+) create mode 100644 keras_hub/src/models/modernbert/__init__.py create mode 100644 keras_hub/src/models/modernbert/modernbert_backbone.py create mode 100644 keras_hub/src/models/modernbert/modernbert_backbone_test.py diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 0a71dbcace..6e10b7a6f0 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -369,6 +369,9 @@ from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, ) +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone as ModernBertBackbone, +) from keras_hub.src.models.moonshine.moonshine_audio_to_text import ( MoonshineAudioToText as MoonshineAudioToText, ) diff --git a/keras_hub/src/models/modernbert/__init__.py b/keras_hub/src/models/modernbert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/modernbert/modernbert_backbone.py b/keras_hub/src/models/modernbert/modernbert_backbone.py new file mode 100644 index 0000000000..d00cc02b88 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_backbone.py @@ -0,0 +1,138 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.utils.keras_utils import gelu_approximate + + +@keras_hub_export("keras_hub.models.ModernBertBackbone") +class ModernBertBackbone(Backbone): + """A ModernBERT encoder network. + + This class implements the ModernBERT backbone, using rotary embeddings, + RMS normalization, and a stack of TransformerEncoder layers. + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + max_sequence_length=8192, + dropout=0.0, + rotary_max_wavelength=160000.0, + layer_norm_epsilon=1e-5, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=keras.initializers.TruncatedNormal( + stddev=0.02 + ), + dtype=dtype, + name="token_embedding", + ) + self.position_embedding = RotaryEmbedding( + max_wavelength=rotary_max_wavelength, + sequence_axis=1, + feature_axis=-1, + dtype=dtype, + name="rotary_embedding", + ) + self.embeddings_layer_norm = RMSNormalization( + dtype=dtype, + epsilon=layer_norm_epsilon, + ) + self.embeddings_dropout = keras.layers.Dropout( + dropout, dtype=dtype, name="embeddings_dropout" + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = TransformerEncoder( + num_heads=num_heads, + intermediate_dim=intermediate_dim, + activation=gelu_approximate, + dropout=dropout, + layer_norm_epsilon=layer_norm_epsilon, + kernel_initializer=keras.initializers.TruncatedNormal( + stddev=0.02 + ), + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.final_norm = RMSNormalization( + dtype=dtype, + epsilon=layer_norm_epsilon, + name="final_normalization", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + # Embed tokens and apply rotary position embedding + x = self.token_embedding(token_id_input) + x = self.position_embedding(x) + x = self.embeddings_layer_norm(x) + x = self.embeddings_dropout(x) + + # Transformer layers + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) + + # Final normalization + sequence_output = self.final_norm(x) + + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.max_sequence_length = max_sequence_length + self.dropout = dropout + self.rotary_max_wavelength = rotary_max_wavelength + self.layer_norm_epsilon = layer_norm_epsilon + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "max_sequence_length": self.max_sequence_length, + "dropout": self.dropout, + "rotary_max_wavelength": self.rotary_max_wavelength, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/modernbert/modernbert_backbone_test.py b/keras_hub/src/models/modernbert/modernbert_backbone_test.py new file mode 100644 index 0000000000..e3c797cfd0 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_backbone_test.py @@ -0,0 +1,38 @@ +import pytest +from keras import ops + +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone, +) +from keras_hub.src.tests.test_case import TestCase + + +class ModernBertBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_heads": 4, + "hidden_dim": 8, + "intermediate_dim": 32, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=ModernBertBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ModernBertBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) From 5b6e62aa0644f3902c6a5a19db209abf55e0e76b Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Sat, 17 May 2025 04:01:18 +0100 Subject: [PATCH 2/3] feat: update backbone + add tokenizer --- .../models/modernbert/modernbert_backbone.py | 50 +++----- .../models/modernbert/modernbert_layers.py | 95 ++++++++++++++ .../models/modernbert/modernbert_tokenizer.py | 36 ++++++ .../modernbert/modernbert_tokenizer_test.py | 35 +++++ .../convert_modernbert_checkpoints.py | 120 ++++++++++++++++++ 5 files changed, 303 insertions(+), 33 deletions(-) create mode 100644 keras_hub/src/models/modernbert/modernbert_layers.py create mode 100644 keras_hub/src/models/modernbert/modernbert_tokenizer.py create mode 100644 keras_hub/src/models/modernbert/modernbert_tokenizer_test.py create mode 100644 tools/checkpoint_conversion/convert_modernbert_checkpoints.py diff --git a/keras_hub/src/models/modernbert/modernbert_backbone.py b/keras_hub/src/models/modernbert/modernbert_backbone.py index d00cc02b88..dabd488bd5 100644 --- a/keras_hub/src/models/modernbert/modernbert_backbone.py +++ b/keras_hub/src/models/modernbert/modernbert_backbone.py @@ -5,20 +5,15 @@ ReversibleEmbedding, ) from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding -from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.gemma.rms_normalization import RMSNormalization +from keras_hub.src.models.modernbert.modernbert_layers import ( + ModernBERTEncoderLayer, +) from keras_hub.src.utils.keras_utils import gelu_approximate @keras_hub_export("keras_hub.models.ModernBertBackbone") class ModernBertBackbone(Backbone): - """A ModernBERT encoder network. - - This class implements the ModernBERT backbone, using rotary embeddings, - RMS normalization, and a stack of TransformerEncoder layers. - """ - def __init__( self, vocabulary_size, @@ -45,37 +40,33 @@ def __init__( ) self.position_embedding = RotaryEmbedding( max_wavelength=rotary_max_wavelength, - sequence_axis=1, - feature_axis=-1, dtype=dtype, name="rotary_embedding", ) - self.embeddings_layer_norm = RMSNormalization( - dtype=dtype, + self.embeddings_layer_norm = keras.layers.LayerNormalization( epsilon=layer_norm_epsilon, - ) - self.embeddings_dropout = keras.layers.Dropout( - dropout, dtype=dtype, name="embeddings_dropout" + dtype=dtype, + rms_scaling=True, + name="embeddings_layer_norm", ) self.transformer_layers = [] for i in range(num_layers): - layer = TransformerEncoder( + layer = ModernBERTEncoderLayer( + hidden_size=hidden_dim, + intermediate_size=intermediate_dim, num_heads=num_heads, - intermediate_dim=intermediate_dim, activation=gelu_approximate, - dropout=dropout, layer_norm_epsilon=layer_norm_epsilon, - kernel_initializer=keras.initializers.TruncatedNormal( - stddev=0.02 - ), + rotary_embedding=self.position_embedding, dtype=dtype, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) - self.final_norm = RMSNormalization( - dtype=dtype, + self.final_norm = keras.layers.LayerNormalization( epsilon=layer_norm_epsilon, - name="final_normalization", + rms_scaling=True, + dtype=dtype, + name="final_layernorm", ) # === Functional Model === @@ -85,20 +76,13 @@ def __init__( padding_mask_input = keras.Input( shape=(None,), dtype="int32", name="padding_mask" ) - - # Embed tokens and apply rotary position embedding x = self.token_embedding(token_id_input) - x = self.position_embedding(x) x = self.embeddings_layer_norm(x) - x = self.embeddings_dropout(x) - - # Transformer layers for transformer_layer in self.transformer_layers: - x = transformer_layer(x, padding_mask=padding_mask_input) - - # Final normalization + x = transformer_layer(x) sequence_output = self.final_norm(x) + # Instantiate using Functional API Model constructor super().__init__( inputs={ "token_ids": token_id_input, diff --git a/keras_hub/src/models/modernbert/modernbert_layers.py b/keras_hub/src/models/modernbert/modernbert_layers.py new file mode 100644 index 0000000000..82fe16b004 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_layers.py @@ -0,0 +1,95 @@ +import keras +from keras import layers +from keras import ops + +from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors +from keras_hub.src.models.flux.flux_maths import scaled_dot_product_attention + + +class MLP(keras.layers.Layer): + def __init__( + self, + hidden_size, + intermediate_size, + activation="gelu", + dtype=None, + **kwargs, + ): + super(MLP, self).__init__(**kwargs) + self.Wi = layers.Dense( + intermediate_size * 2, + use_bias=False, + dtype=dtype, + ) + self.act = keras.activations.get(activation) + self.Wo = layers.Dense( + hidden_size, + use_bias=False, + dtype=dtype, + ) + + def call(self, x): + input, gate = ops.split(self.Wi(x), 2, axis=-1) + return self.Wo(self.act(input) * gate) + + +class ModernBERTAttention(keras.Model): + def __init__( + self, hidden_size, num_heads, rotary_embedding, dtype=None, **kwargs + ): + super(ModernBERTAttention, self).__init__(**kwargs) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.rotary_embedding = rotary_embedding + self.Wqkv = layers.Dense(hidden_size * 3, use_bias=False, dtype=dtype) + self.Wo = layers.Dense(hidden_size, use_bias=False, dtype=dtype) + + def build(self, input_shape): + self.Wqkv.build(input_shape) + self.Wo.build((None, input_shape[1], input_shape[-1])) + + def call(self, x): + qkv = self.Wqkv(x) + q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads) + + # Apply rotary embeddings + q = self.rotary_embedding(q) + k = self.rotary_embedding(k) + + # Apply scaled dot product attention + x = scaled_dot_product_attention(q, k, v) + + # Reshape and apply final dense layer + x = ops.transpose(x, (0, 2, 1, 3)) + b, s, h, d = ops.shape(x) + x = ops.reshape(x, (b, s, h * d)) + x = self.Wo(x) + return x + + +class ModernBERTEncoderLayer(keras.Model): + def __init__( + self, + hidden_size, + intermediate_size, + num_heads, + activation="gelu", + layer_norm_epsilon=1e-05, + rotary_embedding=None, + dtype=None, + **kwargs, + ): + super(ModernBERTEncoderLayer, self).__init__(**kwargs) + self.attn = ModernBERTAttention( + hidden_size, num_heads, rotary_embedding, dtype=dtype + ) + self.mlp_norm = layers.LayerNormalization( + epsilon=layer_norm_epsilon, dtype=dtype + ) + self.mlp = MLP(hidden_size, intermediate_size, activation, dtype=dtype) + + def call(self, x): + x = self.attn(x) + x = self.mlp_norm(x) + x = self.mlp(x) + return x diff --git a/keras_hub/src/models/modernbert/modernbert_tokenizer.py b/keras_hub/src/models/modernbert/modernbert_tokenizer.py new file mode 100644 index 0000000000..8d4fe18ab2 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_tokenizer.py @@ -0,0 +1,36 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone, +) +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + [ + "keras_hub.tokenizers.ModernBertTokenizer", + "keras_hub.models.ModernBertTokenizer", + ] +) +class ModernBertTokenizer(BytePairTokenizer): + backbone_cls = ModernBertBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + self._add_special_token("[CLS]", "cls_token") + self._add_special_token("[SEP]", "sep_token") + self._add_special_token("[PAD]", "pad_token") + self._add_special_token("[UNK]", "unk_token") + self._add_special_token("[MASK]", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("[CLS]", "start_token") + self._add_special_token("[SEP]", "end_token") + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/models/modernbert/modernbert_tokenizer_test.py b/keras_hub/src/models/modernbert/modernbert_tokenizer_test.py new file mode 100644 index 0000000000..b863d03724 --- /dev/null +++ b/keras_hub/src/models/modernbert/modernbert_tokenizer_test.py @@ -0,0 +1,35 @@ +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class ModernBertTokenizerTest(TestCase): + def setUp(self): + self.vocab = ["[CLS]", "[PAD]", "[SEP]", "air", "Ġair", "plane", "Ġat"] + self.vocab += ["port", "[MASK]", "[UNK]"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + "[CLS] airplane at airport[SEP][PAD]", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=ModernBertTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], + expected_detokenize_output=[ + "[CLS] airplane at airport[SEP][PAD]", + " airplane airport", + ], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + ModernBertTokenizer(vocabulary=["a", "b", "c"], merges=[]) diff --git a/tools/checkpoint_conversion/convert_modernbert_checkpoints.py b/tools/checkpoint_conversion/convert_modernbert_checkpoints.py new file mode 100644 index 0000000000..e6ff5e9599 --- /dev/null +++ b/tools/checkpoint_conversion/convert_modernbert_checkpoints.py @@ -0,0 +1,120 @@ +"""Convert ModernBERT checkpoints. + +python tools/checkpoint_conversion/convert_modernbert_checkpoints.py \ + --preset modernbert_base +python tools/checkpoint_conversion/convert_modernbert_checkpoints.py \ + --preset modernbert_large +""" + +import json +import os + +import numpy as np +import requests +import transformers +from absl import app +from absl import flags + +from keras_hub.src.models.modernbert.modernbert_backbone import ( + ModernBertBackbone, +) + +PRESET_MAP = { + "modernbert_base": "answerdotai/ModernBERT-base", + "modernbert_large": "answerdotai/ModernBERT-large", +} + +EXTRACT_DIR = "./{}" + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", +) + + +def download_files(hf_model_name): + extract_dir = EXTRACT_DIR.format(FLAGS.preset) + if not os.path.exists(extract_dir): + os.makedirs(extract_dir) + + # Config. + config_path = os.path.join(extract_dir, "config.json") + response = requests.get( + f"https://huggingface.co/{hf_model_name}/raw/main/config.json" + ) + open(config_path, "wb").write(response.content) + + +def convert_model(hf_model): + extract_dir = EXTRACT_DIR.format(FLAGS.preset) + config_path = os.path.join(extract_dir, "config.json") + + # Build config. + cfg = {} + with open(config_path, "r") as pt_cfg_handler: + pt_cfg = json.load(pt_cfg_handler) + cfg["vocabulary_size"] = pt_cfg["vocab_size"] + cfg["num_layers"] = pt_cfg["num_hidden_layers"] + cfg["num_heads"] = pt_cfg["num_attention_heads"] + cfg["hidden_dim"] = pt_cfg["hidden_size"] + cfg["intermediate_dim"] = pt_cfg["intermediate_size"] + cfg["dropout"] = pt_cfg["embedding_dropout"] + cfg["max_sequence_length"] = pt_cfg["max_position_embeddings"] + + return ModernBertBackbone(**cfg) + + +def convert_weights(keras_model, hf_model): + # Get `state_dict` from `hf_model`. + state_dict = hf_model.state_dict() + + keras_model.get_layer("token_embedding").set_weights( + [np.asarray(state_dict["embeddings.tok_embeddings.weight"])] + ) + + keras_model.get_layer("embeddings_layer_norm").set_weights( + [np.asarray(state_dict["embeddings.norm.weight"])] + ) + + for i in range(keras_model.num_layers): + keras_model.transformer_layers[i].attn.Wqkv.kernel.assign( + state_dict[f"layers.{i}.attn.Wqkv.weight"].T + ) + keras_model.transformer_layers[i].attn.Wo.kernel.assign( + state_dict[f"layers.{i}.attn.Wo.weight"] + ) + keras_model.transformer_layers[i].mlp_norm.gamma.assign( + state_dict[f"layers.{i}.mlp_norm.weight"] + ) + keras_model.transformer_layers[i].mlp.Wi.kernel.assign( + state_dict[f"layers.{i}.mlp.Wi.weight"].T + ) + keras_model.transformer_layers[i].mlp.Wo.kernel.assign( + state_dict[f"layers.{i}.mlp.Wo.weight"].T + ) + + keras_model.get_layer("final_layernorm").set_weights( + [np.asarray(state_dict["final_norm.weight"])] + ) + + +def main(_): + hf_model_name = PRESET_MAP[FLAGS.preset] + download_files(hf_model_name) + + hf_model = transformers.AutoModel.from_pretrained(hf_model_name) + hf_model.eval() + + print(f"🏃 Coverting {FLAGS.preset}") + keras_model = convert_model(hf_model) + print("✅ KerasHub model loaded.") + + convert_weights(keras_model, hf_model) + print("✅ Weights converted.") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From 89e7dcb1595a175d0f5bf028ef20eb8338a0e5be Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Sat, 17 May 2025 04:05:35 +0100 Subject: [PATCH 3/3] chore: api-gen --- keras_hub/api/models/__init__.py | 3 +++ keras_hub/api/tokenizers/__init__.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 6e10b7a6f0..4a413ff165 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -372,6 +372,9 @@ from keras_hub.src.models.modernbert.modernbert_backbone import ( ModernBertBackbone as ModernBertBackbone, ) +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer as ModernBertTokenizer, +) from keras_hub.src.models.moonshine.moonshine_audio_to_text import ( MoonshineAudioToText as MoonshineAudioToText, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..428aa62d7f 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -58,6 +58,9 @@ from keras_hub.src.models.mixtral.mixtral_tokenizer import ( MixtralTokenizer as MixtralTokenizer, ) +from keras_hub.src.models.modernbert.modernbert_tokenizer import ( + ModernBertTokenizer as ModernBertTokenizer, +) from keras_hub.src.models.moonshine.moonshine_tokenizer import ( MoonshineTokenizer as MoonshineTokenizer, )