-
Notifications
You must be signed in to change notification settings - Fork 286
[WIP] add ModernBERT #2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] add ModernBERT #2256
Conversation
Hey folks 👋🏼 , could I get some help on this failing test, it's failing when using the 🐕 ❯ pytest keras_hub/src/models/modernbert/modernbert_backbone_test.py
=================================================================================================== test session starts ===================================================================================================
platform darwin -- Python 3.11.11, pytest-8.3.5, pluggy-1.6.0 -- /Users/sauravmaheshkar/dev/keras-hub/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /Users/sauravmaheshkar/dev/keras-hub
configfile: pyproject.toml
plugins: cov-6.1.1
collected 4 items
keras_hub/src/models/modernbert/modernbert_backbone_test.py::TestCase::test_session <- .venv/lib/python3.11/site-packages/tensorflow/python/framework/test_util.py SKIPPED (Not a test.) [ 25%]
keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_backbone_basics FAILED [ 50%]
keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_saved_model SKIPPED (need --run_large option to run) [ 75%]
keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_session <- .venv/lib/python3.11/site-packages/tensorflow/python/framework/test_util.py PASSED [100%]
======================================================================================================== FAILURES =========================================================================================================
_______________________________________________________________________________________ ModernBertBackboneTest.test_backbone_basics _______________________________________________________________________________________
self = <keras_hub.src.models.modernbert.modernbert_backbone_test.ModernBertBackboneTest testMethod=test_backbone_basics>
def test_backbone_basics(self):
> self.run_backbone_test(
cls=ModernBertBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
)
keras_hub/src/models/modernbert/modernbert_backbone_test.py:25:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras_hub/src/tests/test_case.py:490: in run_backbone_test
self.run_precision_test(cls, init_kwargs, input_data)
keras_hub/src/tests/test_case.py:355: in run_precision_test
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
keras_hub/src/tests/test_case.py:57: in assertEqual
super().assertEqual(x1, x2, msg=msg)
E AssertionError:
E - float16
E + float32
-------------------------------------------------------------------------------------------------- Captured stdout call ---------------------------------------------------------------------------------------------------
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 298ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
================================================================================================= short test summary info =================================================================================================
FAILED keras_hub/src/models/modernbert/modernbert_backbone_test.py::ModernBertBackboneTest::test_backbone_basics - AssertionError:
- float16
+ float32
========================================================================================= 1 failed, 1 passed, 2 skipped in 3.88s ========================================================================================== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Excited for modernbert! Great fit for the library.
|
||
@keras_hub_export("keras_hub.models.ModernBertBackbone") | ||
class ModernBertBackbone(Backbone): | ||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add docstrings.
) | ||
self.transformer_layers = [] | ||
for i in range(num_layers): | ||
layer = ModernBERTEncoderLayer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModernBERTEncoderLayer -> ModernBertEncoderLayer
dtype=dtype, | ||
name="token_embedding", | ||
) | ||
self.position_embedding = RotaryEmbedding( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we cache the rotary tensor or anything, so there's probably no advantage to using a shared layer here. Instead I would just create a RotaryEmbedding
inside inside the encoder layer, same as other models that using rotary embeddings in this repo.
num_heads, | ||
hidden_dim, | ||
intermediate_dim, | ||
max_sequence_length=8192, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused? If this doesn't actually do anything, no need for it. We don't need config to track pretraining params (though we can document on Kaggle the longest sequence length checkpoints were trained on).
] | ||
) | ||
class ModernBertTokenizer(BytePairTokenizer): | ||
backbone_cls = ModernBertBackbone |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring for all public symbols.
self, | ||
hidden_size, | ||
intermediate_size, | ||
activation="gelu", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have we check that we want gelu and not gelu(x, approximate=True)? https://keras.io/api/layers/activations/#gelu-function
We should do whatever matches upstream
**kwargs, | ||
): | ||
super(MLP, self).__init__(**kwargs) | ||
self.Wi = layers.Dense( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We generally don't do algebriac notation like this in attr names.. self.input_dense
and self.output_dense
or something like that.
self.num_heads = num_heads | ||
self.hidden_size = hidden_size | ||
self.rotary_embedding = rotary_embedding | ||
self.Wqkv = layers.Dense(hidden_size * 3, use_bias=False, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here. dense_qkv
.
We may want to split this for consistency with other models. We generally don't do the fused qkv like this.
from keras import layers | ||
from keras import ops | ||
|
||
from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't add a cross model dep like this. If we want stuff like this across models, we should pull it into a util. But in this case, I think we probably want to unfuse the qkv matrix, for consistent UX for things like lora.
return self.Wo(self.act(input) * gate) | ||
|
||
|
||
class ModernBERTAttention(keras.Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModernBERTAttention -> ModernBertAttention
Same anywhere we have BERT in a classname
Ref: #2027