Skip to content

Commit b855c42

Browse files
Ensures explicit mask keyword argument is used for output mask propagation (#21449)
1 parent 6a54c36 commit b855c42

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

keras/src/layers/layer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,9 +906,14 @@ def maybe_convert(x):
906906

907907
# We need to cache the `previous_mask` before `__call__` because the
908908
# mask might be removed during the call, such as `MultiHeadAttention`.
909-
previous_mask = tree.map_structure(
910-
backend.get_keras_mask, call_spec.first_arg
911-
)
909+
if "mask" in kwargs and kwargs["mask"] is not None:
910+
# Case 1: Mask was explicitly passed or auto-populated in step 6.
911+
previous_mask = kwargs["mask"]
912+
else:
913+
# Case 2: Fallback to the mask attached to the first input tensor.
914+
previous_mask = tree.map_structure(
915+
backend.get_keras_mask, call_spec.first_arg
916+
)
912917

913918
####################
914919
# 7. Call the layer.

keras/src/layers/layer_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,42 @@ def call(self, x, mask=None):
874874
y = layer(x)
875875
self.assertAllClose(y._keras_mask, mask)
876876

877+
@pytest.mark.skipif(
878+
backend.backend() == "numpy", reason="masking not supported with numpy"
879+
)
880+
def test_masking_with_explicit_kwarg_propagation(self):
881+
"""This test validates that an explicit `mask` kwarg is correctly
882+
used to compute the output mask.
883+
"""
884+
885+
class PassthroughMaskLayer(layers.Layer):
886+
def __init__(self):
887+
super().__init__()
888+
self.supports_masking = True
889+
890+
def call(self, x, mask=None):
891+
# The layer itself can use the mask.
892+
self.used_mask = mask is not None
893+
return x
894+
895+
layer = PassthroughMaskLayer()
896+
# Create an input tensor WITHOUT an attached mask.
897+
x = backend.numpy.ones((4, 4))
898+
self.assertIsNone(getattr(x, "_keras_mask", None))
899+
900+
# Create a mask to be passed explicitly.
901+
explicit_mask = backend.numpy.array([True, True, False, False])
902+
903+
# Call the layer, passing the mask as a keyword argument.
904+
y = layer(x, mask=explicit_mask)
905+
906+
# Assert that the layer's internal call received the mask.
907+
self.assertTrue(layer.used_mask)
908+
909+
# Assert that the output tensor 'y' now has the explicit mask attached
910+
# for propagation to the next layer.
911+
self.assertAllClose(backend.get_keras_mask(y), explicit_mask)
912+
877913
def test_stateless_call(self):
878914
class TestLayer(layers.Layer):
879915
def __init__(self):

0 commit comments

Comments
 (0)