Skip to content

Commit 8694a6d

Browse files
committed
[Util] handle per-channel ranges in get_random_input
1 parent ee7464f commit 8694a6d

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/qonnx/util/test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ def get_random_input(test_model, seed=42):
220220
rng = np.random.RandomState(seed)
221221
input_shape = test_model_details[test_model]["input_shape"]
222222
(low, high) = test_model_details[test_model]["input_range"]
223+
# some models spec per-channel ranges, be conservative for those
224+
if isinstance(low, np.ndarray):
225+
low = low.max()
226+
if isinstance(high, np.ndarray):
227+
high = high.min()
223228
size = np.prod(np.asarray(input_shape))
224229
input_tensor = rng.uniform(low=low, high=high, size=size)
225230
input_tensor = input_tensor.astype(np.float32)

0 commit comments

Comments
 (0)