From 299102c9ac4bb3e36b547975b86acfe2df659731 Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Fri, 2 May 2025 20:55:59 +0800
Subject: [PATCH 1/9] implement of leftpadding
---
.../layers/preprocessing/start_end_packer.py | 18 ++-
.../preprocessing/start_end_packer_test.py | 103 ++++++++++++++++++
2 files changed, 117 insertions(+), 4 deletions(-)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index f50f09b7f1..53181c2bf3 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -111,6 +111,7 @@ def __init__(
pad_value=None,
return_padding_mask=False,
name=None,
+ padding_side="right",
**kwargs,
):
super().__init__(name=name, **kwargs)
@@ -139,6 +140,7 @@ def check_special_value_type(value, value_name):
self.pad_value = pad_value
self.return_padding_mask = return_padding_mask
+ self.padding_side = padding_side
@preprocessing_function
def call(
@@ -172,10 +174,17 @@ def call(
x = tf.concat([x, end_token_id_tensor], axis=-1)
# Pad to desired length.
- outputs = x.to_tensor(
- default_value=self.pad_value,
- shape=(batch_size, sequence_length),
- )
+ if self.padding_side == "left":
+ outputs = x[..., ::-1].to_tensor(
+ default_value=self.pad_value,
+ shape=(batch_size, sequence_length),
+ )
+ outputs = outputs[..., ::-1]
+ else:
+ outputs = x.to_tensor(
+ default_value=self.pad_value,
+ shape=(batch_size, sequence_length),
+ )
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
if self.return_padding_mask:
@@ -195,6 +204,7 @@ def get_config(self):
"end_value": self._end_value,
"pad_value": self.pad_value,
"return_padding_mask": self.return_padding_mask,
+ "padding_side": self.padding_side,
}
)
return config
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
index 3b0ea65a9c..d6f44303cd 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
@@ -6,11 +6,19 @@
class StartEndPackerTest(TestCase):
def test_dense_input(self):
+ # right padding
input_data = [5, 6, 7]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [5, 6, 7, 0, 0]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [0, 0, 5, 6, 7]
+ self.assertAllEqual(output, expected_output)
def test_bfloat16_dtype(self):
# Core Keras has a strange bug where it converts int to floats in
@@ -21,20 +29,37 @@ def test_bfloat16_dtype(self):
self.assertDTypeEqual(output, "int32")
def test_dense_2D_input(self):
+ # right padding
input_data = [[5, 6, 7]]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 0, 5, 6, 7]]
+ self.assertAllEqual(output, expected_output)
def test_ragged_input(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0], [8, 9, 10, 11, 0]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]]
+ self.assertAllEqual(output, expected_output)
def test_start_end_token(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=6, start_value=1, end_value=2
@@ -42,8 +67,16 @@ def test_start_end_token(self):
output = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=6, start_value=1, end_value=2, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
+ self.assertAllEqual(output, expected_output)
def test_multiple_start_end_tokens(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
sequence_length=8,
@@ -55,7 +88,20 @@ def test_multiple_start_end_tokens(self):
expected_output = [[1, 2, 5, 6, 7, 3, 4, 0], [1, 2, 8, 9, 10, 11, 3, 4]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=8,
+ start_value=[1, 2],
+ end_value=[3, 4],
+ pad_value=0,
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 1, 2, 5, 6, 7, 3, 4], [1, 2, 8, 9, 10, 11, 3, 4]]
+ self.assertAllEqual(output, expected_output)
+
def test_start_end_padding_value(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=7, start_value=1, end_value=2, pad_value=3
@@ -64,7 +110,20 @@ def test_start_end_padding_value(self):
expected_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=7,
+ start_value=1,
+ end_value=2,
+ pad_value=3,
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]]
+ self.assertAllEqual(output, expected_output)
+
def test_end_token_value_during_truncation(self):
+ # right padding
input_data = [[5, 6], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
sequence_length=5, start_value=1, end_value=2, pad_value=0
@@ -73,7 +132,20 @@ def test_end_token_value_during_truncation(self):
expected_output = [[1, 5, 6, 2, 0], [1, 8, 9, 10, 2]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5,
+ start_value=1,
+ end_value=2,
+ pad_value=0,
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 1, 5, 6, 2], [1, 8, 9, 10, 2]]
+ self.assertAllEqual(output, expected_output)
+
def test_string_input(self):
+ # right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
sequence_length=5,
@@ -88,7 +160,23 @@ def test_string_input(self):
]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5,
+ start_value="[START]",
+ end_value="[END]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [
+ ["[START]", "KerasHub", "is", "awesome", "[END]"],
+ ["[PAD]", "[PAD]", "[START]", "amazing", "[END]"],
+ ]
+ self.assertAllEqual(output, expected_output)
+
def test_string_input_with_multiple_special_values(self):
+ # right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
sequence_length=6,
@@ -103,6 +191,21 @@ def test_string_input_with_multiple_special_values(self):
]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=6,
+ start_value=["[END]", "[START]"],
+ end_value="[END]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [
+ ["[END]", "[START]", "KerasHub", "is", "awesome", "[END]"],
+ ["[PAD]", "[PAD]", "[END]", "[START]", "amazing", "[END]"],
+ ]
+ self.assertAllEqual(output, expected_output)
+
def test_special_token_dtype_error(self):
with self.assertRaises(ValueError):
StartEndPacker(sequence_length=5, start_value=1.0)
From 59627a46106409163d817f9b391547f9317fccb3 Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Sat, 3 May 2025 15:46:54 +0800
Subject: [PATCH 2/9] add doc
---
keras_hub/src/layers/preprocessing/start_end_packer.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index 53181c2bf3..f4f2c014fe 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -39,6 +39,8 @@ class StartEndPacker(PreprocessingLayer):
0 or "" will be added depending on the dtype of the input tensor.
return_padding_mask: bool. Whether to return a boolean padding mask of
all locations that are filled in with the `pad_value`.
+ padding_side: str. Whether to pad the input on the "left" or "right".
+ Defaults to "right".
Call arguments:
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
From 5d1b2c0f48a8c2f9f203acb336003fd6172902fe Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Thu, 8 May 2025 13:36:27 +0800
Subject: [PATCH 3/9] update
---
.../layers/preprocessing/start_end_packer.py | 34 +++++++++--------
.../preprocessing/start_end_packer_test.py | 38 +++++++++++++++++++
2 files changed, 57 insertions(+), 15 deletions(-)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index f4f2c014fe..3b981c4109 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -158,35 +158,40 @@ def call(
batch_size = tf.shape(x)[0]
sequence_length = sequence_length or self.sequence_length
dtype = inputs.dtype
-
+ if self.padding_side == "left":
+ x = x[..., ::-1]
# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
start_token_id_tensor = tf.repeat(
start_value[tf.newaxis, :], repeats=batch_size, axis=0
)
- x = tf.concat([start_token_id_tensor, x], axis=-1)
+ if self.padding_side == "left":
+ x = tf.concat([x, start_token_id_tensor[..., ::-1]], axis=-1)
+ else:
+ x = tf.concat([start_token_id_tensor, x], axis=-1)
if add_end_value and self.end_value is not None:
end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)
end_token_id_tensor = tf.repeat(
end_value[tf.newaxis, :], repeats=batch_size, axis=0
)
# Trim to leave room for end token.
- x = x[..., : sequence_length - len(self.end_value)]
- x = tf.concat([x, end_token_id_tensor], axis=-1)
-
+ if self.padding_side == "left":
+ x = x[..., -(sequence_length - len(self.end_value)) :]
+ x = tf.concat([end_token_id_tensor[..., ::-1], x], axis=-1)
+ else:
+ x = x[..., : sequence_length - len(self.end_value)]
+ x = tf.concat([x, end_token_id_tensor], axis=-1)
+ if self.padding_side == "left":
+ x = x[..., -sequence_length:]
+ outputs = x.to_tensor(
+ default_value=self.pad_value,
+ shape=(batch_size, sequence_length),
+ )
# Pad to desired length.
if self.padding_side == "left":
- outputs = x[..., ::-1].to_tensor(
- default_value=self.pad_value,
- shape=(batch_size, sequence_length),
- )
outputs = outputs[..., ::-1]
- else:
- outputs = x.to_tensor(
- default_value=self.pad_value,
- shape=(batch_size, sequence_length),
- )
+
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
if self.return_padding_mask:
@@ -194,7 +199,6 @@ def call(
mask = mask.to_tensor(shape=(batch_size, sequence_length))
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask
-
return outputs
def get_config(self):
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
index d6f44303cd..277cb6026d 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
@@ -122,6 +122,44 @@ def test_start_end_padding_value(self):
expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
+ def test_truncation_side_flips(self):
+ # right padding
+ input_data = list(range(10))
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ end_value=99,
+ )
+ expected_output = [98, 0, 1, 2, 3, 4, 99]
+ self.assertAllEqual(packer(input_data), expected_output)
+
+ # left padding
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ end_value=99,
+ padding_side="left",
+ )
+ self.assertAllEqual(packer(input_data), expected_output)
+
+ def test_truncation_side_flips_wo_endvalue(self):
+ # right padding
+ input_data = list(range(10))
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ )
+ expected_output = [98, 0, 1, 2, 3, 4, 5]
+ self.assertAllEqual(packer(input_data), expected_output)
+
+ # left padding
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ padding_side="left",
+ )
+ self.assertAllEqual(packer(input_data), expected_output)
+
def test_end_token_value_during_truncation(self):
# right padding
input_data = [[5, 6], [8, 9, 10, 11, 12, 13]]
From 97cada723f58db347fef6817493431990107f54e Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Tue, 20 May 2025 20:08:00 +0800
Subject: [PATCH 4/9] fix
---
.../layers/preprocessing/start_end_packer.py | 51 ++++++++++---------
.../preprocessing/start_end_packer_test.py | 31 +++++++++++
2 files changed, 59 insertions(+), 23 deletions(-)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index 3b981c4109..0c3d918c6c 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -143,7 +143,16 @@ def check_special_value_type(value, value_name):
self.pad_value = pad_value
self.return_padding_mask = return_padding_mask
self.padding_side = padding_side
-
+ def pad(self,x, shape):
+ if self.padding_side == "left":
+ x = x[...,::-1]
+ outputs = x.to_tensor(
+ default_value=self.pad_value,
+ shape=shape,
+ )
+ if self.padding_side == "left":
+ outputs = outputs[..., ::-1]
+ return outputs
@preprocessing_function
def call(
self,
@@ -158,45 +167,41 @@ def call(
batch_size = tf.shape(x)[0]
sequence_length = sequence_length or self.sequence_length
dtype = inputs.dtype
- if self.padding_side == "left":
- x = x[..., ::-1]
+ # Truncate.
+ truncation_length = sequence_length
+ if add_start_value and self.start_value is not None:
+ truncation_length -= len(self.start_value)
+ if add_end_value and self.end_value is not None:
+ truncation_length -= len(self.end_value)
+ x = x[..., : truncation_length]
+
# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
start_token_id_tensor = tf.repeat(
start_value[tf.newaxis, :], repeats=batch_size, axis=0
)
- if self.padding_side == "left":
- x = tf.concat([x, start_token_id_tensor[..., ::-1]], axis=-1)
- else:
- x = tf.concat([start_token_id_tensor, x], axis=-1)
+ x = tf.concat([start_token_id_tensor, x], axis=-1)
if add_end_value and self.end_value is not None:
end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)
end_token_id_tensor = tf.repeat(
end_value[tf.newaxis, :], repeats=batch_size, axis=0
)
- # Trim to leave room for end token.
- if self.padding_side == "left":
- x = x[..., -(sequence_length - len(self.end_value)) :]
- x = tf.concat([end_token_id_tensor[..., ::-1], x], axis=-1)
- else:
- x = x[..., : sequence_length - len(self.end_value)]
- x = tf.concat([x, end_token_id_tensor], axis=-1)
- if self.padding_side == "left":
- x = x[..., -sequence_length:]
- outputs = x.to_tensor(
- default_value=self.pad_value,
+ x = tf.concat([x, end_token_id_tensor], axis=-1)
+
+ # Pad to desired length.
+ outputs = self.pad(
+ x,
shape=(batch_size, sequence_length),
)
- # Pad to desired length.
- if self.padding_side == "left":
- outputs = outputs[..., ::-1]
-
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
- mask = mask.to_tensor(shape=(batch_size, sequence_length))
+ mask =self.pad(
+ mask,
+ shape=(batch_size, sequence_length),
+ )
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask
return outputs
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
index 277cb6026d..ab3bf91639 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
@@ -288,3 +288,34 @@ def test_get_config(self):
}
self.assertEqual(config, {**config, **expected_config_subset})
+ def test_return_padding_mask_right_padding(self):
+ input_data = [[5, 6, 7], [8, 9, 10, 11]]
+ start_end_packer = StartEndPacker(
+ sequence_length=6,
+ start_value=1,
+ end_value=2,
+ return_padding_mask=True,
+ )
+ output, padding_mask = start_end_packer(input_data)
+ expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
+ expected_padding_mask = [[True, True, True, True, True, False],
+ [True, True, True, True, True, True]]
+ print(padding_mask)
+ self.assertAllEqual(output, expected_output)
+ self.assertAllEqual(padding_mask, expected_padding_mask)
+
+ def test_return_padding_mask_left_padding(self):
+ input_data = [[5, 6, 7], [8, 9, 10, 11]]
+ start_end_packer = StartEndPacker(
+ sequence_length=6,
+ start_value=1,
+ end_value=2,
+ return_padding_mask=True,
+ padding_side="left",
+ )
+ output, padding_mask = start_end_packer(input_data)
+ expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
+ expected_padding_mask = [[False, True, True, True, True, True],
+ [True, True, True, True, True, True]]
+ self.assertAllEqual(output, expected_output)
+ self.assertAllEqual(padding_mask, expected_padding_mask)
\ No newline at end of file
From 85bb2566ad96d09a29266269df78e5f53149b158 Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Tue, 20 May 2025 20:09:36 +0800
Subject: [PATCH 5/9] format
---
.../layers/preprocessing/start_end_packer.py | 20 ++++++++++---------
.../preprocessing/start_end_packer_test.py | 15 +++++++++-----
2 files changed, 21 insertions(+), 14 deletions(-)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index 0c3d918c6c..b94e513bc2 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -143,16 +143,18 @@ def check_special_value_type(value, value_name):
self.pad_value = pad_value
self.return_padding_mask = return_padding_mask
self.padding_side = padding_side
- def pad(self,x, shape):
+
+ def pad(self, x, shape):
if self.padding_side == "left":
- x = x[...,::-1]
+ x = x[..., ::-1]
outputs = x.to_tensor(
- default_value=self.pad_value,
- shape=shape,
- )
+ default_value=self.pad_value,
+ shape=shape,
+ )
if self.padding_side == "left":
outputs = outputs[..., ::-1]
return outputs
+
@preprocessing_function
def call(
self,
@@ -173,7 +175,7 @@ def call(
truncation_length -= len(self.start_value)
if add_end_value and self.end_value is not None:
truncation_length -= len(self.end_value)
- x = x[..., : truncation_length]
+ x = x[..., :truncation_length]
# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
@@ -191,15 +193,15 @@ def call(
# Pad to desired length.
outputs = self.pad(
- x,
+ x,
shape=(batch_size, sequence_length),
)
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
- mask =self.pad(
- mask,
+ mask = self.pad(
+ mask,
shape=(batch_size, sequence_length),
)
mask = tf.squeeze(mask, axis=0) if unbatched else mask
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
index ab3bf91639..705b8799e2 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
@@ -288,6 +288,7 @@ def test_get_config(self):
}
self.assertEqual(config, {**config, **expected_config_subset})
+
def test_return_padding_mask_right_padding(self):
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
@@ -298,8 +299,10 @@ def test_return_padding_mask_right_padding(self):
)
output, padding_mask = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
- expected_padding_mask = [[True, True, True, True, True, False],
- [True, True, True, True, True, True]]
+ expected_padding_mask = [
+ [True, True, True, True, True, False],
+ [True, True, True, True, True, True],
+ ]
print(padding_mask)
self.assertAllEqual(output, expected_output)
self.assertAllEqual(padding_mask, expected_padding_mask)
@@ -315,7 +318,9 @@ def test_return_padding_mask_left_padding(self):
)
output, padding_mask = start_end_packer(input_data)
expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
- expected_padding_mask = [[False, True, True, True, True, True],
- [True, True, True, True, True, True]]
+ expected_padding_mask = [
+ [False, True, True, True, True, True],
+ [True, True, True, True, True, True],
+ ]
self.assertAllEqual(output, expected_output)
- self.assertAllEqual(padding_mask, expected_padding_mask)
\ No newline at end of file
+ self.assertAllEqual(padding_mask, expected_padding_mask)
From 6ab5ea96609ba716fa0f2989cf76f1fe95d711cc Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Tue, 20 May 2025 20:52:58 +0800
Subject: [PATCH 6/9] format
---
keras_hub/src/layers/preprocessing/start_end_packer.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index b94e513bc2..542eec39ac 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -144,13 +144,15 @@ def check_special_value_type(value, value_name):
self.return_padding_mask = return_padding_mask
self.padding_side = padding_side
- def pad(self, x, shape):
+ def pad(self, x, shape, pad_value):
if self.padding_side == "left":
x = x[..., ::-1]
+
outputs = x.to_tensor(
- default_value=self.pad_value,
+ default_value=pad_value,
shape=shape,
)
+
if self.padding_side == "left":
outputs = outputs[..., ::-1]
return outputs
@@ -195,14 +197,17 @@ def call(
outputs = self.pad(
x,
shape=(batch_size, sequence_length),
+ pad_value=self.pad_value,
)
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
+
mask = self.pad(
mask,
shape=(batch_size, sequence_length),
+ pad_value=False,
)
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask
From f1c55ac5fcfec834875b38b94c465f792294a469 Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Tue, 20 May 2025 23:40:23 +0800
Subject: [PATCH 7/9] update test
---
.../src/layers/preprocessing/start_end_packer_test.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
index 705b8799e2..78f65405f0 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
@@ -122,7 +122,7 @@ def test_start_end_padding_value(self):
expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
- def test_truncation_side_flips(self):
+ def test_truncation(self):
# right padding
input_data = list(range(10))
packer = StartEndPacker(
@@ -142,7 +142,7 @@ def test_truncation_side_flips(self):
)
self.assertAllEqual(packer(input_data), expected_output)
- def test_truncation_side_flips_wo_endvalue(self):
+ def test_truncation_wo_endvalue(self):
# right padding
input_data = list(range(10))
packer = StartEndPacker(
@@ -289,7 +289,8 @@ def test_get_config(self):
self.assertEqual(config, {**config, **expected_config_subset})
- def test_return_padding_mask_right_padding(self):
+ def test_return_padding_mask(self):
+ # right_padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=6,
@@ -307,8 +308,7 @@ def test_return_padding_mask_right_padding(self):
self.assertAllEqual(output, expected_output)
self.assertAllEqual(padding_mask, expected_padding_mask)
- def test_return_padding_mask_left_padding(self):
- input_data = [[5, 6, 7], [8, 9, 10, 11]]
+ # left_padding
start_end_packer = StartEndPacker(
sequence_length=6,
start_value=1,
From 8c40279d7493a63f9d45b9a9ce14b28b6a2886d9 Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Wed, 21 May 2025 09:54:48 +0800
Subject: [PATCH 8/9] add left padding for segment
---
.../preprocessing/multi_segment_packer.py | 19 +-
.../multi_segment_packer_test.py | 173 ++++++++++++++++++
.../layers/preprocessing/start_end_packer.py | 24 +--
keras_hub/src/utils/tensor_utils.py | 14 ++
4 files changed, 210 insertions(+), 20 deletions(-)
diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer.py b/keras_hub/src/layers/preprocessing/multi_segment_packer.py
index 3c0a8e346d..54608292f0 100644
--- a/keras_hub/src/layers/preprocessing/multi_segment_packer.py
+++ b/keras_hub/src/layers/preprocessing/multi_segment_packer.py
@@ -3,6 +3,7 @@
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
+from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function
try:
@@ -124,6 +125,7 @@ def __init__(
sep_value=None,
pad_value=None,
truncate="round_robin",
+ padding_side="right",
**kwargs,
):
super().__init__(**kwargs)
@@ -163,6 +165,8 @@ def check_special_value_type(value, value_name):
self.pad_value = pad_value
+ self.padding_side = padding_side
+
def get_config(self):
config = super().get_config()
config.update(
@@ -173,6 +177,7 @@ def get_config(self):
"sep_value": self._sep_value,
"pad_value": self.pad_value,
"truncate": self.truncate,
+ "padding_side": self.padding_side,
}
)
return config
@@ -287,10 +292,18 @@ def call(
# Pad to dense tensor output.
sequence_length = sequence_length or self.sequence_length
shape = tf.cast([-1, sequence_length], "int64")
- token_ids = token_ids.to_tensor(
- shape=shape, default_value=self.pad_value
+ token_ids = pad(
+ token_ids,
+ shape=shape,
+ padding_side=self.padding_side,
+ pad_value=self.pad_value,
+ )
+ segment_ids = pad(
+ segment_ids,
+ shape=shape,
+ padding_side=self.padding_side,
+ pad_value=0,
)
- segment_ids = segment_ids.to_tensor(shape=shape)
# Remove the batch dim if added.
if unbatched:
token_ids = tf.squeeze(token_ids, 0)
diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
index c38aa137e5..7ffba1dd7e 100644
--- a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
@@ -8,6 +8,7 @@
class MultiSegmentPackerTest(TestCase):
def test_trim_single_input_ints(self):
+ # right padding
input_data = np.arange(3, 10)
packer = MultiSegmentPacker(
sequence_length=8, start_value=1, end_value=2
@@ -16,7 +17,20 @@ def test_trim_single_input_ints(self):
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])
+ # left padding
+ input_data = np.arange(3, 10)
+ packer = MultiSegmentPacker(
+ sequence_length=8,
+ start_value=1,
+ end_value=2,
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer(input_data)
+ self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])
+
def test_trim_single_input_strings(self):
+ # right padding
input_data = ["a", "b", "c", "d"]
packer = MultiSegmentPacker(
sequence_length=5, start_value="[CLS]", end_value="[SEP]"
@@ -25,7 +39,19 @@ def test_trim_single_input_strings(self):
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=5,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer(input_data)
+ self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])
+
def test_trim_multiple_inputs_round_robin(self):
+ # right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
@@ -40,7 +66,22 @@ def test_trim_multiple_inputs_round_robin(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="round_robin",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids, ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"]
+ )
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])
+
def test_trim_multiple_inputs_waterfall(self):
+ # right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
@@ -55,7 +96,22 @@ def test_trim_multiple_inputs_waterfall(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="waterfall",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids, ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"]
+ )
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])
+
def test_trim_batched_inputs_round_robin(self):
+ # right padding
seq1 = [["a", "b", "c"], ["a", "b", "c"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
@@ -80,7 +136,32 @@ def test_trim_batched_inputs_round_robin(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="round_robin",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
+ ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 0, 1, 1, 1],
+ ],
+ )
+
def test_trim_batched_inputs_waterfall(self):
+ # right padding
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
@@ -105,7 +186,32 @@ def test_trim_batched_inputs_waterfall(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="waterfall",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"],
+ ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 1, 1, 1],
+ ],
+ )
+
def test_pad_inputs(self):
+ # right padding
seq1 = ["a"]
seq2 = ["x"]
packer = MultiSegmentPacker(
@@ -118,7 +224,23 @@ def test_pad_inputs(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0])
+ # left padding
+ packer = MultiSegmentPacker(
+ 6,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ ["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
+ )
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1])
+
def test_pad_batched_inputs(self):
+ # right padding
seq1 = [["a"], ["a"]]
seq2 = [["x"], ["x", "y"]]
packer = MultiSegmentPacker(
@@ -143,7 +265,32 @@ def test_pad_batched_inputs(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
+ ["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 1, 1, 1],
+ ],
+ )
+
def test_list_special_tokens(self):
+ # right padding
seq1 = [["a", "b"], ["a", "b"]]
seq2 = [["x", "y"], ["x"]]
packer = MultiSegmentPacker(
@@ -170,6 +317,32 @@ def test_list_special_tokens(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ 8,
+ start_value="",
+ end_value="",
+ sep_value=["", ""],
+ pad_value="",
+ truncate="round_robin",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["", "a", "b", "", "", "x", "y", ""],
+ ["", "", "a", "b", "", "", "x", ""],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0, 1, 1],
+ ],
+ )
+
def test_config(self):
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index 542eec39ac..efe10a4585 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -3,6 +3,7 @@
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
+from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function
try:
@@ -144,19 +145,6 @@ def check_special_value_type(value, value_name):
self.return_padding_mask = return_padding_mask
self.padding_side = padding_side
- def pad(self, x, shape, pad_value):
- if self.padding_side == "left":
- x = x[..., ::-1]
-
- outputs = x.to_tensor(
- default_value=pad_value,
- shape=shape,
- )
-
- if self.padding_side == "left":
- outputs = outputs[..., ::-1]
- return outputs
-
@preprocessing_function
def call(
self,
@@ -194,20 +182,22 @@ def call(
x = tf.concat([x, end_token_id_tensor], axis=-1)
# Pad to desired length.
- outputs = self.pad(
+ outputs = pad(
x,
- shape=(batch_size, sequence_length),
pad_value=self.pad_value,
+ padding_side=self.padding_side,
+ shape=(batch_size, sequence_length),
)
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
- mask = self.pad(
+ mask = pad(
mask,
- shape=(batch_size, sequence_length),
pad_value=False,
+ padding_side=self.padding_side,
+ shape=(batch_size, sequence_length),
)
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask
diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py
index a602963cf0..fa4a5abad1 100644
--- a/keras_hub/src/utils/tensor_utils.py
+++ b/keras_hub/src/utils/tensor_utils.py
@@ -19,6 +19,20 @@
NO_CONVERT_COUNTER = threading.local()
+def pad(x, shape, padding_side, pad_value):
+ if padding_side == "left":
+ x = x[..., ::-1]
+
+ outputs = x.to_tensor(
+ default_value=pad_value,
+ shape=shape,
+ )
+
+ if padding_side == "left":
+ outputs = outputs[..., ::-1]
+ return outputs
+
+
@contextlib.contextmanager
def no_convert_scope():
try:
From 7ceef484194679b4a7dfe847d2feb3cb11697898 Mon Sep 17 00:00:00 2001
From: pass_lin <935499957@qq.com>
Date: Fri, 23 May 2025 11:50:49 +0800
Subject: [PATCH 9/9] add doc.
---
keras_hub/src/layers/preprocessing/multi_segment_packer.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer.py b/keras_hub/src/layers/preprocessing/multi_segment_packer.py
index 54608292f0..a4e0ba7ef4 100644
--- a/keras_hub/src/layers/preprocessing/multi_segment_packer.py
+++ b/keras_hub/src/layers/preprocessing/multi_segment_packer.py
@@ -67,6 +67,8 @@ class MultiSegmentPacker(PreprocessingLayer):
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It support arbitrary number of segments.
+ padding_side: str. Whether to pad the input on the "left" or "right".
+ Defaults to "right".
Returns:
A tuple with two elements. The first is the dense, packed token
@@ -164,7 +166,6 @@ def check_special_value_type(value, value_name):
self.end_value = end_value
self.pad_value = pad_value
-
self.padding_side = padding_side
def get_config(self):