Skip to content

[WIP] add ModernBERT #2256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

SauravMaheshkar
Copy link

Ref: #2027

@SauravMaheshkar
Copy link
Author

SauravMaheshkar commented May 17, 2025

Hey folks 👋🏼 , could I get some help on this failing test, it's failing when using the mixed_float16 dtype policy.

🐕 ❯ pytest keras_hub/src/models/modernbert/modernbert_backbone_test.py
=================================================================================================== test session starts ===================================================================================================
platform darwin -- Python 3.11.11, pytest-8.3.5, pluggy-1.6.0 -- /Users/sauravmaheshkar/dev/keras-hub/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /Users/sauravmaheshkar/dev/keras-hub
configfile: pyproject.toml
plugins: cov-6.1.1
collected 4 items                                                                                                                                                                                                         

keras_hub/src/models/modernbert/modernbert_backbone_test.py::TestCase::test_session <- .venv/lib/python3.11/site-packages/tensorflow/python/framework/test_util.py SKIPPED (Not a test.)                            [ 25%]
keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_backbone_basics FAILED                                                                                                    [ 50%]
keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_saved_model SKIPPED (need --run_large option to run)                                                                      [ 75%]
keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_session <- .venv/lib/python3.11/site-packages/tensorflow/python/framework/test_util.py PASSED                             [100%]

======================================================================================================== FAILURES =========================================================================================================
_______________________________________________________________________________________ ModernBertBackboneTest.test_backbone_basics _______________________________________________________________________________________

self = <keras_hub.src.models.modernbert.modernbert_backbone_test.ModernBertBackboneTest testMethod=test_backbone_basics>

    def test_backbone_basics(self):
>       self.run_backbone_test(
            cls=ModernBertBackbone,
            init_kwargs=self.init_kwargs,
            input_data=self.input_data,
            expected_output_shape=(2, 5, 8),
        )

keras_hub/src/models/modernbert/modernbert_backbone_test.py:25: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras_hub/src/tests/test_case.py:490: in run_backbone_test
    self.run_precision_test(cls, init_kwargs, input_data)
keras_hub/src/tests/test_case.py:355: in run_precision_test
    self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
keras_hub/src/tests/test_case.py:57: in assertEqual
    super().assertEqual(x1, x2, msg=msg)
E   AssertionError: 
E   - float16
E   + float32
-------------------------------------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------------------------------------
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 298ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
================================================================================================= short test summary info =================================================================================================
FAILED keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_backbone_basics - AssertionError: 
- float16
+ float32
========================================================================================= 1 failed, 1 passed, 2 skipped in 3.88s ==========================================================================================

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Excited for modernbert! Great fit for the library.


@keras_hub_export("keras_hub.models.ModernBertBackbone")
class ModernBertBackbone(Backbone):
def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstrings.

)
self.transformer_layers = []
for i in range(num_layers):
layer = ModernBERTEncoderLayer(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModernBERTEncoderLayer -> ModernBertEncoderLayer

dtype=dtype,
name="token_embedding",
)
self.position_embedding = RotaryEmbedding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we cache the rotary tensor or anything, so there's probably no advantage to using a shared layer here. Instead I would just create a RotaryEmbedding inside inside the encoder layer, same as other models that using rotary embeddings in this repo.

num_heads,
hidden_dim,
intermediate_dim,
max_sequence_length=8192,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused? If this doesn't actually do anything, no need for it. We don't need config to track pretraining params (though we can document on Kaggle the longest sequence length checkpoints were trained on).

]
)
class ModernBertTokenizer(BytePairTokenizer):
backbone_cls = ModernBertBackbone
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring for all public symbols.

self,
hidden_size,
intermediate_size,
activation="gelu",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we check that we want gelu and not gelu(x, approximate=True)? https://keras.io/api/layers/activations/#gelu-function

We should do whatever matches upstream

**kwargs,
):
super(MLP, self).__init__(**kwargs)
self.Wi = layers.Dense(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally don't do algebriac notation like this in attr names.. self.input_dense and self.output_dense or something like that.

self.num_heads = num_heads
self.hidden_size = hidden_size
self.rotary_embedding = rotary_embedding
self.Wqkv = layers.Dense(hidden_size * 3, use_bias=False, dtype=dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here. dense_qkv.

We may want to split this for consistency with other models. We generally don't do the fused qkv like this.

from keras import layers
from keras import ops

from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't add a cross model dep like this. If we want stuff like this across models, we should pull it into a util. But in this case, I think we probably want to unfuse the qkv matrix, for consistent UX for things like lora.

return self.Wo(self.act(input) * gate)


class ModernBERTAttention(keras.Model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModernBERTAttention -> ModernBertAttention

Same anywhere we have BERT in a classname

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants