Skip to content

Commit e099356

Browse files
authored
Update the way to check input_signature in from_function(). (#1947)
* Update the way to check input_signature in from_function(). Signed-off-by: Jay Zhang <jiz@microsoft.com>
1 parent 880754e commit e099356

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

tests/test_api.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,23 @@ def func(foo, a, x, b, w):
173173
res_onnx = self.run_onnxruntime(output_path, {"x": x, "w": w}, output_names)
174174
self.assertAllClose(res_tf, res_onnx[0], rtol=1e-5, atol=1e-5)
175175

176+
@check_tf_min_version("2.0")
177+
def test_function_nparray(self):
178+
@tf.function
179+
def func(x):
180+
return tf.math.sqrt(x)
181+
182+
output_path = os.path.join(self.test_data_directory, "model.onnx")
183+
x = np.asarray([1.0, 2.0])
184+
185+
res_tf = func(x)
186+
spec = np.asarray([[1.0, 2.0]])
187+
model_proto, _ = tf2onnx.convert.from_function(func, input_signature=spec,
188+
opset=self.config.opset, output_path=output_path)
189+
output_names = [n.name for n in model_proto.graph.output]
190+
res_onnx = self.run_onnxruntime(output_path, {'x': x}, output_names)
191+
self.assertAllClose(res_tf, res_onnx[0], rtol=1e-5, atol=1e-5)
192+
176193
@check_tf_min_version("1.15")
177194
def _test_graphdef(self):
178195
def func(x, y):

tf2onnx/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c
535535
if LooseVersion(tf.__version__) < "2.0":
536536
raise NotImplementedError("from_function requires tf-2.0 or newer")
537537

538-
if not input_signature:
538+
if input_signature is None:
539539
raise ValueError("from_function requires input_signature")
540540

541541
concrete_func = function.get_concrete_function(*input_signature)

0 commit comments

Comments
 (0)