diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer.py b/keras_hub/src/layers/preprocessing/multi_segment_packer.py index 3c0a8e346d..a4e0ba7ef4 100644 --- a/keras_hub/src/layers/preprocessing/multi_segment_packer.py +++ b/keras_hub/src/layers/preprocessing/multi_segment_packer.py @@ -3,6 +3,7 @@ PreprocessingLayer, ) from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch +from keras_hub.src.utils.tensor_utils import pad from keras_hub.src.utils.tensor_utils import preprocessing_function try: @@ -66,6 +67,8 @@ class MultiSegmentPacker(PreprocessingLayer): "waterfall" algorithm that allocates quota in a left-to-right manner and fills up the buckets until we run out of budget. It support arbitrary number of segments. + padding_side: str. Whether to pad the input on the "left" or "right". + Defaults to "right". Returns: A tuple with two elements. The first is the dense, packed token @@ -124,6 +127,7 @@ def __init__( sep_value=None, pad_value=None, truncate="round_robin", + padding_side="right", **kwargs, ): super().__init__(**kwargs) @@ -162,6 +166,7 @@ def check_special_value_type(value, value_name): self.end_value = end_value self.pad_value = pad_value + self.padding_side = padding_side def get_config(self): config = super().get_config() @@ -173,6 +178,7 @@ def get_config(self): "sep_value": self._sep_value, "pad_value": self.pad_value, "truncate": self.truncate, + "padding_side": self.padding_side, } ) return config @@ -287,10 +293,18 @@ def call( # Pad to dense tensor output. sequence_length = sequence_length or self.sequence_length shape = tf.cast([-1, sequence_length], "int64") - token_ids = token_ids.to_tensor( - shape=shape, default_value=self.pad_value + token_ids = pad( + token_ids, + shape=shape, + padding_side=self.padding_side, + pad_value=self.pad_value, + ) + segment_ids = pad( + segment_ids, + shape=shape, + padding_side=self.padding_side, + pad_value=0, ) - segment_ids = segment_ids.to_tensor(shape=shape) # Remove the batch dim if added. if unbatched: token_ids = tf.squeeze(token_ids, 0) diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py index c38aa137e5..7ffba1dd7e 100644 --- a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py +++ b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py @@ -8,6 +8,7 @@ class MultiSegmentPackerTest(TestCase): def test_trim_single_input_ints(self): + # right padding input_data = np.arange(3, 10) packer = MultiSegmentPacker( sequence_length=8, start_value=1, end_value=2 @@ -16,7 +17,20 @@ def test_trim_single_input_ints(self): self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2]) self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0]) + # left padding + input_data = np.arange(3, 10) + packer = MultiSegmentPacker( + sequence_length=8, + start_value=1, + end_value=2, + padding_side="left", + ) + token_ids, segment_ids = packer(input_data) + self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2]) + self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0]) + def test_trim_single_input_strings(self): + # right padding input_data = ["a", "b", "c", "d"] packer = MultiSegmentPacker( sequence_length=5, start_value="[CLS]", end_value="[SEP]" @@ -25,7 +39,19 @@ def test_trim_single_input_strings(self): self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"]) self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0]) + # left padding + packer = MultiSegmentPacker( + sequence_length=5, + start_value="[CLS]", + end_value="[SEP]", + padding_side="left", + ) + token_ids, segment_ids = packer(input_data) + self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"]) + self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0]) + def test_trim_multiple_inputs_round_robin(self): + # right padding seq1 = ["a", "b", "c"] seq2 = ["x", "y", "z"] packer = MultiSegmentPacker( @@ -40,7 +66,22 @@ def test_trim_multiple_inputs_round_robin(self): ) self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1]) + # left padding + packer = MultiSegmentPacker( + sequence_length=7, + start_value="[CLS]", + end_value="[SEP]", + truncate="round_robin", + padding_side="left", + ) + token_ids, segment_ids = packer((seq1, seq2)) + self.assertAllEqual( + token_ids, ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"] + ) + self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1]) + def test_trim_multiple_inputs_waterfall(self): + # right padding seq1 = ["a", "b", "c"] seq2 = ["x", "y", "z"] packer = MultiSegmentPacker( @@ -55,7 +96,22 @@ def test_trim_multiple_inputs_waterfall(self): ) self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1]) + # left padding + packer = MultiSegmentPacker( + sequence_length=7, + start_value="[CLS]", + end_value="[SEP]", + truncate="waterfall", + padding_side="left", + ) + token_ids, segment_ids = packer((seq1, seq2)) + self.assertAllEqual( + token_ids, ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"] + ) + self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1]) + def test_trim_batched_inputs_round_robin(self): + # right padding seq1 = [["a", "b", "c"], ["a", "b", "c"]] seq2 = [["x", "y", "z"], ["x", "y", "z"]] packer = MultiSegmentPacker( @@ -80,7 +136,32 @@ def test_trim_batched_inputs_round_robin(self): ], ) + # left padding + packer = MultiSegmentPacker( + sequence_length=7, + start_value="[CLS]", + end_value="[SEP]", + truncate="round_robin", + padding_side="left", + ) + token_ids, segment_ids = packer((seq1, seq2)) + self.assertAllEqual( + token_ids, + [ + ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"], + ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"], + ], + ) + self.assertAllEqual( + segment_ids, + [ + [0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1], + ], + ) + def test_trim_batched_inputs_waterfall(self): + # right padding seq1 = [["a", "b", "c"], ["a", "b"]] seq2 = [["x", "y", "z"], ["x", "y", "z"]] packer = MultiSegmentPacker( @@ -105,7 +186,32 @@ def test_trim_batched_inputs_waterfall(self): ], ) + # left padding + packer = MultiSegmentPacker( + sequence_length=7, + start_value="[CLS]", + end_value="[SEP]", + truncate="waterfall", + padding_side="left", + ) + token_ids, segment_ids = packer((seq1, seq2)) + self.assertAllEqual( + token_ids, + [ + ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"], + ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"], + ], + ) + self.assertAllEqual( + segment_ids, + [ + [0, 0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1, 1], + ], + ) + def test_pad_inputs(self): + # right padding seq1 = ["a"] seq2 = ["x"] packer = MultiSegmentPacker( @@ -118,7 +224,23 @@ def test_pad_inputs(self): ) self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0]) + # left padding + packer = MultiSegmentPacker( + 6, + start_value="[CLS]", + end_value="[SEP]", + pad_value="[PAD]", + padding_side="left", + ) + token_ids, segment_ids = packer((seq1, seq2)) + self.assertAllEqual( + token_ids, + ["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"], + ) + self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1]) + def test_pad_batched_inputs(self): + # right padding seq1 = [["a"], ["a"]] seq2 = [["x"], ["x", "y"]] packer = MultiSegmentPacker( @@ -143,7 +265,32 @@ def test_pad_batched_inputs(self): ], ) + # left padding + packer = MultiSegmentPacker( + sequence_length=7, + start_value="[CLS]", + end_value="[SEP]", + pad_value="[PAD]", + padding_side="left", + ) + token_ids, segment_ids = packer((seq1, seq2)) + self.assertAllEqual( + token_ids, + [ + ["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"], + ["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"], + ], + ) + self.assertAllEqual( + segment_ids, + [ + [0, 0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1, 1], + ], + ) + def test_list_special_tokens(self): + # right padding seq1 = [["a", "b"], ["a", "b"]] seq2 = [["x", "y"], ["x"]] packer = MultiSegmentPacker( @@ -170,6 +317,32 @@ def test_list_special_tokens(self): ], ) + # left padding + packer = MultiSegmentPacker( + 8, + start_value="", + end_value="", + sep_value=["", ""], + pad_value="", + truncate="round_robin", + padding_side="left", + ) + token_ids, segment_ids = packer((seq1, seq2)) + self.assertAllEqual( + token_ids, + [ + ["", "a", "b", "", "", "x", "y", ""], + ["", "", "a", "b", "", "", "x", ""], + ], + ) + self.assertAllEqual( + segment_ids, + [ + [0, 0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 1, 1], + ], + ) + def test_config(self): seq1 = [["a", "b", "c"], ["a", "b"]] seq2 = [["x", "y", "z"], ["x", "y", "z"]] diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py index f50f09b7f1..efe10a4585 100644 --- a/keras_hub/src/layers/preprocessing/start_end_packer.py +++ b/keras_hub/src/layers/preprocessing/start_end_packer.py @@ -3,6 +3,7 @@ PreprocessingLayer, ) from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch +from keras_hub.src.utils.tensor_utils import pad from keras_hub.src.utils.tensor_utils import preprocessing_function try: @@ -39,6 +40,8 @@ class StartEndPacker(PreprocessingLayer): 0 or "" will be added depending on the dtype of the input tensor. return_padding_mask: bool. Whether to return a boolean padding mask of all locations that are filled in with the `pad_value`. + padding_side: str. Whether to pad the input on the "left" or "right". + Defaults to "right". Call arguments: inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings. @@ -111,6 +114,7 @@ def __init__( pad_value=None, return_padding_mask=False, name=None, + padding_side="right", **kwargs, ): super().__init__(name=name, **kwargs) @@ -139,6 +143,7 @@ def check_special_value_type(value, value_name): self.pad_value = pad_value self.return_padding_mask = return_padding_mask + self.padding_side = padding_side @preprocessing_function def call( @@ -154,6 +159,13 @@ def call( batch_size = tf.shape(x)[0] sequence_length = sequence_length or self.sequence_length dtype = inputs.dtype + # Truncate. + truncation_length = sequence_length + if add_start_value and self.start_value is not None: + truncation_length -= len(self.start_value) + if add_end_value and self.end_value is not None: + truncation_length -= len(self.end_value) + x = x[..., :truncation_length] # Concatenate start and end tokens. if add_start_value and self.start_value is not None: @@ -167,23 +179,28 @@ def call( end_token_id_tensor = tf.repeat( end_value[tf.newaxis, :], repeats=batch_size, axis=0 ) - # Trim to leave room for end token. - x = x[..., : sequence_length - len(self.end_value)] x = tf.concat([x, end_token_id_tensor], axis=-1) # Pad to desired length. - outputs = x.to_tensor( - default_value=self.pad_value, + outputs = pad( + x, + pad_value=self.pad_value, + padding_side=self.padding_side, shape=(batch_size, sequence_length), ) outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs if self.return_padding_mask: mask = tf.ones_like(x, dtype="bool") - mask = mask.to_tensor(shape=(batch_size, sequence_length)) + + mask = pad( + mask, + pad_value=False, + padding_side=self.padding_side, + shape=(batch_size, sequence_length), + ) mask = tf.squeeze(mask, axis=0) if unbatched else mask return outputs, mask - return outputs def get_config(self): @@ -195,6 +212,7 @@ def get_config(self): "end_value": self._end_value, "pad_value": self.pad_value, "return_padding_mask": self.return_padding_mask, + "padding_side": self.padding_side, } ) return config diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py index 3b0ea65a9c..78f65405f0 100644 --- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py +++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py @@ -6,11 +6,19 @@ class StartEndPackerTest(TestCase): def test_dense_input(self): + # right padding input_data = [5, 6, 7] start_end_packer = StartEndPacker(sequence_length=5) output = start_end_packer(input_data) expected_output = [5, 6, 7, 0, 0] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=5, padding_side="left" + ) + output = start_end_packer(input_data) + expected_output = [0, 0, 5, 6, 7] + self.assertAllEqual(output, expected_output) def test_bfloat16_dtype(self): # Core Keras has a strange bug where it converts int to floats in @@ -21,20 +29,37 @@ def test_bfloat16_dtype(self): self.assertDTypeEqual(output, "int32") def test_dense_2D_input(self): + # right padding input_data = [[5, 6, 7]] start_end_packer = StartEndPacker(sequence_length=5) output = start_end_packer(input_data) expected_output = [[5, 6, 7, 0, 0]] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=5, padding_side="left" + ) + output = start_end_packer(input_data) + expected_output = [[0, 0, 5, 6, 7]] + self.assertAllEqual(output, expected_output) def test_ragged_input(self): + # right padding input_data = [[5, 6, 7], [8, 9, 10, 11]] start_end_packer = StartEndPacker(sequence_length=5) output = start_end_packer(input_data) expected_output = [[5, 6, 7, 0, 0], [8, 9, 10, 11, 0]] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=5, padding_side="left" + ) + output = start_end_packer(input_data) + expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]] + self.assertAllEqual(output, expected_output) def test_start_end_token(self): + # right padding input_data = [[5, 6, 7], [8, 9, 10, 11]] start_end_packer = StartEndPacker( sequence_length=6, start_value=1, end_value=2 @@ -42,8 +67,16 @@ def test_start_end_token(self): output = start_end_packer(input_data) expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=6, start_value=1, end_value=2, padding_side="left" + ) + output = start_end_packer(input_data) + expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]] + self.assertAllEqual(output, expected_output) def test_multiple_start_end_tokens(self): + # right padding input_data = [[5, 6, 7], [8, 9, 10, 11, 12, 13]] start_end_packer = StartEndPacker( sequence_length=8, @@ -55,7 +88,20 @@ def test_multiple_start_end_tokens(self): expected_output = [[1, 2, 5, 6, 7, 3, 4, 0], [1, 2, 8, 9, 10, 11, 3, 4]] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=8, + start_value=[1, 2], + end_value=[3, 4], + pad_value=0, + padding_side="left", + ) + output = start_end_packer(input_data) + expected_output = [[0, 1, 2, 5, 6, 7, 3, 4], [1, 2, 8, 9, 10, 11, 3, 4]] + self.assertAllEqual(output, expected_output) + def test_start_end_padding_value(self): + # right padding input_data = [[5, 6, 7], [8, 9, 10, 11]] start_end_packer = StartEndPacker( sequence_length=7, start_value=1, end_value=2, pad_value=3 @@ -64,7 +110,58 @@ def test_start_end_padding_value(self): expected_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=7, + start_value=1, + end_value=2, + pad_value=3, + padding_side="left", + ) + output = start_end_packer(input_data) + expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]] + self.assertAllEqual(output, expected_output) + + def test_truncation(self): + # right padding + input_data = list(range(10)) + packer = StartEndPacker( + sequence_length=7, + start_value=98, + end_value=99, + ) + expected_output = [98, 0, 1, 2, 3, 4, 99] + self.assertAllEqual(packer(input_data), expected_output) + + # left padding + packer = StartEndPacker( + sequence_length=7, + start_value=98, + end_value=99, + padding_side="left", + ) + self.assertAllEqual(packer(input_data), expected_output) + + def test_truncation_wo_endvalue(self): + # right padding + input_data = list(range(10)) + packer = StartEndPacker( + sequence_length=7, + start_value=98, + ) + expected_output = [98, 0, 1, 2, 3, 4, 5] + self.assertAllEqual(packer(input_data), expected_output) + + # left padding + packer = StartEndPacker( + sequence_length=7, + start_value=98, + padding_side="left", + ) + self.assertAllEqual(packer(input_data), expected_output) + def test_end_token_value_during_truncation(self): + # right padding input_data = [[5, 6], [8, 9, 10, 11, 12, 13]] start_end_packer = StartEndPacker( sequence_length=5, start_value=1, end_value=2, pad_value=0 @@ -73,7 +170,20 @@ def test_end_token_value_during_truncation(self): expected_output = [[1, 5, 6, 2, 0], [1, 8, 9, 10, 2]] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=5, + start_value=1, + end_value=2, + pad_value=0, + padding_side="left", + ) + output = start_end_packer(input_data) + expected_output = [[0, 1, 5, 6, 2], [1, 8, 9, 10, 2]] + self.assertAllEqual(output, expected_output) + def test_string_input(self): + # right padding input_data = [["KerasHub", "is", "awesome"], ["amazing"]] start_end_packer = StartEndPacker( sequence_length=5, @@ -88,7 +198,23 @@ def test_string_input(self): ] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=5, + start_value="[START]", + end_value="[END]", + pad_value="[PAD]", + padding_side="left", + ) + output = start_end_packer(input_data) + expected_output = [ + ["[START]", "KerasHub", "is", "awesome", "[END]"], + ["[PAD]", "[PAD]", "[START]", "amazing", "[END]"], + ] + self.assertAllEqual(output, expected_output) + def test_string_input_with_multiple_special_values(self): + # right padding input_data = [["KerasHub", "is", "awesome"], ["amazing"]] start_end_packer = StartEndPacker( sequence_length=6, @@ -103,6 +229,21 @@ def test_string_input_with_multiple_special_values(self): ] self.assertAllEqual(output, expected_output) + # left padding + start_end_packer = StartEndPacker( + sequence_length=6, + start_value=["[END]", "[START]"], + end_value="[END]", + pad_value="[PAD]", + padding_side="left", + ) + output = start_end_packer(input_data) + expected_output = [ + ["[END]", "[START]", "KerasHub", "is", "awesome", "[END]"], + ["[PAD]", "[PAD]", "[END]", "[START]", "amazing", "[END]"], + ] + self.assertAllEqual(output, expected_output) + def test_special_token_dtype_error(self): with self.assertRaises(ValueError): StartEndPacker(sequence_length=5, start_value=1.0) @@ -147,3 +288,39 @@ def test_get_config(self): } self.assertEqual(config, {**config, **expected_config_subset}) + + def test_return_padding_mask(self): + # right_padding + input_data = [[5, 6, 7], [8, 9, 10, 11]] + start_end_packer = StartEndPacker( + sequence_length=6, + start_value=1, + end_value=2, + return_padding_mask=True, + ) + output, padding_mask = start_end_packer(input_data) + expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]] + expected_padding_mask = [ + [True, True, True, True, True, False], + [True, True, True, True, True, True], + ] + print(padding_mask) + self.assertAllEqual(output, expected_output) + self.assertAllEqual(padding_mask, expected_padding_mask) + + # left_padding + start_end_packer = StartEndPacker( + sequence_length=6, + start_value=1, + end_value=2, + return_padding_mask=True, + padding_side="left", + ) + output, padding_mask = start_end_packer(input_data) + expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]] + expected_padding_mask = [ + [False, True, True, True, True, True], + [True, True, True, True, True, True], + ] + self.assertAllEqual(output, expected_output) + self.assertAllEqual(padding_mask, expected_padding_mask) diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index a602963cf0..fa4a5abad1 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -19,6 +19,20 @@ NO_CONVERT_COUNTER = threading.local() +def pad(x, shape, padding_side, pad_value): + if padding_side == "left": + x = x[..., ::-1] + + outputs = x.to_tensor( + default_value=pad_value, + shape=shape, + ) + + if padding_side == "left": + outputs = outputs[..., ::-1] + return outputs + + @contextlib.contextmanager def no_convert_scope(): try: