-
Notifications
You must be signed in to change notification settings - Fork 286
implement of leftpadding #2242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement of leftpadding #2242
Changes from 6 commits
299102c
59627a4
5d1b2c0
97cada7
85bb256
6ab5ea9
f1c55ac
8c40279
7ceef48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,8 @@ class StartEndPacker(PreprocessingLayer): | |
0 or "" will be added depending on the dtype of the input tensor. | ||
return_padding_mask: bool. Whether to return a boolean padding mask of | ||
all locations that are filled in with the `pad_value`. | ||
padding_side: str. Whether to pad the input on the "left" or "right". | ||
Defaults to "right". | ||
|
||
Call arguments: | ||
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings. | ||
|
@@ -111,6 +113,7 @@ def __init__( | |
pad_value=None, | ||
return_padding_mask=False, | ||
name=None, | ||
padding_side="right", | ||
**kwargs, | ||
): | ||
super().__init__(name=name, **kwargs) | ||
|
@@ -139,6 +142,20 @@ def check_special_value_type(value, value_name): | |
|
||
self.pad_value = pad_value | ||
self.return_padding_mask = return_padding_mask | ||
self.padding_side = padding_side | ||
|
||
def pad(self, x, shape, pad_value): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pull this into a util in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary. Only Bert-like models are using multi-segment packers. For Bert-like models, there is no essential difference between left padding and right padding. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want it first and foremost for uniformity of the API. But also, this is not just for BERT-like. Gemma3 and PaliGemma use it, for example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
OK, I have fulfilled this requirement. Please check |
||
if self.padding_side == "left": | ||
x = x[..., ::-1] | ||
|
||
outputs = x.to_tensor( | ||
default_value=pad_value, | ||
shape=shape, | ||
) | ||
|
||
if self.padding_side == "left": | ||
outputs = outputs[..., ::-1] | ||
return outputs | ||
|
||
@preprocessing_function | ||
def call( | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -154,6 +171,13 @@ def call( | |
batch_size = tf.shape(x)[0] | ||
sequence_length = sequence_length or self.sequence_length | ||
dtype = inputs.dtype | ||
# Truncate. | ||
truncation_length = sequence_length | ||
if add_start_value and self.start_value is not None: | ||
truncation_length -= len(self.start_value) | ||
if add_end_value and self.end_value is not None: | ||
truncation_length -= len(self.end_value) | ||
x = x[..., :truncation_length] | ||
|
||
# Concatenate start and end tokens. | ||
if add_start_value and self.start_value is not None: | ||
|
@@ -167,23 +191,26 @@ def call( | |
end_token_id_tensor = tf.repeat( | ||
end_value[tf.newaxis, :], repeats=batch_size, axis=0 | ||
) | ||
# Trim to leave room for end token. | ||
x = x[..., : sequence_length - len(self.end_value)] | ||
x = tf.concat([x, end_token_id_tensor], axis=-1) | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Pad to desired length. | ||
outputs = x.to_tensor( | ||
default_value=self.pad_value, | ||
outputs = self.pad( | ||
x, | ||
shape=(batch_size, sequence_length), | ||
pad_value=self.pad_value, | ||
) | ||
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs | ||
|
||
if self.return_padding_mask: | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask = tf.ones_like(x, dtype="bool") | ||
mask = mask.to_tensor(shape=(batch_size, sequence_length)) | ||
|
||
mask = self.pad( | ||
mask, | ||
shape=(batch_size, sequence_length), | ||
pad_value=False, | ||
) | ||
mask = tf.squeeze(mask, axis=0) if unbatched else mask | ||
return outputs, mask | ||
|
||
return outputs | ||
|
||
def get_config(self): | ||
|
@@ -195,6 +222,7 @@ def get_config(self): | |
"end_value": self._end_value, | ||
"pad_value": self.pad_value, | ||
"return_padding_mask": self.return_padding_mask, | ||
"padding_side": self.padding_side, | ||
} | ||
) | ||
return config | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,19 @@ | |
|
||
class StartEndPackerTest(TestCase): | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def test_dense_input(self): | ||
# right padding | ||
input_data = [5, 6, 7] | ||
start_end_packer = StartEndPacker(sequence_length=5) | ||
output = start_end_packer(input_data) | ||
expected_output = [5, 6, 7, 0, 0] | ||
self.assertAllEqual(output, expected_output) | ||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=5, padding_side="left" | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [0, 0, 5, 6, 7] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_bfloat16_dtype(self): | ||
# Core Keras has a strange bug where it converts int to floats in | ||
|
@@ -21,29 +29,54 @@ def test_bfloat16_dtype(self): | |
self.assertDTypeEqual(output, "int32") | ||
|
||
def test_dense_2D_input(self): | ||
# right padding | ||
input_data = [[5, 6, 7]] | ||
start_end_packer = StartEndPacker(sequence_length=5) | ||
output = start_end_packer(input_data) | ||
expected_output = [[5, 6, 7, 0, 0]] | ||
self.assertAllEqual(output, expected_output) | ||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=5, padding_side="left" | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [[0, 0, 5, 6, 7]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_ragged_input(self): | ||
# right padding | ||
input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
start_end_packer = StartEndPacker(sequence_length=5) | ||
output = start_end_packer(input_data) | ||
expected_output = [[5, 6, 7, 0, 0], [8, 9, 10, 11, 0]] | ||
self.assertAllEqual(output, expected_output) | ||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=5, padding_side="left" | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_start_end_token(self): | ||
# right padding | ||
input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=6, start_value=1, end_value=2 | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]] | ||
self.assertAllEqual(output, expected_output) | ||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=6, start_value=1, end_value=2, padding_side="left" | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_multiple_start_end_tokens(self): | ||
# right padding | ||
input_data = [[5, 6, 7], [8, 9, 10, 11, 12, 13]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=8, | ||
|
@@ -55,7 +88,20 @@ def test_multiple_start_end_tokens(self): | |
expected_output = [[1, 2, 5, 6, 7, 3, 4, 0], [1, 2, 8, 9, 10, 11, 3, 4]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=8, | ||
start_value=[1, 2], | ||
end_value=[3, 4], | ||
pad_value=0, | ||
padding_side="left", | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [[0, 1, 2, 5, 6, 7, 3, 4], [1, 2, 8, 9, 10, 11, 3, 4]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_start_end_padding_value(self): | ||
# right padding | ||
input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=7, start_value=1, end_value=2, pad_value=3 | ||
|
@@ -64,7 +110,58 @@ def test_start_end_padding_value(self): | |
expected_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=7, | ||
start_value=1, | ||
end_value=2, | ||
pad_value=3, | ||
padding_side="left", | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_truncation_side_flips(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we need all of these. We might just need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I did both left and right scenarios in all the tests. I think it's still necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's at least remove the "side_flips" from the name. I don't think any reader would understand what that means. |
||
# right padding | ||
input_data = list(range(10)) | ||
packer = StartEndPacker( | ||
sequence_length=7, | ||
start_value=98, | ||
end_value=99, | ||
) | ||
expected_output = [98, 0, 1, 2, 3, 4, 99] | ||
self.assertAllEqual(packer(input_data), expected_output) | ||
|
||
# left padding | ||
packer = StartEndPacker( | ||
sequence_length=7, | ||
start_value=98, | ||
end_value=99, | ||
padding_side="left", | ||
) | ||
self.assertAllEqual(packer(input_data), expected_output) | ||
|
||
def test_truncation_side_flips_wo_endvalue(self): | ||
# right padding | ||
input_data = list(range(10)) | ||
packer = StartEndPacker( | ||
sequence_length=7, | ||
start_value=98, | ||
) | ||
expected_output = [98, 0, 1, 2, 3, 4, 5] | ||
self.assertAllEqual(packer(input_data), expected_output) | ||
|
||
# left padding | ||
packer = StartEndPacker( | ||
sequence_length=7, | ||
start_value=98, | ||
padding_side="left", | ||
) | ||
self.assertAllEqual(packer(input_data), expected_output) | ||
|
||
def test_end_token_value_during_truncation(self): | ||
# right padding | ||
input_data = [[5, 6], [8, 9, 10, 11, 12, 13]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=5, start_value=1, end_value=2, pad_value=0 | ||
|
@@ -73,7 +170,20 @@ def test_end_token_value_during_truncation(self): | |
expected_output = [[1, 5, 6, 2, 0], [1, 8, 9, 10, 2]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=5, | ||
start_value=1, | ||
end_value=2, | ||
pad_value=0, | ||
padding_side="left", | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [[0, 1, 5, 6, 2], [1, 8, 9, 10, 2]] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_string_input(self): | ||
# right padding | ||
input_data = [["KerasHub", "is", "awesome"], ["amazing"]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=5, | ||
|
@@ -88,7 +198,23 @@ def test_string_input(self): | |
] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=5, | ||
start_value="[START]", | ||
end_value="[END]", | ||
pad_value="[PAD]", | ||
padding_side="left", | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [ | ||
["[START]", "KerasHub", "is", "awesome", "[END]"], | ||
["[PAD]", "[PAD]", "[START]", "amazing", "[END]"], | ||
] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_string_input_with_multiple_special_values(self): | ||
# right padding | ||
input_data = [["KerasHub", "is", "awesome"], ["amazing"]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=6, | ||
|
@@ -103,6 +229,21 @@ def test_string_input_with_multiple_special_values(self): | |
] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
# left padding | ||
start_end_packer = StartEndPacker( | ||
sequence_length=6, | ||
start_value=["[END]", "[START]"], | ||
end_value="[END]", | ||
pad_value="[PAD]", | ||
padding_side="left", | ||
) | ||
output = start_end_packer(input_data) | ||
expected_output = [ | ||
["[END]", "[START]", "KerasHub", "is", "awesome", "[END]"], | ||
["[PAD]", "[PAD]", "[END]", "[START]", "amazing", "[END]"], | ||
] | ||
self.assertAllEqual(output, expected_output) | ||
|
||
def test_special_token_dtype_error(self): | ||
with self.assertRaises(ValueError): | ||
StartEndPacker(sequence_length=5, start_value=1.0) | ||
|
@@ -147,3 +288,39 @@ def test_get_config(self): | |
} | ||
|
||
self.assertEqual(config, {**config, **expected_config_subset}) | ||
|
||
def test_return_padding_mask_right_padding(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In keeping with the other tests, just leave these in the same test with a |
||
input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=6, | ||
start_value=1, | ||
end_value=2, | ||
return_padding_mask=True, | ||
) | ||
output, padding_mask = start_end_packer(input_data) | ||
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]] | ||
expected_padding_mask = [ | ||
[True, True, True, True, True, False], | ||
[True, True, True, True, True, True], | ||
] | ||
print(padding_mask) | ||
self.assertAllEqual(output, expected_output) | ||
self.assertAllEqual(padding_mask, expected_padding_mask) | ||
|
||
def test_return_padding_mask_left_padding(self): | ||
input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
start_end_packer = StartEndPacker( | ||
sequence_length=6, | ||
start_value=1, | ||
end_value=2, | ||
return_padding_mask=True, | ||
padding_side="left", | ||
) | ||
output, padding_mask = start_end_packer(input_data) | ||
expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]] | ||
expected_padding_mask = [ | ||
[False, True, True, True, True, True], | ||
[True, True, True, True, True, True], | ||
] | ||
self.assertAllEqual(output, expected_output) | ||
self.assertAllEqual(padding_mask, expected_padding_mask) |
Uh oh!
There was an error while loading. Please reload this page.