Skip to content

Commit 10b2d31

Browse files
authored
Merge pull request #2392 from Satge96/main
Replace np.cast (deprecated in numpy >= 2) with np.asarray.
2 parents 026e12f + f4608dc commit 10b2d31

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tf2onnx/tfonnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def rewrite_constant_fold(g, ops):
7474
func_map = {
7575
"Add": np.add,
7676
"GreaterEqual": np.greater_equal,
77-
"Cast": np.cast,
77+
"Cast": np.asarray,
7878
"ConcatV2": np.concatenate,
7979
"Less": np.less,
8080
"ListDiff": np.setdiff1d,
@@ -107,7 +107,7 @@ def rewrite_constant_fold(g, ops):
107107
if op.type == "Cast":
108108
dst = op.get_attr_int("to")
109109
np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst)
110-
val = np.cast[np_type](*inputs)
110+
val = np.asarray(*inputs, dtype=np_type)
111111
elif op.type == "ConcatV2":
112112
axis = inputs[-1]
113113
values = inputs[:-1]

0 commit comments

Comments
 (0)