4
4
import onnxscript as ost
5
5
from onnxscript import opset19 as op # opset19 is the lastest by 202309
6
6
7
+ np .random .seed (0 )
8
+
7
9
def make_model_and_data (model , * args , ** kwargs ):
8
10
name = model ._name
9
11
10
12
# TODO: support multiple outputs
11
- output = model (* args , ** kwargs ) # eager mode
13
+ output = model (* args ) # eager mode
12
14
13
15
# Save model
14
16
model_proto = model .to_model_proto ()
@@ -24,13 +26,20 @@ def make_model_and_data(model, *args, **kwargs):
24
26
onnx .save (model_proto_ , save_path )
25
27
26
28
# Save inputs and output
29
+ inputs = args
30
+ if "force_saving_input_as_dtype_float32" in kwargs and kwargs ["force_saving_input_as_dtype_float32" ]:
31
+ inputs = []
32
+ for input in args :
33
+ inputs .append (input .astype (np .float32 ))
27
34
if len (args ) == 1 :
28
35
input_file = os .path .join ("data" , "input_" + name )
29
- np .save (input_file , args [0 ])
36
+ np .save (input_file , inputs [0 ])
30
37
else :
31
- for idx , input in enumerate (args , start = 0 ):
38
+ for idx , input in enumerate (inputs , start = 0 ):
32
39
input_files = os .path .join ("data" , "input_" + name + "_" + str (index ))
33
40
np .save (input_files , input )
41
+ if "force_saving_output_as_dtype_float32" in kwargs and kwargs ["force_saving_output_as_dtype_float32" ]:
42
+ output = output .astype (np .float32 )
34
43
output_files = os .path .join ("data" , "output_" + name )
35
44
np .save (output_files , output )
36
45
@@ -48,3 +57,14 @@ def gather_shared_indices(x: ost.FLOAT[2, 1, 3, 4]) -> ost.FLOAT[3, 4]:
48
57
y1 = op .Gather (y0 , indices , axis = 0 )
49
58
return y1
50
59
make_model_and_data (gather_shared_indices , np .random .rand (2 , 1 , 3 , 4 ).astype (np .float32 ))
60
+
61
+ '''
62
+ [Input] -> Greater(B=61) -> [Output]
63
+ \
64
+ dtype=np.int64
65
+ '''
66
+ @ost .script ()
67
+ def greater_input_dtype_int64 (x : ost .FLOAT [27 , 9 ]) -> ost .BOOL [27 , 9 ]:
68
+ y = op .Greater (x , op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [], np .array ([61 ], dtype = np .int64 ))))
69
+ return y
70
+ make_model_and_data (greater_input_dtype_int64 , np .random .randint (0 , 100 , size = [27 , 9 ], dtype = np .int64 ), force_saving_input_as_dtype_float32 = True , force_saving_output_as_dtype_float32 = True )
0 commit comments