Skip to content

InputGen: introduce random manager #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:

- name: Run tests
run: |
python -m unittest discover -s test -p "*.py"
python -m unittest discover -s test/inputgen -p "*.py"
35 changes: 35 additions & 0 deletions examples/random_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from inputgen.argtuple.gen import ArgumentTupleGenerator
from inputgen.utils.random_manager import random_manager
from specdb.db import SpecDictDB


def main():
# example to seed all random number generators
random_manager.seed(1729)

spec = SpecDictDB["add.Tensor"]
op = torch.ops.aten.add.Tensor
for ix, (posargs, inkwargs, outargs) in enumerate(
ArgumentTupleGenerator(spec).gen()
):
op(*posargs, **inkwargs, **outargs)
print(
posargs[0].shape,
posargs[0].dtype,
posargs[1].shape,
posargs[1].dtype,
inkwargs["alpha"],
)
if ix == 1:
print(posargs[0])


if __name__ == "__main__":
main()
8 changes: 5 additions & 3 deletions inputgen/argument/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random
from typing import Any, List, Optional, Tuple, Union

import torch
Expand All @@ -13,6 +12,7 @@
from inputgen.attribute.model import Attribute
from inputgen.attribute.solve import AttributeSolver
from inputgen.specs.model import Constraint, ConstraintSuffix
from inputgen.utils.random_manager import random_manager as rm
from inputgen.variable.type import ScalarDtype


Expand Down Expand Up @@ -60,7 +60,9 @@ def gen_structure_with_depth_and_length(
yield from self.gen_structure_with_depth(depth, focus, length)
return

focus_ixs = range(length) if focus == attr else (random.choice(range(length)),)
focus_ixs = (
range(length) if focus == attr else (rm.get_random().choice(range(length)),)
)
for focus_ix in focus_ixs:
values = [()]
for ix in range(length):
Expand Down Expand Up @@ -241,7 +243,7 @@ def gen_value_spaces(self, focus, dtype, struct):
if focus == Attribute.VALUE:
return [v.space for v in variables]
else:
return [random.choice(variables).space]
return [rm.get_random().choice(variables).space]

def gen(self, focus):
# TODO(mcandales): Enable Tensor List generation
Expand Down
39 changes: 33 additions & 6 deletions inputgen/argument/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from inputgen.argument.engine import MetaArg
from inputgen.utils.random_manager import random_manager
from inputgen.variable.gen import VariableGenerator
from inputgen.variable.space import VariableSpace
from torch.testing._internal.common_dtype import floating_types, integral_types
Expand Down Expand Up @@ -41,6 +42,8 @@ def gen(self):
)

def get_random_tensor(self, size, dtype, high=None, low=None):
torch_rng = random_manager.get_torch()

if low is None and high is None:
low = -100
high = 100
Expand All @@ -55,7 +58,9 @@ def get_random_tensor(self, size, dtype, high=None, low=None):
elif not self.space.contains(1):
return torch.full(size, False, dtype=dtype)
else:
return torch.randint(low=0, high=2, size=size, dtype=dtype)
return torch.randint(
low=0, high=2, size=size, dtype=dtype, generator=torch_rng
)

if dtype in integral_types():
low = math.ceil(low)
Expand All @@ -68,16 +73,38 @@ def get_random_tensor(self, size, dtype, high=None, low=None):

if dtype == torch.uint8:
if not self.space.contains(0):
return torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
return torch.randint(
low=max(1, low),
high=high,
size=size,
dtype=dtype,
generator=torch_rng,
)
else:
return torch.randint(low=max(0, low), high=high, size=size, dtype=dtype)
return torch.randint(
low=max(0, low),
high=high,
size=size,
dtype=dtype,
generator=torch_rng,
)

t = torch.randint(low=low, high=high, size=size, dtype=dtype)
t = torch.randint(
low=low, high=high, size=size, dtype=dtype, generator=torch_rng
)
if not self.space.contains(0):
if high > 0:
pos = torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
pos = torch.randint(
low=max(1, low),
high=high,
size=size,
dtype=dtype,
generator=torch_rng,
)
else:
pos = torch.randint(low=low, high=0, size=size, dtype=dtype)
pos = torch.randint(
low=low, high=0, size=size, dtype=dtype, generator=torch_rng
)
t = torch.where(t == 0, pos, t)

if dtype in integral_types():
Expand Down
4 changes: 2 additions & 2 deletions inputgen/attribute/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from inputgen.attribute.solve import AttributeSolver
from inputgen.specs.model import Constraint
from inputgen.variable.gen import VariableGenerator
from inputgen.variable.type import ScalarDtype
from inputgen.variable.type import ScalarDtype, sort_values_of_type


class AttributeEngine(AttributeSolver):
Expand Down Expand Up @@ -51,4 +51,4 @@ def gen(self, focus: Attribute, *args):
if len(vals) == 0:
vals = VariableGenerator(variable.space).gen(num)
gen_vals.update(vals)
return gen_vals
return sort_values_of_type(self.vtype, gen_vals)
Empty file added inputgen/utils/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions inputgen/utils/random_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import random

import torch


class RandomManager:
def __init__(self):
self._rng = random.Random()
self._torch_rng = torch.Generator()

def seed(self, seed):
"""
Seeds the random number generators for random and torch.
"""
self._rng.seed(seed)
self._torch_rng.manual_seed(seed)

def get_random(self):
# self._rng.seed(42)
return self._rng

def get_torch(self):
# self._torch_rng.manual_seed(42)
return self._torch_rng


random_manager = RandomManager()
28 changes: 16 additions & 12 deletions inputgen/variable/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.

import math
import random
from typing import Any, List, Optional, Set, Union

from inputgen.utils.random_manager import random_manager as rm
from inputgen.variable.constants import BOUND_ON_INF, INT64_MAX, INT64_MIN
from inputgen.variable.space import Interval, Intervals, VariableSpace
from inputgen.variable.type import sort_values_of_type
from inputgen.variable.utils import nextdown, nextup


Expand Down Expand Up @@ -51,7 +52,7 @@ def gen_float_from_interval(r: Interval) -> Optional[float]:
elif lower > upper:
return None
else:
return random.uniform(lower, upper)
return rm.get_random().uniform(lower, upper)


def gen_min_float_from_intervals(rs: Intervals) -> Optional[float]:
Expand All @@ -69,7 +70,7 @@ def gen_max_float_from_intervals(rs: Intervals) -> Optional[float]:
def gen_float_from_intervals(rs: Intervals) -> Optional[float]:
if rs.empty():
return None
r = random.choice(rs.intervals)
r = rm.get_random().choice(rs.intervals)
return gen_float_from_interval(r)


Expand Down Expand Up @@ -112,7 +113,7 @@ def gen_int_from_interval(r: Interval) -> Optional[int]:
elif upper is None:
upper = max(lower, 0) + BOUND_ON_INF
assert lower is not None and upper is not None
return random.randint(lower, upper)
return rm.get_random().randint(lower, upper)


def gen_min_int_from_intervals(rs: Intervals) -> Optional[int]:
Expand All @@ -133,7 +134,7 @@ def gen_int_from_intervals(rs: Intervals) -> Optional[int]:
intervals_with_ints = [r for r in rs.intervals if r.contains_int()]
if len(intervals_with_ints) == 0:
return None
r = random.choice(intervals_with_ints)
r = rm.get_random().choice(intervals_with_ints)
return gen_int_from_interval(r)


Expand All @@ -147,6 +148,12 @@ def __init__(self, space: VariableSpace):
self.vtype = space.vtype
self.space = space

def _sorted(self, values: Set[Any]) -> List[Any]:
return sort_values_of_type(self.vtype, values)

def _sample(self, values: Set[Any], num: int) -> List[Any]:
return rm.get_random().sample(self._sorted(values), num)

def gen_min(self) -> Any:
"""Returns the minimum value of the space."""
if self.space.empty() or self.vtype not in [bool, int, float]:
Expand Down Expand Up @@ -221,7 +228,7 @@ def gen_edges_non_extreme(self, num: int = 2) -> Set[Any]:
edges_not_extreme = self.gen_edges() - self.gen_extremes()
if num >= len(edges_not_extreme):
return edges_not_extreme
return set(random.sample(list(edges_not_extreme), num))
return set(self._sample(edges_not_extreme, num))

def gen_non_edges(self, num: int = 2) -> Set[Any]:
"""Generates non-edge (or interior) values of the space."""
Expand All @@ -232,7 +239,7 @@ def gen_non_edges(self, num: int = 2) -> Set[Any]:
if self.space.discrete.initialized:
vals = self.space.discrete.values - edge_or_extreme_vals
if num < len(vals):
vals = set(random.sample(list(vals), num))
vals = set(self._sample(vals, num))
else:
for _ in range(100):
v: Optional[Union[int, float]] = None
Expand Down Expand Up @@ -269,11 +276,8 @@ def gen_balanced(self, num: int = 6) -> Set[Any]:

if num >= len(balanced):
return balanced
return set(random.sample(list(balanced), num))
return set(self._sample(balanced, num))

def gen(self, num: int = 6) -> List[Any]:
"""Generates a sorted (if applicable), balanced sample of the space."""
vals = list(self.gen_balanced(num))
if self.vtype in [bool, int, float, str]:
return sorted(vals)
return vals
return sort_values_of_type(self.vtype, self.gen_balanced(num))
12 changes: 11 additions & 1 deletion inputgen/variable/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import math
from enum import Enum
from typing import Any
from typing import Any, List, Set

import torch

Expand Down Expand Up @@ -93,3 +93,13 @@ def convert_to_vtype(vtype: type, v: Any) -> Any:
if vtype == float:
return float(v)
return v


def sort_values_of_type(vtype: type, values: Set[Any]) -> List[Any]:
if vtype in [bool, int, float, str, tuple]:
return sorted(values)
if vtype == torch.dtype:
return [v for v in SUPPORTED_TENSOR_DTYPES if v in values]
if vtype == ScalarDtype:
return [v for v in ScalarDtype if v in values]
return list(values)
Loading