Skip to content

Commit 1b032da

Browse files
committed
add model generation with onnxscript; add model and test data
1 parent 6608980 commit 1b032da

File tree

4 files changed

+50
-0
lines changed

4 files changed

+50
-0
lines changed
Binary file not shown.
Binary file not shown.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import numpy as np
3+
import onnx
4+
import onnxscript as ost
5+
from onnxscript import opset19 as op # opset19 is the lastest by 202309
6+
7+
def make_model_and_data(model, *args, **kwargs):
8+
name = model._name
9+
10+
# TODO: support multiple outputs
11+
output = model(*args, **kwargs) # eager mode
12+
13+
# Save model
14+
model_proto = model.to_model_proto()
15+
try:
16+
onnx.checker.check_model(model_proto)
17+
except onnx.checker.ValidationError as e:
18+
print(f"Model {name} is invalid: {e}. Skipping ...")
19+
return False
20+
else:
21+
save_path = "./models/{}.onnx".format(name)
22+
print(f"Model {name} is valid! Saved to {save_path}")
23+
model_proto_ = onnx.shape_inference.infer_shapes(model_proto)
24+
onnx.save(model_proto_, save_path)
25+
26+
# Save inputs and output
27+
if len(args) == 1:
28+
input_file = os.path.join("data", "input_" + name)
29+
np.save(input_file, args[0])
30+
else:
31+
for idx, input in enumerate(args, start=0):
32+
input_files = os.path.join("data", "input_" + name + "_" + str(index))
33+
np.save(input_files, input)
34+
output_files = os.path.join("data", "output_" + name)
35+
np.save(output_files, output)
36+
37+
'''
38+
It builds a model with two Gather ops sharing a single same indices:
39+
40+
[Input] -> Gather(indices=0) -> Gather(indices=0) -> [Output]
41+
42+
, where the two indices constants have the same name.
43+
'''
44+
@ost.script()
45+
def gather_shared_indices(x: ost.FLOAT[2, 1, 3, 4]) -> ost.FLOAT[3, 4]:
46+
indices = op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [], np.array([0], dtype=np.int64)))
47+
y0 = op.Gather(x, indices, axis=0)
48+
y1 = op.Gather(y0, indices, axis=0)
49+
return y1
50+
make_model_and_data(gather_shared_indices, np.random.rand(2, 1, 3, 4).astype(np.float32))
269 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)