Skip to content

Commit 3319ed9

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce random manager (#7)
Summary: Pull Request resolved: #7 Introduces a random manager in InputGen. This allows to generate reproducible data, by seeding the random manager. ``` from inputgen.utils.random_manager import random_manager random_manager.seed(1729) ``` Reviewed By: zonglinpengmeta Differential Revision: D59668295 fbshipit-source-id: b69337d1f1a4b29e8589dbe4d26c0d7832466399
1 parent a88dd7e commit 3319ed9

File tree

9 files changed

+136
-25
lines changed

9 files changed

+136
-25
lines changed

.github/workflows/python-app.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
3131
- name: Run tests
3232
run: |
33-
python -m unittest discover -s test -p "*.py"
33+
python -m unittest discover -s test/inputgen -p "*.py"

examples/random_seed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from inputgen.argtuple.gen import ArgumentTupleGenerator
9+
from inputgen.utils.random_manager import random_manager
10+
from specdb.db import SpecDictDB
11+
12+
13+
def main():
14+
# example to seed all random number generators
15+
random_manager.seed(1729)
16+
17+
spec = SpecDictDB["add.Tensor"]
18+
op = torch.ops.aten.add.Tensor
19+
for ix, (posargs, inkwargs, outargs) in enumerate(
20+
ArgumentTupleGenerator(spec).gen()
21+
):
22+
op(*posargs, **inkwargs, **outargs)
23+
print(
24+
posargs[0].shape,
25+
posargs[0].dtype,
26+
posargs[1].shape,
27+
posargs[1].dtype,
28+
inkwargs["alpha"],
29+
)
30+
if ix == 1:
31+
print(posargs[0])
32+
33+
34+
if __name__ == "__main__":
35+
main()

inputgen/argument/engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import random
87
from typing import Any, List, Optional, Tuple, Union
98

109
import torch
@@ -13,6 +12,7 @@
1312
from inputgen.attribute.model import Attribute
1413
from inputgen.attribute.solve import AttributeSolver
1514
from inputgen.specs.model import Constraint, ConstraintSuffix
15+
from inputgen.utils.random_manager import random_manager as rm
1616
from inputgen.variable.type import ScalarDtype
1717

1818

@@ -60,7 +60,9 @@ def gen_structure_with_depth_and_length(
6060
yield from self.gen_structure_with_depth(depth, focus, length)
6161
return
6262

63-
focus_ixs = range(length) if focus == attr else (random.choice(range(length)),)
63+
focus_ixs = (
64+
range(length) if focus == attr else (rm.get_random().choice(range(length)),)
65+
)
6466
for focus_ix in focus_ixs:
6567
values = [()]
6668
for ix in range(length):
@@ -241,7 +243,7 @@ def gen_value_spaces(self, focus, dtype, struct):
241243
if focus == Attribute.VALUE:
242244
return [v.space for v in variables]
243245
else:
244-
return [random.choice(variables).space]
246+
return [rm.get_random().choice(variables).space]
245247

246248
def gen(self, focus):
247249
# TODO(mcandales): Enable Tensor List generation

inputgen/argument/gen.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from inputgen.argument.engine import MetaArg
12+
from inputgen.utils.random_manager import random_manager
1213
from inputgen.variable.gen import VariableGenerator
1314
from inputgen.variable.space import VariableSpace
1415
from torch.testing._internal.common_dtype import floating_types, integral_types
@@ -41,6 +42,8 @@ def gen(self):
4142
)
4243

4344
def get_random_tensor(self, size, dtype, high=None, low=None):
45+
torch_rng = random_manager.get_torch()
46+
4447
if low is None and high is None:
4548
low = -100
4649
high = 100
@@ -55,7 +58,9 @@ def get_random_tensor(self, size, dtype, high=None, low=None):
5558
elif not self.space.contains(1):
5659
return torch.full(size, False, dtype=dtype)
5760
else:
58-
return torch.randint(low=0, high=2, size=size, dtype=dtype)
61+
return torch.randint(
62+
low=0, high=2, size=size, dtype=dtype, generator=torch_rng
63+
)
5964

6065
if dtype in integral_types():
6166
low = math.ceil(low)
@@ -68,16 +73,38 @@ def get_random_tensor(self, size, dtype, high=None, low=None):
6873

6974
if dtype == torch.uint8:
7075
if not self.space.contains(0):
71-
return torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
76+
return torch.randint(
77+
low=max(1, low),
78+
high=high,
79+
size=size,
80+
dtype=dtype,
81+
generator=torch_rng,
82+
)
7283
else:
73-
return torch.randint(low=max(0, low), high=high, size=size, dtype=dtype)
84+
return torch.randint(
85+
low=max(0, low),
86+
high=high,
87+
size=size,
88+
dtype=dtype,
89+
generator=torch_rng,
90+
)
7491

75-
t = torch.randint(low=low, high=high, size=size, dtype=dtype)
92+
t = torch.randint(
93+
low=low, high=high, size=size, dtype=dtype, generator=torch_rng
94+
)
7695
if not self.space.contains(0):
7796
if high > 0:
78-
pos = torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
97+
pos = torch.randint(
98+
low=max(1, low),
99+
high=high,
100+
size=size,
101+
dtype=dtype,
102+
generator=torch_rng,
103+
)
79104
else:
80-
pos = torch.randint(low=low, high=0, size=size, dtype=dtype)
105+
pos = torch.randint(
106+
low=low, high=0, size=size, dtype=dtype, generator=torch_rng
107+
)
81108
t = torch.where(t == 0, pos, t)
82109

83110
if dtype in integral_types():

inputgen/attribute/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from inputgen.attribute.solve import AttributeSolver
1313
from inputgen.specs.model import Constraint
1414
from inputgen.variable.gen import VariableGenerator
15-
from inputgen.variable.type import ScalarDtype
15+
from inputgen.variable.type import ScalarDtype, sort_values_of_type
1616

1717

1818
class AttributeEngine(AttributeSolver):
@@ -51,4 +51,4 @@ def gen(self, focus: Attribute, *args):
5151
if len(vals) == 0:
5252
vals = VariableGenerator(variable.space).gen(num)
5353
gen_vals.update(vals)
54-
return gen_vals
54+
return sort_values_of_type(self.vtype, gen_vals)

inputgen/utils/__init__.py

Whitespace-only changes.

inputgen/utils/random_manager.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import random
8+
9+
import torch
10+
11+
12+
class RandomManager:
13+
def __init__(self):
14+
self._rng = random.Random()
15+
self._torch_rng = torch.Generator()
16+
17+
def seed(self, seed):
18+
"""
19+
Seeds the random number generators for random and torch.
20+
"""
21+
self._rng.seed(seed)
22+
self._torch_rng.manual_seed(seed)
23+
24+
def get_random(self):
25+
# self._rng.seed(42)
26+
return self._rng
27+
28+
def get_torch(self):
29+
# self._torch_rng.manual_seed(42)
30+
return self._torch_rng
31+
32+
33+
random_manager = RandomManager()

inputgen/variable/gen.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8-
import random
98
from typing import Any, List, Optional, Set, Union
109

10+
from inputgen.utils.random_manager import random_manager as rm
1111
from inputgen.variable.constants import BOUND_ON_INF, INT64_MAX, INT64_MIN
1212
from inputgen.variable.space import Interval, Intervals, VariableSpace
13+
from inputgen.variable.type import sort_values_of_type
1314
from inputgen.variable.utils import nextdown, nextup
1415

1516

@@ -51,7 +52,7 @@ def gen_float_from_interval(r: Interval) -> Optional[float]:
5152
elif lower > upper:
5253
return None
5354
else:
54-
return random.uniform(lower, upper)
55+
return rm.get_random().uniform(lower, upper)
5556

5657

5758
def gen_min_float_from_intervals(rs: Intervals) -> Optional[float]:
@@ -69,7 +70,7 @@ def gen_max_float_from_intervals(rs: Intervals) -> Optional[float]:
6970
def gen_float_from_intervals(rs: Intervals) -> Optional[float]:
7071
if rs.empty():
7172
return None
72-
r = random.choice(rs.intervals)
73+
r = rm.get_random().choice(rs.intervals)
7374
return gen_float_from_interval(r)
7475

7576

@@ -112,7 +113,7 @@ def gen_int_from_interval(r: Interval) -> Optional[int]:
112113
elif upper is None:
113114
upper = max(lower, 0) + BOUND_ON_INF
114115
assert lower is not None and upper is not None
115-
return random.randint(lower, upper)
116+
return rm.get_random().randint(lower, upper)
116117

117118

118119
def gen_min_int_from_intervals(rs: Intervals) -> Optional[int]:
@@ -133,7 +134,7 @@ def gen_int_from_intervals(rs: Intervals) -> Optional[int]:
133134
intervals_with_ints = [r for r in rs.intervals if r.contains_int()]
134135
if len(intervals_with_ints) == 0:
135136
return None
136-
r = random.choice(intervals_with_ints)
137+
r = rm.get_random().choice(intervals_with_ints)
137138
return gen_int_from_interval(r)
138139

139140

@@ -147,6 +148,12 @@ def __init__(self, space: VariableSpace):
147148
self.vtype = space.vtype
148149
self.space = space
149150

151+
def _sorted(self, values: Set[Any]) -> List[Any]:
152+
return sort_values_of_type(self.vtype, values)
153+
154+
def _sample(self, values: Set[Any], num: int) -> List[Any]:
155+
return rm.get_random().sample(self._sorted(values), num)
156+
150157
def gen_min(self) -> Any:
151158
"""Returns the minimum value of the space."""
152159
if self.space.empty() or self.vtype not in [bool, int, float]:
@@ -221,7 +228,7 @@ def gen_edges_non_extreme(self, num: int = 2) -> Set[Any]:
221228
edges_not_extreme = self.gen_edges() - self.gen_extremes()
222229
if num >= len(edges_not_extreme):
223230
return edges_not_extreme
224-
return set(random.sample(list(edges_not_extreme), num))
231+
return set(self._sample(edges_not_extreme, num))
225232

226233
def gen_non_edges(self, num: int = 2) -> Set[Any]:
227234
"""Generates non-edge (or interior) values of the space."""
@@ -232,7 +239,7 @@ def gen_non_edges(self, num: int = 2) -> Set[Any]:
232239
if self.space.discrete.initialized:
233240
vals = self.space.discrete.values - edge_or_extreme_vals
234241
if num < len(vals):
235-
vals = set(random.sample(list(vals), num))
242+
vals = set(self._sample(vals, num))
236243
else:
237244
for _ in range(100):
238245
v: Optional[Union[int, float]] = None
@@ -269,11 +276,8 @@ def gen_balanced(self, num: int = 6) -> Set[Any]:
269276

270277
if num >= len(balanced):
271278
return balanced
272-
return set(random.sample(list(balanced), num))
279+
return set(self._sample(balanced, num))
273280

274281
def gen(self, num: int = 6) -> List[Any]:
275282
"""Generates a sorted (if applicable), balanced sample of the space."""
276-
vals = list(self.gen_balanced(num))
277-
if self.vtype in [bool, int, float, str]:
278-
return sorted(vals)
279-
return vals
283+
return sort_values_of_type(self.vtype, self.gen_balanced(num))

inputgen/variable/type.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import math
88
from enum import Enum
9-
from typing import Any
9+
from typing import Any, List, Set
1010

1111
import torch
1212

@@ -93,3 +93,13 @@ def convert_to_vtype(vtype: type, v: Any) -> Any:
9393
if vtype == float:
9494
return float(v)
9595
return v
96+
97+
98+
def sort_values_of_type(vtype: type, values: Set[Any]) -> List[Any]:
99+
if vtype in [bool, int, float, str, tuple]:
100+
return sorted(values)
101+
if vtype == torch.dtype:
102+
return [v for v in SUPPORTED_TENSOR_DTYPES if v in values]
103+
if vtype == ScalarDtype:
104+
return [v for v in ScalarDtype if v in values]
105+
return list(values)

0 commit comments

Comments
 (0)