@@ -724,46 +724,52 @@ def convert_to_tensors(
724
724
)
725
725
import tensorflow as tf
726
726
727
- as_tensor = tf .constant
727
+ def as_tensor (value ):
728
+ return tf .constant (value , dtype = tf .int32 )
729
+
728
730
is_tensor = tf .is_tensor
731
+
729
732
elif tensor_type == TensorType .PYTORCH :
730
733
if not is_torch_available ():
731
734
raise ImportError ("Unable to convert output to PyTorch tensors format, PyTorch is not installed." )
732
735
import torch
733
736
734
- is_tensor = torch .is_tensor
735
-
736
- def as_tensor (value , dtype = None ):
737
+ def as_tensor (value ):
737
738
if isinstance (value , list ) and isinstance (value [0 ], np .ndarray ):
738
- return torch .from_numpy (np .array (value ))
739
- return torch .tensor (value )
739
+ return torch .from_numpy (np .array (value )).to (torch .int64 )
740
+ return torch .tensor (value , dtype = torch .int64 )
741
+
742
+ is_tensor = torch .is_tensor
740
743
741
744
elif tensor_type == TensorType .JAX :
742
745
if not is_flax_available ():
743
746
raise ImportError ("Unable to convert output to JAX tensors format, JAX is not installed." )
744
747
import jax .numpy as jnp # noqa: F811
745
748
746
- as_tensor = jnp .array
749
+ def as_tensor (value ):
750
+ return jnp .array (value , dtype = jnp .int32 )
751
+
747
752
is_tensor = is_jax_tensor
748
753
749
754
elif tensor_type == TensorType .MLX :
750
755
if not is_mlx_available ():
751
756
raise ImportError ("Unable to convert output to MLX tensors format, MLX is not installed." )
752
757
import mlx .core as mx
753
758
754
- as_tensor = mx .array
759
+ def as_tensor (value ):
760
+ return mx .array (value , dtype = mx .int32 )
755
761
756
762
def is_tensor (obj ):
757
763
return isinstance (obj , mx .array )
758
764
else :
759
765
760
- def as_tensor (value , dtype = None ):
766
+ def as_tensor (value ):
761
767
if isinstance (value , (list , tuple )) and isinstance (value [0 ], (list , tuple , np .ndarray )):
762
768
value_lens = [len (val ) for val in value ]
763
- if len (set (value_lens )) > 1 and dtype is None :
769
+ if len (set (value_lens )) > 1 :
764
770
# we have a ragged list so handle explicitly
765
771
value = as_tensor ([np .asarray (val ) for val in value ], dtype = object )
766
- return np .asarray (value , dtype = dtype )
772
+ return np .asarray (value , dtype = np . int64 )
767
773
768
774
is_tensor = is_numpy_array
769
775
0 commit comments