We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 026e12f + f4608dc commit 10b2d31Copy full SHA for 10b2d31
tf2onnx/tfonnx.py
@@ -74,7 +74,7 @@ def rewrite_constant_fold(g, ops):
74
func_map = {
75
"Add": np.add,
76
"GreaterEqual": np.greater_equal,
77
- "Cast": np.cast,
+ "Cast": np.asarray,
78
"ConcatV2": np.concatenate,
79
"Less": np.less,
80
"ListDiff": np.setdiff1d,
@@ -107,7 +107,7 @@ def rewrite_constant_fold(g, ops):
107
if op.type == "Cast":
108
dst = op.get_attr_int("to")
109
np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst)
110
- val = np.cast[np_type](*inputs)
+ val = np.asarray(*inputs, dtype=np_type)
111
elif op.type == "ConcatV2":
112
axis = inputs[-1]
113
values = inputs[:-1]
0 commit comments