Skip to content

Commit f61631a

Browse files
wangzhen0518Rocketknight1
authored andcommitted
Fix edge case for tokenize (#36277)
1 parent 8805600 commit f61631a

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -724,46 +724,52 @@ def convert_to_tensors(
724724
)
725725
import tensorflow as tf
726726

727-
as_tensor = tf.constant
727+
def as_tensor(value):
728+
return tf.constant(value, dtype=tf.int32)
729+
728730
is_tensor = tf.is_tensor
731+
729732
elif tensor_type == TensorType.PYTORCH:
730733
if not is_torch_available():
731734
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
732735
import torch
733736

734-
is_tensor = torch.is_tensor
735-
736-
def as_tensor(value, dtype=None):
737+
def as_tensor(value):
737738
if isinstance(value, list) and isinstance(value[0], np.ndarray):
738-
return torch.from_numpy(np.array(value))
739-
return torch.tensor(value)
739+
return torch.from_numpy(np.array(value)).to(torch.int64)
740+
return torch.tensor(value, dtype=torch.int64)
741+
742+
is_tensor = torch.is_tensor
740743

741744
elif tensor_type == TensorType.JAX:
742745
if not is_flax_available():
743746
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
744747
import jax.numpy as jnp # noqa: F811
745748

746-
as_tensor = jnp.array
749+
def as_tensor(value):
750+
return jnp.array(value, dtype=jnp.int32)
751+
747752
is_tensor = is_jax_tensor
748753

749754
elif tensor_type == TensorType.MLX:
750755
if not is_mlx_available():
751756
raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.")
752757
import mlx.core as mx
753758

754-
as_tensor = mx.array
759+
def as_tensor(value):
760+
return mx.array(value, dtype=mx.int32)
755761

756762
def is_tensor(obj):
757763
return isinstance(obj, mx.array)
758764
else:
759765

760-
def as_tensor(value, dtype=None):
766+
def as_tensor(value):
761767
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
762768
value_lens = [len(val) for val in value]
763-
if len(set(value_lens)) > 1 and dtype is None:
769+
if len(set(value_lens)) > 1:
764770
# we have a ragged list so handle explicitly
765771
value = as_tensor([np.asarray(val) for val in value], dtype=object)
766-
return np.asarray(value, dtype=dtype)
772+
return np.asarray(value, dtype=np.int64)
767773

768774
is_tensor = is_numpy_array
769775

0 commit comments

Comments
 (0)