Skip to content

implement of leftpadding #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions keras_hub/src/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class StartEndPacker(PreprocessingLayer):
0 or "" will be added depending on the dtype of the input tensor.
return_padding_mask: bool. Whether to return a boolean padding mask of
all locations that are filled in with the `pad_value`.
padding_side: str. Whether to pad the input on the "left" or "right".
Defaults to "right".

Call arguments:
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
Expand Down Expand Up @@ -111,6 +113,7 @@ def __init__(
pad_value=None,
return_padding_mask=False,
name=None,
padding_side="right",
**kwargs,
):
super().__init__(name=name, **kwargs)
Expand Down Expand Up @@ -139,6 +142,20 @@ def check_special_value_type(value, value_name):

self.pad_value = pad_value
self.return_padding_mask = return_padding_mask
self.padding_side = padding_side

def pad(self, x, shape, pad_value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull this into a util in tensor_utils.py or something like that. We will also need it for the multi segment packer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull this into a util in or something like that. We will also need it for the multi segment packer.tensor_utils.py

I don't think this is necessary. Only Bert-like models are using multi-segment packers. For Bert-like models, there is no essential difference between left padding and right padding.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want it first and foremost for uniformity of the API. But also, this is not just for BERT-like. Gemma3 and PaliGemma use it, for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want it first and foremost for uniformity of the API. But also, this is not just for BERT-like. Gemma3 and PaliGemma use it, for example.

OK, I have fulfilled this requirement. Please check

if self.padding_side == "left":
x = x[..., ::-1]

outputs = x.to_tensor(
default_value=pad_value,
shape=shape,
)

if self.padding_side == "left":
outputs = outputs[..., ::-1]
return outputs

@preprocessing_function
def call(
Expand All @@ -154,6 +171,13 @@ def call(
batch_size = tf.shape(x)[0]
sequence_length = sequence_length or self.sequence_length
dtype = inputs.dtype
# Truncate.
truncation_length = sequence_length
if add_start_value and self.start_value is not None:
truncation_length -= len(self.start_value)
if add_end_value and self.end_value is not None:
truncation_length -= len(self.end_value)
x = x[..., :truncation_length]

# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
Expand All @@ -167,23 +191,26 @@ def call(
end_token_id_tensor = tf.repeat(
end_value[tf.newaxis, :], repeats=batch_size, axis=0
)
# Trim to leave room for end token.
x = x[..., : sequence_length - len(self.end_value)]
x = tf.concat([x, end_token_id_tensor], axis=-1)

# Pad to desired length.
outputs = x.to_tensor(
default_value=self.pad_value,
outputs = self.pad(
x,
shape=(batch_size, sequence_length),
pad_value=self.pad_value,
)
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs

if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
mask = mask.to_tensor(shape=(batch_size, sequence_length))

mask = self.pad(
mask,
shape=(batch_size, sequence_length),
pad_value=False,
)
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask

return outputs

def get_config(self):
Expand All @@ -195,6 +222,7 @@ def get_config(self):
"end_value": self._end_value,
"pad_value": self.pad_value,
"return_padding_mask": self.return_padding_mask,
"padding_side": self.padding_side,
}
)
return config
Expand Down
177 changes: 177 additions & 0 deletions keras_hub/src/layers/preprocessing/start_end_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@

class StartEndPackerTest(TestCase):
def test_dense_input(self):
# right padding
input_data = [5, 6, 7]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [5, 6, 7, 0, 0]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [0, 0, 5, 6, 7]
self.assertAllEqual(output, expected_output)

def test_bfloat16_dtype(self):
# Core Keras has a strange bug where it converts int to floats in
Expand All @@ -21,29 +29,54 @@ def test_bfloat16_dtype(self):
self.assertDTypeEqual(output, "int32")

def test_dense_2D_input(self):
# right padding
input_data = [[5, 6, 7]]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0]]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [[0, 0, 5, 6, 7]]
self.assertAllEqual(output, expected_output)

def test_ragged_input(self):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0], [8, 9, 10, 11, 0]]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]]
self.assertAllEqual(output, expected_output)

def test_start_end_token(self):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=6, start_value=1, end_value=2
)
output = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
sequence_length=6, start_value=1, end_value=2, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)

def test_multiple_start_end_tokens(self):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
sequence_length=8,
Expand All @@ -55,7 +88,20 @@ def test_multiple_start_end_tokens(self):
expected_output = [[1, 2, 5, 6, 7, 3, 4, 0], [1, 2, 8, 9, 10, 11, 3, 4]]
self.assertAllEqual(output, expected_output)

# left padding
start_end_packer = StartEndPacker(
sequence_length=8,
start_value=[1, 2],
end_value=[3, 4],
pad_value=0,
padding_side="left",
)
output = start_end_packer(input_data)
expected_output = [[0, 1, 2, 5, 6, 7, 3, 4], [1, 2, 8, 9, 10, 11, 3, 4]]
self.assertAllEqual(output, expected_output)

def test_start_end_padding_value(self):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=7, start_value=1, end_value=2, pad_value=3
Expand All @@ -64,7 +110,58 @@ def test_start_end_padding_value(self):
expected_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]]
self.assertAllEqual(output, expected_output)

# left padding
start_end_packer = StartEndPacker(
sequence_length=7,
start_value=1,
end_value=2,
pad_value=3,
padding_side="left",
)
output = start_end_packer(input_data)
expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)

def test_truncation_side_flips(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need all of these. We might just need test_truncation and test_truncation_without_end_value for these next three tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need all of these. We might just need and for these next three tests.test_truncation``test_truncation_without_end_value

I did both left and right scenarios in all the tests. I think it's still necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's at least remove the "side_flips" from the name. I don't think any reader would understand what that means. test_truncation and test_truncation_without_endvalue

# right padding
input_data = list(range(10))
packer = StartEndPacker(
sequence_length=7,
start_value=98,
end_value=99,
)
expected_output = [98, 0, 1, 2, 3, 4, 99]
self.assertAllEqual(packer(input_data), expected_output)

# left padding
packer = StartEndPacker(
sequence_length=7,
start_value=98,
end_value=99,
padding_side="left",
)
self.assertAllEqual(packer(input_data), expected_output)

def test_truncation_side_flips_wo_endvalue(self):
# right padding
input_data = list(range(10))
packer = StartEndPacker(
sequence_length=7,
start_value=98,
)
expected_output = [98, 0, 1, 2, 3, 4, 5]
self.assertAllEqual(packer(input_data), expected_output)

# left padding
packer = StartEndPacker(
sequence_length=7,
start_value=98,
padding_side="left",
)
self.assertAllEqual(packer(input_data), expected_output)

def test_end_token_value_during_truncation(self):
# right padding
input_data = [[5, 6], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
sequence_length=5, start_value=1, end_value=2, pad_value=0
Expand All @@ -73,7 +170,20 @@ def test_end_token_value_during_truncation(self):
expected_output = [[1, 5, 6, 2, 0], [1, 8, 9, 10, 2]]
self.assertAllEqual(output, expected_output)

# left padding
start_end_packer = StartEndPacker(
sequence_length=5,
start_value=1,
end_value=2,
pad_value=0,
padding_side="left",
)
output = start_end_packer(input_data)
expected_output = [[0, 1, 5, 6, 2], [1, 8, 9, 10, 2]]
self.assertAllEqual(output, expected_output)

def test_string_input(self):
# right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
sequence_length=5,
Expand All @@ -88,7 +198,23 @@ def test_string_input(self):
]
self.assertAllEqual(output, expected_output)

# left padding
start_end_packer = StartEndPacker(
sequence_length=5,
start_value="[START]",
end_value="[END]",
pad_value="[PAD]",
padding_side="left",
)
output = start_end_packer(input_data)
expected_output = [
["[START]", "KerasHub", "is", "awesome", "[END]"],
["[PAD]", "[PAD]", "[START]", "amazing", "[END]"],
]
self.assertAllEqual(output, expected_output)

def test_string_input_with_multiple_special_values(self):
# right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
sequence_length=6,
Expand All @@ -103,6 +229,21 @@ def test_string_input_with_multiple_special_values(self):
]
self.assertAllEqual(output, expected_output)

# left padding
start_end_packer = StartEndPacker(
sequence_length=6,
start_value=["[END]", "[START]"],
end_value="[END]",
pad_value="[PAD]",
padding_side="left",
)
output = start_end_packer(input_data)
expected_output = [
["[END]", "[START]", "KerasHub", "is", "awesome", "[END]"],
["[PAD]", "[PAD]", "[END]", "[START]", "amazing", "[END]"],
]
self.assertAllEqual(output, expected_output)

def test_special_token_dtype_error(self):
with self.assertRaises(ValueError):
StartEndPacker(sequence_length=5, start_value=1.0)
Expand Down Expand Up @@ -147,3 +288,39 @@ def test_get_config(self):
}

self.assertEqual(config, {**config, **expected_config_subset})

def test_return_padding_mask_right_padding(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In keeping with the other tests, just leave these in the same test with a # right padding and # left padding comment.

input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=6,
start_value=1,
end_value=2,
return_padding_mask=True,
)
output, padding_mask = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
expected_padding_mask = [
[True, True, True, True, True, False],
[True, True, True, True, True, True],
]
print(padding_mask)
self.assertAllEqual(output, expected_output)
self.assertAllEqual(padding_mask, expected_padding_mask)

def test_return_padding_mask_left_padding(self):
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=6,
start_value=1,
end_value=2,
return_padding_mask=True,
padding_side="left",
)
output, padding_mask = start_end_packer(input_data)
expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
expected_padding_mask = [
[False, True, True, True, True, True],
[True, True, True, True, True, True],
]
self.assertAllEqual(output, expected_output)
self.assertAllEqual(padding_mask, expected_padding_mask)
Loading