Skip to content

Commit 3d6cde0

Browse files
authored
Merge pull request #1158 from fengyuentau/matmul_bias
* make model and data for biased matmul using onnxruntime for inference * add note about onnxscript producing incorrect results in eager mode * update script to put random numpy arrays outside of function
1 parent f9f7143 commit 3d6cde0

File tree

4 files changed

+26
-0
lines changed

4 files changed

+26
-0
lines changed
4.13 KB
Binary file not shown.
16.1 KB
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models_with_onnxscript.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
from onnxscript import opset11
77
from onnxscript import opset13
88

9+
###############
10+
### CAUTION!!!
11+
### Be sure to put constant numpy arrays out of @ost.script() decorated fucntion.
12+
### Otherwise random values change each time eager mode is enter.
13+
### See discussions in https://github.com/microsoft/onnxscript/issues/1313
14+
###############
15+
916
np.random.seed(0)
1017

1118
def make_model_and_data(model, *args, **kwargs):
@@ -339,3 +346,22 @@ def layer_norm_no_fusion(x: ost.FLOAT[n, c, h, w]) -> ost.FLOAT[n, c, h, w]:
339346

340347
return add
341348
make_model_and_data(layer_norm_no_fusion, np.random.rand(n, c, h, w).astype(np.float32))
349+
350+
351+
''' Subgraph: [Input] -> MatMul<B> -> Add<A> -> [Output]
352+
'''
353+
354+
b = 2
355+
m = 32
356+
n = 64
357+
k = 16
358+
weight_data = np.random.rand(k, n).astype(np.float32)
359+
bias_data = np.random.rand(n).astype(np.float32)
360+
361+
@ost.script()
362+
def biased_matmul(x: ost.FLOAT[b, m, k]) -> ost.FLOAT[b, m, n]:
363+
weight = op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.FLOAT, [k, n], weight_data))
364+
matmul = op.MatMul(x, weight)
365+
bias = op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.FLOAT, [n], bias_data))
366+
return op.Add(bias, matmul)
367+
make_model_and_data(biased_matmul, np.random.rand(b, m, k).astype(np.float32), use_ort=True, ort_input_keys=["x"])
4.58 KB
Binary file not shown.

0 commit comments

Comments
 (0)