Skip to content

implement of leftpadding #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions keras_hub/src/layers/preprocessing/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function

try:
Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(
sep_value=None,
pad_value=None,
truncate="round_robin",
padding_side="right",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a docstring

**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -163,6 +165,8 @@ def check_special_value_type(value, value_name):

self.pad_value = pad_value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove this empty line

self.padding_side = padding_side

def get_config(self):
config = super().get_config()
config.update(
Expand All @@ -173,6 +177,7 @@ def get_config(self):
"sep_value": self._sep_value,
"pad_value": self.pad_value,
"truncate": self.truncate,
"padding_side": self.padding_side,
}
)
return config
Expand Down Expand Up @@ -287,10 +292,18 @@ def call(
# Pad to dense tensor output.
sequence_length = sequence_length or self.sequence_length
shape = tf.cast([-1, sequence_length], "int64")
token_ids = token_ids.to_tensor(
shape=shape, default_value=self.pad_value
token_ids = pad(
token_ids,
shape=shape,
padding_side=self.padding_side,
pad_value=self.pad_value,
)
segment_ids = pad(
segment_ids,
shape=shape,
padding_side=self.padding_side,
pad_value=0,
)
segment_ids = segment_ids.to_tensor(shape=shape)
# Remove the batch dim if added.
if unbatched:
token_ids = tf.squeeze(token_ids, 0)
Expand Down
173 changes: 173 additions & 0 deletions keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class MultiSegmentPackerTest(TestCase):
def test_trim_single_input_ints(self):
# right padding
input_data = np.arange(3, 10)
packer = MultiSegmentPacker(
sequence_length=8, start_value=1, end_value=2
Expand All @@ -16,7 +17,20 @@ def test_trim_single_input_ints(self):
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])

# left padding
input_data = np.arange(3, 10)
packer = MultiSegmentPacker(
sequence_length=8,
start_value=1,
end_value=2,
padding_side="left",
)
token_ids, segment_ids = packer(input_data)
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])

def test_trim_single_input_strings(self):
# right padding
input_data = ["a", "b", "c", "d"]
packer = MultiSegmentPacker(
sequence_length=5, start_value="[CLS]", end_value="[SEP]"
Expand All @@ -25,7 +39,19 @@ def test_trim_single_input_strings(self):
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])

# left padding
packer = MultiSegmentPacker(
sequence_length=5,
start_value="[CLS]",
end_value="[SEP]",
padding_side="left",
)
token_ids, segment_ids = packer(input_data)
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])

def test_trim_multiple_inputs_round_robin(self):
# right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
Expand All @@ -40,7 +66,22 @@ def test_trim_multiple_inputs_round_robin(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="round_robin",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids, ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"]
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])

def test_trim_multiple_inputs_waterfall(self):
# right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
Expand All @@ -55,7 +96,22 @@ def test_trim_multiple_inputs_waterfall(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="waterfall",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids, ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"]
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])

def test_trim_batched_inputs_round_robin(self):
# right padding
seq1 = [["a", "b", "c"], ["a", "b", "c"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
Expand All @@ -80,7 +136,32 @@ def test_trim_batched_inputs_round_robin(self):
],
)

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="round_robin",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1],
],
)

def test_trim_batched_inputs_waterfall(self):
# right padding
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
Expand All @@ -105,7 +186,32 @@ def test_trim_batched_inputs_waterfall(self):
],
)

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="waterfall",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"],
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1, 1],
],
)

def test_pad_inputs(self):
# right padding
seq1 = ["a"]
seq2 = ["x"]
packer = MultiSegmentPacker(
Expand All @@ -118,7 +224,23 @@ def test_pad_inputs(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0])

# left padding
packer = MultiSegmentPacker(
6,
start_value="[CLS]",
end_value="[SEP]",
pad_value="[PAD]",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1])

def test_pad_batched_inputs(self):
# right padding
seq1 = [["a"], ["a"]]
seq2 = [["x"], ["x", "y"]]
packer = MultiSegmentPacker(
Expand All @@ -143,7 +265,32 @@ def test_pad_batched_inputs(self):
],
)

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
pad_value="[PAD]",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1, 1],
],
)

def test_list_special_tokens(self):
# right padding
seq1 = [["a", "b"], ["a", "b"]]
seq2 = [["x", "y"], ["x"]]
packer = MultiSegmentPacker(
Expand All @@ -170,6 +317,32 @@ def test_list_special_tokens(self):
],
)

# left padding
packer = MultiSegmentPacker(
8,
start_value="<s>",
end_value="</s>",
sep_value=["</s>", "</s>"],
pad_value="<pad>",
truncate="round_robin",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["<s>", "a", "b", "</s>", "</s>", "x", "y", "</s>"],
["<pad>", "<s>", "a", "b", "</s>", "</s>", "x", "</s>"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1],
],
)

def test_config(self):
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
Expand Down
30 changes: 24 additions & 6 deletions keras_hub/src/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function

try:
Expand Down Expand Up @@ -39,6 +40,8 @@ class StartEndPacker(PreprocessingLayer):
0 or "" will be added depending on the dtype of the input tensor.
return_padding_mask: bool. Whether to return a boolean padding mask of
all locations that are filled in with the `pad_value`.
padding_side: str. Whether to pad the input on the "left" or "right".
Defaults to "right".

Call arguments:
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
Expand Down Expand Up @@ -111,6 +114,7 @@ def __init__(
pad_value=None,
return_padding_mask=False,
name=None,
padding_side="right",
**kwargs,
):
super().__init__(name=name, **kwargs)
Expand Down Expand Up @@ -139,6 +143,7 @@ def check_special_value_type(value, value_name):

self.pad_value = pad_value
self.return_padding_mask = return_padding_mask
self.padding_side = padding_side

@preprocessing_function
def call(
Expand All @@ -154,6 +159,13 @@ def call(
batch_size = tf.shape(x)[0]
sequence_length = sequence_length or self.sequence_length
dtype = inputs.dtype
# Truncate.
truncation_length = sequence_length
if add_start_value and self.start_value is not None:
truncation_length -= len(self.start_value)
if add_end_value and self.end_value is not None:
truncation_length -= len(self.end_value)
x = x[..., :truncation_length]

# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
Expand All @@ -167,23 +179,28 @@ def call(
end_token_id_tensor = tf.repeat(
end_value[tf.newaxis, :], repeats=batch_size, axis=0
)
# Trim to leave room for end token.
x = x[..., : sequence_length - len(self.end_value)]
x = tf.concat([x, end_token_id_tensor], axis=-1)

# Pad to desired length.
outputs = x.to_tensor(
default_value=self.pad_value,
outputs = pad(
x,
pad_value=self.pad_value,
padding_side=self.padding_side,
shape=(batch_size, sequence_length),
)
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs

if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
mask = mask.to_tensor(shape=(batch_size, sequence_length))

mask = pad(
mask,
pad_value=False,
padding_side=self.padding_side,
shape=(batch_size, sequence_length),
)
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask

return outputs

def get_config(self):
Expand All @@ -195,6 +212,7 @@ def get_config(self):
"end_value": self._end_value,
"pad_value": self.pad_value,
"return_padding_mask": self.return_padding_mask,
"padding_side": self.padding_side,
}
)
return config
Expand Down
Loading
Loading