Skip to content

Commit 6013e3a

Browse files
chinhuang007tjingrant
authored andcommitted
Add Greater and others for v9 (#270)
Add v9 support for Greater, Constant, Flatten, Gemm, MatMul, and PRelu. No logic changes seem needed, just to declare support of v9.
1 parent 026bd29 commit 6013e3a

File tree

10 files changed

+60
-15
lines changed

10 files changed

+60
-15
lines changed

doc/support_status.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ______
2222
|Ceil|1, 6|
2323
|Clip|1, 6|
2424
|Concat|1, 4|
25-
|Constant|1|
25+
|Constant|1, 9|
2626
|ConstantFill|1|
2727
|ConstantLike|9|
2828
|Conv|1|
@@ -38,17 +38,17 @@ ______
3838
|Exp|1, 6|
3939
|Expand|8|
4040
|EyeLike|N/A|
41-
|Flatten|1|
41+
|Flatten|1, 9|
4242
|Floor|1, 6|
4343
|GRU|1, 3, 7|
4444
|GRUUnit|N/A|
4545
|Gather|1|
46-
|Gemm|1, 6, 7|
46+
|Gemm|1, 6, 7, 9|
4747
|GivenTensorFill|N/A|
4848
|GlobalAveragePool|1|
4949
|GlobalLpPool|1, 2|
5050
|GlobalMaxPool|1|
51-
|Greater|1, 7|
51+
|Greater|1, 7, 9|
5252
|HardSigmoid|1, 6|
5353
|Hardmax|1|
5454
|Identity|1|
@@ -64,7 +64,7 @@ ______
6464
|Loop|N/A|
6565
|LpNormalization|1|
6666
|LpPool|N/A|
67-
|MatMul|1|
67+
|MatMul|1, 9|
6868
|Max|1, 6, 8|
6969
|MaxPool|1, 8|
7070
|MaxRoiPool|N/A|
@@ -76,7 +76,7 @@ ______
7676
|Neg|1, 6|
7777
|Not|1|
7878
|Or|1, 7|
79-
|PRelu|1, 6, 7|
79+
|PRelu|1, 6, 7, 9|
8080
|Pad|1, 2|
8181
|ParametricSoftplus|N/A|
8282
|Pow|1, 7|

onnx_tf/handlers/backend/constant.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,19 @@
1212
class Constant(BackendHandler):
1313

1414
@classmethod
15-
def version_1(cls, node, **kwargs):
15+
def _common(cls, node, **kwargs):
1616
attr_value = node.attrs["value"]
1717
dtype = data_type.onnx2tf(attr_value.data_type)
1818
value = numpy_helper.to_array(attr_value)
1919
return [
2020
cls.make_tensor_from_onnx_node(
2121
node, inputs=[value], attrs={"dtype": dtype})
2222
]
23+
24+
@classmethod
25+
def version_1(cls, node, **kwargs):
26+
return cls._common(node, **kwargs)
27+
28+
@classmethod
29+
def version_9(cls, node, **kwargs):
30+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/flatten.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class Flatten(BackendHandler):
1111

1212
@classmethod
13-
def version_1(cls, node, **kwargs):
13+
def _common(cls, node, **kwargs):
1414
x = kwargs["tensor_dict"][node.inputs[0]]
1515
shape = tf.shape(x)
1616
x_rank = len(x.shape)
@@ -25,3 +25,12 @@ def version_1(cls, node, **kwargs):
2525
cal_shape = (tf.reduce_prod(shape[0:axis]),
2626
tf.reduce_prod(shape[axis:tf.size(shape)]))
2727
return [tf.reshape(x, cal_shape)]
28+
29+
@classmethod
30+
def version_1(cls, node, **kwargs):
31+
return cls._common(node, **kwargs)
32+
33+
@classmethod
34+
def version_9(cls, node, **kwargs):
35+
return cls._common(node, **kwargs)
36+

onnx_tf/handlers/backend/gemm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ def version_6(cls, node, **kwargs):
3333
@classmethod
3434
def version_7(cls, node, **kwargs):
3535
return cls._common(node, **kwargs)
36+
37+
@classmethod
38+
def version_9(cls, node, **kwargs):
39+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/greater.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ def version_1(cls, node, **kwargs):
1717
@classmethod
1818
def version_7(cls, node, **kwargs):
1919
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
20+
21+
@classmethod
22+
def version_9(cls, node, **kwargs):
23+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

onnx_tf/handlers/backend/mat_mul.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ class MatMul(BackendHandler):
1212
@classmethod
1313
def version_1(cls, node, **kwargs):
1414
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
15+
16+
@classmethod
17+
def version_9(cls, node, **kwargs):
18+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

onnx_tf/handlers/backend/p_relu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,7 @@ def version_6(cls, node, **kwargs):
3232
@classmethod
3333
def version_7(cls, node, **kwargs):
3434
return cls._common(node, **kwargs)
35+
36+
@classmethod
37+
def version_9(cls, node, **kwargs):
38+
return cls._common(node, **kwargs)

onnx_tf/handlers/frontend/greater.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ def version_1(cls, node, **kwargs):
1515
@classmethod
1616
def version_7(cls, node, **kwargs):
1717
return cls.comparison_op(node, **kwargs)
18+
19+
@classmethod
20+
def version_9(cls, node, **kwargs):
21+
return cls.comparison_op(node, **kwargs)

onnx_tf/handlers/frontend/matmul.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class Matmul(FrontendHandler):
1010

1111
@classmethod
12-
def version_1(cls, node, **kwargs):
12+
def _common(cls, node, **kwargs):
1313
transpose_a = node.attr.get("transpose_a", False)
1414
transpose_b = node.attr.get("transpose_b", False)
1515
input_a = node.inputs[0]
@@ -33,3 +33,11 @@ def version_1(cls, node, **kwargs):
3333
nodes.append(transposed_b)
3434
nodes.append(cls.make_node_from_tf_node(node, [input_a, input_b]))
3535
return nodes
36+
37+
@classmethod
38+
def version_1(cls, node, **kwargs):
39+
return cls._common(node, **kwargs)
40+
41+
@classmethod
42+
def version_9(cls, node, **kwargs):
43+
return cls._common(node, **kwargs)

onnx_tf/opset_version.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
'Ceil': [1, 6],
1616
'Clip': [1, 6],
1717
'Concat': [1, 4],
18-
'Constant': [1],
18+
'Constant': [1, 9],
1919
'ConstantFill': [1],
2020
'ConstantLike': [9],
2121
'Conv': [1],
@@ -31,17 +31,17 @@
3131
'Exp': [1, 6],
3232
'Expand': [8],
3333
'EyeLike': [],
34-
'Flatten': [1],
34+
'Flatten': [1, 9],
3535
'Floor': [1, 6],
3636
'GRU': [1, 3, 7],
3737
'GRUUnit': [],
3838
'Gather': [1],
39-
'Gemm': [1, 6, 7],
39+
'Gemm': [1, 6, 7, 9],
4040
'GivenTensorFill': [],
4141
'GlobalAveragePool': [1],
4242
'GlobalLpPool': [1, 2],
4343
'GlobalMaxPool': [1],
44-
'Greater': [1, 7],
44+
'Greater': [1, 7, 9],
4545
'HardSigmoid': [1, 6],
4646
'Hardmax': [1],
4747
'Identity': [1],
@@ -57,7 +57,7 @@
5757
'Loop': [],
5858
'LpNormalization': [1],
5959
'LpPool': [],
60-
'MatMul': [1],
60+
'MatMul': [1, 9],
6161
'Max': [1, 6, 8],
6262
'MaxPool': [1, 8],
6363
'MaxRoiPool': [],
@@ -69,7 +69,7 @@
6969
'Neg': [1, 6],
7070
'Not': [1],
7171
'Or': [1, 7],
72-
'PRelu': [1, 6, 7],
72+
'PRelu': [1, 6, 7, 9],
7373
'Pad': [1, 2],
7474
'ParametricSoftplus': [],
7575
'Pow': [1, 7],

0 commit comments

Comments
 (0)