Skip to content

Commit 29b76df

Browse files
hwangdeyufatcat-z
andauthored
Add TensorScatterAdd op for opset 16 (#1949)
* add TensorScatterAdd op conversion in opset 16 Signed-off-by: Deyu Huang <deyhuang@microsoft.com> * add unit test Signed-off-by: Deyu Huang <deyhuang@microsoft.com> * fix pylint name Signed-off-by: Deyu Huang <deyhuang@microsoft.com> * fix typo and test op check Signed-off-by: Deyu Huang <deyhuang@microsoft.com> Co-authored-by: Jay Zhang <jiz@microsoft.com> Co-authored-by: Jay Zhang <jiz@microsoft.com>
1 parent 6f5a673 commit 29b76df

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

tests/test_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4698,6 +4698,17 @@ def func(x, y, z):
46984698
return tf.identity(x_, name=_TFOUTPUT)
46994699
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
47004700

4701+
@check_opset_min_version(16, "ScatterND")
4702+
def test_scatternd_add(self):
4703+
x_val = np.array([10, 20, 30, 40], dtype=np.int32).reshape((4))
4704+
y_val = np.array([0, 2], dtype=np.int64).reshape((2, 1))
4705+
z_val = np.array([20, 30], dtype=np.int32).reshape((2))
4706+
4707+
def func(x, y, z):
4708+
x_ = tf.tensor_scatter_nd_add(x, y, z)
4709+
return tf.identity(x_, name=_TFOUTPUT)
4710+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
4711+
47014712
@check_opset_min_version(11, "ScatterND")
47024713
def test_scatternd_1d(self):
47034714
x_val = np.array([4, 3, 1, 7], dtype=np.int32).reshape((4, 1))

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,16 @@ def version_11(cls, ctx, node, **kwargs):
655655
ctx.replace_inputs(node, [node.input[2], node.input[0], node.input[1]])
656656

657657

658+
@tf_op("TensorScatterAdd", onnx_op="ScatterND")
659+
class TensorScatterAdd:
660+
@classmethod
661+
def version_16(cls, ctx, node, **kwargs):
662+
# indicies input must be int64 in ONNX.
663+
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
664+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
665+
node.set_attr("reduction", 'add')
666+
667+
658668
@tf_op("TensorScatterUpdate", onnx_op="ScatterND")
659669
class TensorScatterUpdate:
660670
@classmethod

0 commit comments

Comments
 (0)