Skip to content

Commit 522b48c

Browse files
authored
Merge pull request #1107 from fengyuentau:fix_dtype_nary_eltwise
Merge with opencv/opencv#24386 Also noted that random seed is added in this pr, so previous data is regenerated to keep consistency.
1 parent 00385e4 commit 522b48c

6 files changed

+23
-3
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models_with_onnxscript.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import onnxscript as ost
55
from onnxscript import opset19 as op # opset19 is the lastest by 202309
66

7+
np.random.seed(0)
8+
79
def make_model_and_data(model, *args, **kwargs):
810
name = model._name
911

1012
# TODO: support multiple outputs
11-
output = model(*args, **kwargs) # eager mode
13+
output = model(*args) # eager mode
1214

1315
# Save model
1416
model_proto = model.to_model_proto()
@@ -24,13 +26,20 @@ def make_model_and_data(model, *args, **kwargs):
2426
onnx.save(model_proto_, save_path)
2527

2628
# Save inputs and output
29+
inputs = args
30+
if "force_saving_input_as_dtype_float32" in kwargs and kwargs["force_saving_input_as_dtype_float32"]:
31+
inputs = []
32+
for input in args:
33+
inputs.append(input.astype(np.float32))
2734
if len(args) == 1:
2835
input_file = os.path.join("data", "input_" + name)
29-
np.save(input_file, args[0])
36+
np.save(input_file, inputs[0])
3037
else:
31-
for idx, input in enumerate(args, start=0):
38+
for idx, input in enumerate(inputs, start=0):
3239
input_files = os.path.join("data", "input_" + name + "_" + str(index))
3340
np.save(input_files, input)
41+
if "force_saving_output_as_dtype_float32" in kwargs and kwargs["force_saving_output_as_dtype_float32"]:
42+
output = output.astype(np.float32)
3443
output_files = os.path.join("data", "output_" + name)
3544
np.save(output_files, output)
3645

@@ -48,3 +57,14 @@ def gather_shared_indices(x: ost.FLOAT[2, 1, 3, 4]) -> ost.FLOAT[3, 4]:
4857
y1 = op.Gather(y0, indices, axis=0)
4958
return y1
5059
make_model_and_data(gather_shared_indices, np.random.rand(2, 1, 3, 4).astype(np.float32))
60+
61+
'''
62+
[Input] -> Greater(B=61) -> [Output]
63+
\
64+
dtype=np.int64
65+
'''
66+
@ost.script()
67+
def greater_input_dtype_int64(x: ost.FLOAT[27, 9]) ->ost.BOOL[27, 9]:
68+
y = op.Greater(x, op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [], np.array([61], dtype=np.int64))))
69+
return y
70+
make_model_and_data(greater_input_dtype_int64, np.random.randint(0, 100, size=[27, 9], dtype=np.int64), force_saving_input_as_dtype_float32=True, force_saving_output_as_dtype_float32=True)
Binary file not shown.

0 commit comments

Comments
 (0)