Skip to content

Commit 96e6301

Browse files
authored
Replace np.cast (deprecated in numpy > 2) with np.asarray.
This should also work for older numpy versions. Signed-off-by: Satge96 <35660525+Satge96@users.noreply.github.com>
1 parent 3dd7729 commit 96e6301

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tf2onnx/tfonnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def rewrite_constant_fold(g, ops):
7474
func_map = {
7575
"Add": np.add,
7676
"GreaterEqual": np.greater_equal,
77-
"Cast": np.cast,
77+
"Cast": np.asarray,
7878
"ConcatV2": np.concatenate,
7979
"Less": np.less,
8080
"ListDiff": np.setdiff1d,
@@ -107,7 +107,7 @@ def rewrite_constant_fold(g, ops):
107107
if op.type == "Cast":
108108
dst = op.get_attr_int("to")
109109
np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst)
110-
val = np.cast[np_type](*inputs)
110+
val = np.asarray(*inputs, dtype=np_type)
111111
elif op.type == "ConcatV2":
112112
axis = inputs[-1]
113113
values = inputs[:-1]

0 commit comments

Comments
 (0)