Skip to content

Commit a1facc5

Browse files
authored
hardswish support for tflite (#1735)
Signed-off-by: Guenther Schmuelling <guschmue@microsoft.com>
1 parent 52ff39e commit a1facc5

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5696,6 +5696,19 @@ def func(x):
56965696
# can't check the values because in onnx they are padded with 0, in tf they are not
56975697
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False)
56985698

5699+
@check_tf_min_version("2.5")
5700+
@check_opset_min_version(14, "hardswish")
5701+
@skip_tfjs("not supported in tfjs")
5702+
def test_hardswish(self):
5703+
def func(x):
5704+
# there is no hardswich in tf but toco will optimize to it
5705+
op_ = x * tf.nn.relu6(x + np.float32(3)) * np.float32(1. / 6.)
5706+
return tf.identity(op_, name=_TFOUTPUT)
5707+
5708+
# tf gets this wrong and returns fp32 instead of int
5709+
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
5710+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5711+
56995712

57005713
if __name__ == '__main__':
57015714
unittest_main()

tf2onnx/onnx_opset/math.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,3 +806,11 @@ def version_11(cls, ctx, node, **kwargs):
806806
shapes=[shape], dtypes=[onnx_dtype])
807807

808808
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
809+
810+
811+
@tf_op(["HardSwish"])
812+
class HardSwish:
813+
# Note: this doesn't really exist in tensorflow but it does in tflite
814+
@classmethod
815+
def version_14(cls, ctx, node, **kwargs):
816+
pass

tf2onnx/tflite_handlers/tfl_direct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
@tfl_op("TFL_CUMSUM", tf_op="Cumsum")
8383
@tfl_op("TFL_RFFT2D", tf_op="RFFT2D")
8484
@tfl_op("TFL_COMPLEX_ABS", tf_op="ComplexAbs")
85+
@tfl_op("TFL_HARD_SWISH", tf_op="HardSwish")
8586
class TflDirectOp:
8687
@classmethod
8788
def to_tf(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)