diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer.py b/keras_hub/src/layers/preprocessing/multi_segment_packer.py
index 3c0a8e346d..a4e0ba7ef4 100644
--- a/keras_hub/src/layers/preprocessing/multi_segment_packer.py
+++ b/keras_hub/src/layers/preprocessing/multi_segment_packer.py
@@ -3,6 +3,7 @@
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
+from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function
try:
@@ -66,6 +67,8 @@ class MultiSegmentPacker(PreprocessingLayer):
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It support arbitrary number of segments.
+ padding_side: str. Whether to pad the input on the "left" or "right".
+ Defaults to "right".
Returns:
A tuple with two elements. The first is the dense, packed token
@@ -124,6 +127,7 @@ def __init__(
sep_value=None,
pad_value=None,
truncate="round_robin",
+ padding_side="right",
**kwargs,
):
super().__init__(**kwargs)
@@ -162,6 +166,7 @@ def check_special_value_type(value, value_name):
self.end_value = end_value
self.pad_value = pad_value
+ self.padding_side = padding_side
def get_config(self):
config = super().get_config()
@@ -173,6 +178,7 @@ def get_config(self):
"sep_value": self._sep_value,
"pad_value": self.pad_value,
"truncate": self.truncate,
+ "padding_side": self.padding_side,
}
)
return config
@@ -287,10 +293,18 @@ def call(
# Pad to dense tensor output.
sequence_length = sequence_length or self.sequence_length
shape = tf.cast([-1, sequence_length], "int64")
- token_ids = token_ids.to_tensor(
- shape=shape, default_value=self.pad_value
+ token_ids = pad(
+ token_ids,
+ shape=shape,
+ padding_side=self.padding_side,
+ pad_value=self.pad_value,
+ )
+ segment_ids = pad(
+ segment_ids,
+ shape=shape,
+ padding_side=self.padding_side,
+ pad_value=0,
)
- segment_ids = segment_ids.to_tensor(shape=shape)
# Remove the batch dim if added.
if unbatched:
token_ids = tf.squeeze(token_ids, 0)
diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
index c38aa137e5..7ffba1dd7e 100644
--- a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
@@ -8,6 +8,7 @@
class MultiSegmentPackerTest(TestCase):
def test_trim_single_input_ints(self):
+ # right padding
input_data = np.arange(3, 10)
packer = MultiSegmentPacker(
sequence_length=8, start_value=1, end_value=2
@@ -16,7 +17,20 @@ def test_trim_single_input_ints(self):
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])
+ # left padding
+ input_data = np.arange(3, 10)
+ packer = MultiSegmentPacker(
+ sequence_length=8,
+ start_value=1,
+ end_value=2,
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer(input_data)
+ self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])
+
def test_trim_single_input_strings(self):
+ # right padding
input_data = ["a", "b", "c", "d"]
packer = MultiSegmentPacker(
sequence_length=5, start_value="[CLS]", end_value="[SEP]"
@@ -25,7 +39,19 @@ def test_trim_single_input_strings(self):
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=5,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer(input_data)
+ self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])
+
def test_trim_multiple_inputs_round_robin(self):
+ # right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
@@ -40,7 +66,22 @@ def test_trim_multiple_inputs_round_robin(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="round_robin",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids, ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"]
+ )
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])
+
def test_trim_multiple_inputs_waterfall(self):
+ # right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
@@ -55,7 +96,22 @@ def test_trim_multiple_inputs_waterfall(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="waterfall",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids, ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"]
+ )
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])
+
def test_trim_batched_inputs_round_robin(self):
+ # right padding
seq1 = [["a", "b", "c"], ["a", "b", "c"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
@@ -80,7 +136,32 @@ def test_trim_batched_inputs_round_robin(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="round_robin",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
+ ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 0, 1, 1, 1],
+ ],
+ )
+
def test_trim_batched_inputs_waterfall(self):
+ # right padding
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
@@ -105,7 +186,32 @@ def test_trim_batched_inputs_waterfall(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ truncate="waterfall",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"],
+ ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 1, 1, 1],
+ ],
+ )
+
def test_pad_inputs(self):
+ # right padding
seq1 = ["a"]
seq2 = ["x"]
packer = MultiSegmentPacker(
@@ -118,7 +224,23 @@ def test_pad_inputs(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0])
+ # left padding
+ packer = MultiSegmentPacker(
+ 6,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ ["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
+ )
+ self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1])
+
def test_pad_batched_inputs(self):
+ # right padding
seq1 = [["a"], ["a"]]
seq2 = [["x"], ["x", "y"]]
packer = MultiSegmentPacker(
@@ -143,7 +265,32 @@ def test_pad_batched_inputs(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ sequence_length=7,
+ start_value="[CLS]",
+ end_value="[SEP]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
+ ["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 1, 1, 1],
+ ],
+ )
+
def test_list_special_tokens(self):
+ # right padding
seq1 = [["a", "b"], ["a", "b"]]
seq2 = [["x", "y"], ["x"]]
packer = MultiSegmentPacker(
@@ -170,6 +317,32 @@ def test_list_special_tokens(self):
],
)
+ # left padding
+ packer = MultiSegmentPacker(
+ 8,
+ start_value="",
+ end_value="",
+ sep_value=["", ""],
+ pad_value="",
+ truncate="round_robin",
+ padding_side="left",
+ )
+ token_ids, segment_ids = packer((seq1, seq2))
+ self.assertAllEqual(
+ token_ids,
+ [
+ ["", "a", "b", "", "", "x", "y", ""],
+ ["", "", "a", "b", "", "", "x", ""],
+ ],
+ )
+ self.assertAllEqual(
+ segment_ids,
+ [
+ [0, 0, 0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0, 1, 1],
+ ],
+ )
+
def test_config(self):
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index f50f09b7f1..efe10a4585 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -3,6 +3,7 @@
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
+from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function
try:
@@ -39,6 +40,8 @@ class StartEndPacker(PreprocessingLayer):
0 or "" will be added depending on the dtype of the input tensor.
return_padding_mask: bool. Whether to return a boolean padding mask of
all locations that are filled in with the `pad_value`.
+ padding_side: str. Whether to pad the input on the "left" or "right".
+ Defaults to "right".
Call arguments:
inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
@@ -111,6 +114,7 @@ def __init__(
pad_value=None,
return_padding_mask=False,
name=None,
+ padding_side="right",
**kwargs,
):
super().__init__(name=name, **kwargs)
@@ -139,6 +143,7 @@ def check_special_value_type(value, value_name):
self.pad_value = pad_value
self.return_padding_mask = return_padding_mask
+ self.padding_side = padding_side
@preprocessing_function
def call(
@@ -154,6 +159,13 @@ def call(
batch_size = tf.shape(x)[0]
sequence_length = sequence_length or self.sequence_length
dtype = inputs.dtype
+ # Truncate.
+ truncation_length = sequence_length
+ if add_start_value and self.start_value is not None:
+ truncation_length -= len(self.start_value)
+ if add_end_value and self.end_value is not None:
+ truncation_length -= len(self.end_value)
+ x = x[..., :truncation_length]
# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
@@ -167,23 +179,28 @@ def call(
end_token_id_tensor = tf.repeat(
end_value[tf.newaxis, :], repeats=batch_size, axis=0
)
- # Trim to leave room for end token.
- x = x[..., : sequence_length - len(self.end_value)]
x = tf.concat([x, end_token_id_tensor], axis=-1)
# Pad to desired length.
- outputs = x.to_tensor(
- default_value=self.pad_value,
+ outputs = pad(
+ x,
+ pad_value=self.pad_value,
+ padding_side=self.padding_side,
shape=(batch_size, sequence_length),
)
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
if self.return_padding_mask:
mask = tf.ones_like(x, dtype="bool")
- mask = mask.to_tensor(shape=(batch_size, sequence_length))
+
+ mask = pad(
+ mask,
+ pad_value=False,
+ padding_side=self.padding_side,
+ shape=(batch_size, sequence_length),
+ )
mask = tf.squeeze(mask, axis=0) if unbatched else mask
return outputs, mask
-
return outputs
def get_config(self):
@@ -195,6 +212,7 @@ def get_config(self):
"end_value": self._end_value,
"pad_value": self.pad_value,
"return_padding_mask": self.return_padding_mask,
+ "padding_side": self.padding_side,
}
)
return config
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
index 3b0ea65a9c..78f65405f0 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
@@ -6,11 +6,19 @@
class StartEndPackerTest(TestCase):
def test_dense_input(self):
+ # right padding
input_data = [5, 6, 7]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [5, 6, 7, 0, 0]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [0, 0, 5, 6, 7]
+ self.assertAllEqual(output, expected_output)
def test_bfloat16_dtype(self):
# Core Keras has a strange bug where it converts int to floats in
@@ -21,20 +29,37 @@ def test_bfloat16_dtype(self):
self.assertDTypeEqual(output, "int32")
def test_dense_2D_input(self):
+ # right padding
input_data = [[5, 6, 7]]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 0, 5, 6, 7]]
+ self.assertAllEqual(output, expected_output)
def test_ragged_input(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(sequence_length=5)
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0], [8, 9, 10, 11, 0]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]]
+ self.assertAllEqual(output, expected_output)
def test_start_end_token(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=6, start_value=1, end_value=2
@@ -42,8 +67,16 @@ def test_start_end_token(self):
output = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=6, start_value=1, end_value=2, padding_side="left"
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
+ self.assertAllEqual(output, expected_output)
def test_multiple_start_end_tokens(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
sequence_length=8,
@@ -55,7 +88,20 @@ def test_multiple_start_end_tokens(self):
expected_output = [[1, 2, 5, 6, 7, 3, 4, 0], [1, 2, 8, 9, 10, 11, 3, 4]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=8,
+ start_value=[1, 2],
+ end_value=[3, 4],
+ pad_value=0,
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 1, 2, 5, 6, 7, 3, 4], [1, 2, 8, 9, 10, 11, 3, 4]]
+ self.assertAllEqual(output, expected_output)
+
def test_start_end_padding_value(self):
+ # right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
sequence_length=7, start_value=1, end_value=2, pad_value=3
@@ -64,7 +110,58 @@ def test_start_end_padding_value(self):
expected_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=7,
+ start_value=1,
+ end_value=2,
+ pad_value=3,
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]]
+ self.assertAllEqual(output, expected_output)
+
+ def test_truncation(self):
+ # right padding
+ input_data = list(range(10))
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ end_value=99,
+ )
+ expected_output = [98, 0, 1, 2, 3, 4, 99]
+ self.assertAllEqual(packer(input_data), expected_output)
+
+ # left padding
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ end_value=99,
+ padding_side="left",
+ )
+ self.assertAllEqual(packer(input_data), expected_output)
+
+ def test_truncation_wo_endvalue(self):
+ # right padding
+ input_data = list(range(10))
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ )
+ expected_output = [98, 0, 1, 2, 3, 4, 5]
+ self.assertAllEqual(packer(input_data), expected_output)
+
+ # left padding
+ packer = StartEndPacker(
+ sequence_length=7,
+ start_value=98,
+ padding_side="left",
+ )
+ self.assertAllEqual(packer(input_data), expected_output)
+
def test_end_token_value_during_truncation(self):
+ # right padding
input_data = [[5, 6], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
sequence_length=5, start_value=1, end_value=2, pad_value=0
@@ -73,7 +170,20 @@ def test_end_token_value_during_truncation(self):
expected_output = [[1, 5, 6, 2, 0], [1, 8, 9, 10, 2]]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5,
+ start_value=1,
+ end_value=2,
+ pad_value=0,
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [[0, 1, 5, 6, 2], [1, 8, 9, 10, 2]]
+ self.assertAllEqual(output, expected_output)
+
def test_string_input(self):
+ # right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
sequence_length=5,
@@ -88,7 +198,23 @@ def test_string_input(self):
]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=5,
+ start_value="[START]",
+ end_value="[END]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [
+ ["[START]", "KerasHub", "is", "awesome", "[END]"],
+ ["[PAD]", "[PAD]", "[START]", "amazing", "[END]"],
+ ]
+ self.assertAllEqual(output, expected_output)
+
def test_string_input_with_multiple_special_values(self):
+ # right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
sequence_length=6,
@@ -103,6 +229,21 @@ def test_string_input_with_multiple_special_values(self):
]
self.assertAllEqual(output, expected_output)
+ # left padding
+ start_end_packer = StartEndPacker(
+ sequence_length=6,
+ start_value=["[END]", "[START]"],
+ end_value="[END]",
+ pad_value="[PAD]",
+ padding_side="left",
+ )
+ output = start_end_packer(input_data)
+ expected_output = [
+ ["[END]", "[START]", "KerasHub", "is", "awesome", "[END]"],
+ ["[PAD]", "[PAD]", "[END]", "[START]", "amazing", "[END]"],
+ ]
+ self.assertAllEqual(output, expected_output)
+
def test_special_token_dtype_error(self):
with self.assertRaises(ValueError):
StartEndPacker(sequence_length=5, start_value=1.0)
@@ -147,3 +288,39 @@ def test_get_config(self):
}
self.assertEqual(config, {**config, **expected_config_subset})
+
+ def test_return_padding_mask(self):
+ # right_padding
+ input_data = [[5, 6, 7], [8, 9, 10, 11]]
+ start_end_packer = StartEndPacker(
+ sequence_length=6,
+ start_value=1,
+ end_value=2,
+ return_padding_mask=True,
+ )
+ output, padding_mask = start_end_packer(input_data)
+ expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
+ expected_padding_mask = [
+ [True, True, True, True, True, False],
+ [True, True, True, True, True, True],
+ ]
+ print(padding_mask)
+ self.assertAllEqual(output, expected_output)
+ self.assertAllEqual(padding_mask, expected_padding_mask)
+
+ # left_padding
+ start_end_packer = StartEndPacker(
+ sequence_length=6,
+ start_value=1,
+ end_value=2,
+ return_padding_mask=True,
+ padding_side="left",
+ )
+ output, padding_mask = start_end_packer(input_data)
+ expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
+ expected_padding_mask = [
+ [False, True, True, True, True, True],
+ [True, True, True, True, True, True],
+ ]
+ self.assertAllEqual(output, expected_output)
+ self.assertAllEqual(padding_mask, expected_padding_mask)
diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py
index a602963cf0..fa4a5abad1 100644
--- a/keras_hub/src/utils/tensor_utils.py
+++ b/keras_hub/src/utils/tensor_utils.py
@@ -19,6 +19,20 @@
NO_CONVERT_COUNTER = threading.local()
+def pad(x, shape, padding_side, pad_value):
+ if padding_side == "left":
+ x = x[..., ::-1]
+
+ outputs = x.to_tensor(
+ default_value=pad_value,
+ shape=shape,
+ )
+
+ if padding_side == "left":
+ outputs = outputs[..., ::-1]
+ return outputs
+
+
@contextlib.contextmanager
def no_convert_scope():
try: