Skip to content

Commit 100bfde

Browse files
committed
[Util] break out test input generation function & allow seed setting
1 parent a3451c5 commit 100bfde

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/qonnx/util/test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,20 @@ def qonnx_download_model():
145145
clize.run(download_model)
146146

147147

148-
def get_golden_in_and_output(test_model):
149-
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
150-
rng = np.random.RandomState(42)
148+
def get_random_input(test_model, seed=42):
149+
rng = np.random.RandomState(seed)
151150
input_shape = test_model_details[test_model]["input_shape"]
152151
(low, high) = test_model_details[test_model]["input_range"]
153152
size = np.prod(np.asarray(input_shape))
154153
input_tensor = rng.uniform(low=low, high=high, size=size)
155154
input_tensor = input_tensor.astype(np.float32)
156155
input_tensor = input_tensor.reshape(input_shape)
156+
return input_tensor
157+
158+
159+
def get_golden_in_and_output(test_model, seed=42):
160+
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
161+
input_tensor = get_random_input(test_model, seed=seed)
157162
input_dict = {model.graph.input[0].name: input_tensor}
158163
golden_output_dict = oxe.execute_onnx(model, input_dict)
159164
golden_result = golden_output_dict[model.graph.output[0].name]

0 commit comments

Comments
 (0)