@@ -874,6 +874,42 @@ def call(self, x, mask=None):
874
874
y = layer (x )
875
875
self .assertAllClose (y ._keras_mask , mask )
876
876
877
+ @pytest .mark .skipif (
878
+ backend .backend () == "numpy" , reason = "masking not supported with numpy"
879
+ )
880
+ def test_masking_with_explicit_kwarg_propagation (self ):
881
+ """This test validates that an explicit `mask` kwarg is correctly
882
+ used to compute the output mask.
883
+ """
884
+
885
+ class PassthroughMaskLayer (layers .Layer ):
886
+ def __init__ (self ):
887
+ super ().__init__ ()
888
+ self .supports_masking = True
889
+
890
+ def call (self , x , mask = None ):
891
+ # The layer itself can use the mask.
892
+ self .used_mask = mask is not None
893
+ return x
894
+
895
+ layer = PassthroughMaskLayer ()
896
+ # Create an input tensor WITHOUT an attached mask.
897
+ x = backend .numpy .ones ((4 , 4 ))
898
+ self .assertIsNone (getattr (x , "_keras_mask" , None ))
899
+
900
+ # Create a mask to be passed explicitly.
901
+ explicit_mask = backend .numpy .array ([True , True , False , False ])
902
+
903
+ # Call the layer, passing the mask as a keyword argument.
904
+ y = layer (x , mask = explicit_mask )
905
+
906
+ # Assert that the layer's internal call received the mask.
907
+ self .assertTrue (layer .used_mask )
908
+
909
+ # Assert that the output tensor 'y' now has the explicit mask attached
910
+ # for propagation to the next layer.
911
+ self .assertAllClose (backend .get_keras_mask (y ), explicit_mask )
912
+
877
913
def test_stateless_call (self ):
878
914
class TestLayer (layers .Layer ):
879
915
def __init__ (self ):
0 commit comments