Skip to content

Fix edge case for tokenize (#36277) #36555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_g2p_en_available,
is_keras_nlp_available,
is_librosa_available,
is_mlx_available,
is_pretty_midi_available,
is_scipy_available,
is_sentencepiece_available,
Expand Down Expand Up @@ -231,6 +232,7 @@
"is_faiss_available",
"is_flax_available",
"is_keras_nlp_available",
"is_mlx_available",
"is_phonemizer_available",
"is_psutil_available",
"is_py3nvml_available",
Expand Down Expand Up @@ -728,6 +730,7 @@
is_faiss_available,
is_flax_available,
is_keras_nlp_available,
is_mlx_available,
is_phonemizer_available,
is_psutil_available,
is_py3nvml_available,
Expand Down
47 changes: 39 additions & 8 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def import_protobuf_decode_error(error_message=""):
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))


def flatten(arr: list):
res = []
if len(arr) > 0:
for sub_arr in arr:
if isinstance(arr[0], (list, tuple)):
res.extend(flatten(sub_arr))
else:
res.append(sub_arr)
return res


if is_tokenizers_available():
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast
Expand Down Expand Up @@ -723,45 +734,65 @@ def convert_to_tensors(
)
import tensorflow as tf

as_tensor = tf.constant
def as_tensor(value, dtype=None):
if len(flatten(value)) == 0 and dtype is None:
dtype = tf.int32
return tf.constant(value, dtype=dtype)

is_tensor = tf.is_tensor

elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch

is_tensor = torch.is_tensor

def as_tensor(value, dtype=None):
if isinstance(value, list) and isinstance(value[0], np.ndarray):
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], np.ndarray):
return torch.from_numpy(np.array(value))
return torch.tensor(value)
if len(flatten(value)) == 0 and dtype is None:
dtype = torch.int64
return torch.tensor(value, dtype=dtype)

is_tensor = torch.is_tensor

elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp # noqa: F811

as_tensor = jnp.array
def as_tensor(value, dtype=None):
if len(flatten(value)) == 0 and dtype is None:
dtype = jnp.int32
return jnp.array(value, dtype=dtype)

is_tensor = is_jax_tensor

elif tensor_type == TensorType.MLX:
if not is_mlx_available():
raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.")
import mlx.core as mx

as_tensor = mx.array
def as_tensor(value, dtype=None):
if len(flatten(value)) == 0 and dtype is None:
dtype = mx.int32
return mx.array(value, dtype=dtype)

def is_tensor(obj):
return isinstance(obj, mx.array)
else:

def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
if (
isinstance(value, (list, tuple))
and len(value) > 0
and isinstance(value[0], (list, tuple, np.ndarray))
):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
if len(flatten(value)) == 0 and dtype is None:
dtype = np.int64
return np.asarray(value, dtype=dtype)

is_tensor = is_numpy_array
Expand Down
39 changes: 39 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Trainer,
TrainingArguments,
is_flax_available,
is_mlx_available,
is_tf_available,
is_torch_available,
logging,
Expand Down Expand Up @@ -4736,3 +4737,41 @@ def test_rust_tokenizer_add_prefix_space(self, add_prefix_space):
# Only the ByteLevel pre-tokenizer has the `add_prefix_space` attribute, we have to ensure that it's set correctly
if hasattr(fast_tokenizer.backend_tokenizer.pre_tokenizer, "add_prefix_space"):
self.assertEqual(fast_tokenizer.backend_tokenizer.pre_tokenizer.add_prefix_space, add_prefix_space)

def test_empty_input_string(self):
empty_input_string = ""
tokenizer_return_type = []
output_tensor_type = []
if is_torch_available():
import numpy as np
import torch

tokenizer_return_type.append("pt")
output_tensor_type.append(torch.int64)
tokenizer_return_type.append("np")
output_tensor_type.append(np.int64)
if is_tf_available():
import tensorflow as tf

tokenizer_return_type.append("tf")
output_tensor_type.append(tf.int32)
if is_flax_available():
import jax.numpy as jnp

tokenizer_return_type.append("jax")
output_tensor_type.append(jnp.int32)
if is_mlx_available():
import mlx.core as mx

tokenizer_return_type.append("mlx")
output_tensor_type.append(mx.int32)

if len(tokenizer_return_type) == 0:
self.skipTest(reason="No expected framework from PT, TF, JAX or MLX found")

tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
output = tokenizer(empty_input_string, return_tensors=return_type)
self.assertEqual(output.input_ids.dtype, target_type)