Skip to content

Commit a47dd76

Browse files
hmc-cs-mdrissiMehdi Drissi
andauthored
Tensorflow keras layer (#9707)
Co-authored-by: Mehdi Drissi <mdrissi@snapchat.com>
1 parent 2fe634f commit a47dd76

File tree

10 files changed

+398
-2
lines changed

10 files changed

+398
-2
lines changed

stubs/tensorflow/@tests/stubtest_allowlist.txt

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,52 @@ tensorflow.DType.__getattr__
1515
tensorflow.Graph.__getattr__
1616
tensorflow.Operation.__getattr__
1717
tensorflow.Variable.__getattr__
18+
tensorflow.keras.layers.Layer.__getattr__
1819
# Internal undocumented API
1920
tensorflow.RaggedTensor.__init__
2021
# Has an undocumented extra argument that tf.Variable which acts like subclass
2122
# (by dynamically patching tf.Tensor methods) does not preserve.
2223
tensorflow.Tensor.__getitem__
24+
# stub internal utility
25+
tensorflow._aliases
26+
27+
# Tensorflow imports are cursed.
28+
# import tensorflow.initializers
29+
# import tensorflow as tf
30+
# tf.initializers
31+
# Usually these two ways are same module, but for tensorflow the first way
32+
# often does not work and the second way does. The documentation describes
33+
# tf.initializers as module and has that type if accessed the second way,
34+
# but the real module file is completely different name (even package) and dynamically handled.
35+
# tf.initializers at runtime is <module 'keras.api._v2.keras.initializers' from '...'>
36+
tensorflow.initializers
37+
38+
# Layer constructor's always have **kwargs, but only allow a few specific values. PEP 692
39+
# would allow us to specify this with **kwargs and remove the need for these exceptions.
40+
tensorflow.keras.layers.*.__init__
41+
42+
# __call__ in tensorflow classes often allow keyword usage, but
43+
# when you subclass those classes it is not expected to handle keyword case. As an example,
44+
# class MyLayer(tf.keras.layers.Layer):
45+
# def call(self, x):
46+
# ...
47+
# is common even though Layer.call is defined like def call(self, inputs). Treating inputs as
48+
# a keyword argument would lead to many false positives with typical subclass usage.
49+
# Additional awkwardness for Layer's is call may optionally have training/mask as keyword arguments and some
50+
# layers do while others do not. At runtime call is not intended to be used directly by users,
51+
# but instead through __call__ which extracts out the training/mask arguments. Trying to describe
52+
# this better in stubs would similarly add many false positive Liskov violations.
53+
tensorflow.keras.layers.*.call
54+
tensorflow.keras.regularizers.Regularizer.__call__
55+
tensorflow.keras.constraints.Constraint.__call__
56+
57+
# Layer class does good deal of __new__ magic and actually returns one of two different internal
58+
# types depending on tensorflow execution mode. This feels like implementation internal.
59+
tensorflow.keras.layers.Layer.__new__
60+
61+
# build/compute_output_shape are marked positional only in stubs
62+
# as argument name is inconsistent across layer's and looks like
63+
# an implementation detail as documentation never mentions the
64+
# disagreements.
65+
tensorflow.keras.layers.*.build
66+
tensorflow.keras.layers.*.compute_output_shape

stubs/tensorflow/tensorflow/__init__.pyi

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ from builtins import bool as _bool
44
from collections.abc import Callable, Iterable, Iterator, Sequence
55
from contextlib import contextmanager
66
from enum import Enum
7-
from typing import Any, NoReturn, overload
8-
from typing_extensions import Self, TypeAlias
7+
from types import TracebackType
8+
from typing import Any, NoReturn, TypeVar, overload
9+
from typing_extensions import ParamSpec, Self, TypeAlias
910

1011
import numpy
12+
from tensorflow import initializers as initializers, keras as keras, math as math
1113

1214
# Explicit import of DType is covered by the wildcard, but
1315
# is necessary to avoid a crash in pytype.
@@ -253,4 +255,31 @@ class IndexedSlices(metaclass=ABCMeta):
253255
def __neg__(self) -> IndexedSlices: ...
254256
def consumers(self) -> list[Operation]: ...
255257

258+
class name_scope:
259+
def __init__(self, name: str) -> None: ...
260+
def __enter__(self) -> str: ...
261+
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
262+
263+
_P = ParamSpec("_P")
264+
_R = TypeVar("_R")
265+
266+
class Module:
267+
def __init__(self, name: str | None = None) -> None: ...
268+
@property
269+
def name(self) -> str: ...
270+
@property
271+
def name_scope(self) -> name_scope: ...
272+
# Documentation only specifies these as returning Sequence. Actual
273+
# implementation does tuple.
274+
@property
275+
def variables(self) -> Sequence[Variable]: ...
276+
@property
277+
def trainable_variables(self) -> Sequence[Variable]: ...
278+
@property
279+
def non_trainable_variables(self) -> Sequence[Variable]: ...
280+
@property
281+
def submodules(self) -> Sequence[Module]: ...
282+
@classmethod
283+
def with_name_scope(cls, method: Callable[_P, _R]) -> Callable[_P, _R]: ...
284+
256285
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Commonly used type aliases.
2+
# Everything in this module is private for stubs. There is no runtime
3+
# equivalent.
4+
5+
from collections.abc import Mapping, Sequence
6+
from typing import Any, TypeVar
7+
from typing_extensions import TypeAlias
8+
9+
import numpy
10+
11+
_T1 = TypeVar("_T1")
12+
ContainerGeneric: TypeAlias = Mapping[str, ContainerGeneric[_T1]] | Sequence[ContainerGeneric[_T1]] | _T1
13+
14+
AnyArray: TypeAlias = numpy.ndarray[Any, Any]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from tensorflow.keras.initializers import *
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from _typeshed import Incomplete
2+
3+
from tensorflow.keras import (
4+
activations as activations,
5+
constraints as constraints,
6+
initializers as initializers,
7+
layers as layers,
8+
regularizers as regularizers,
9+
)
10+
11+
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from _typeshed import Incomplete
2+
from collections.abc import Callable
3+
from typing import Any
4+
from typing_extensions import TypeAlias
5+
6+
from tensorflow import Tensor
7+
8+
# The implementation uses isinstance so it must be dict and not any Mapping.
9+
_Activation: TypeAlias = str | None | Callable[[Tensor], Tensor] | dict[str, Any]
10+
11+
def get(identifier: _Activation) -> Callable[[Tensor], Tensor]: ...
12+
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from _typeshed import Incomplete
2+
from collections.abc import Callable
3+
from typing import Any, overload
4+
5+
from tensorflow import Tensor
6+
7+
class Constraint:
8+
def get_config(self) -> dict[str, Any]: ...
9+
def __call__(self, __w: Tensor) -> Tensor: ...
10+
11+
@overload
12+
def get(identifier: None) -> None: ...
13+
@overload
14+
def get(identifier: str | dict[str, Any] | Constraint) -> Constraint: ...
15+
@overload
16+
def get(identifier: Callable[[Tensor], Tensor]) -> Callable[[Tensor], Tensor]: ...
17+
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from _typeshed import Incomplete
2+
from collections.abc import Callable
3+
from typing import Any, overload
4+
from typing_extensions import Self, TypeAlias
5+
6+
from tensorflow import Tensor, _DTypeLike, _ShapeLike, _TensorCompatible
7+
8+
class Initializer:
9+
def __call__(self, shape: _ShapeLike, dtype: _DTypeLike | None = None) -> Tensor: ...
10+
def get_config(self) -> dict[str, Any]: ...
11+
@classmethod
12+
def from_config(cls, config: dict[str, Any]) -> Self: ...
13+
14+
class Constant(Initializer):
15+
def __init__(self, value: _TensorCompatible = 0) -> None: ...
16+
17+
class GlorotNormal(Initializer):
18+
def __init__(self, seed: int | None = None) -> None: ...
19+
20+
class GlorotUniform(Initializer):
21+
def __init__(self, seed: int | None = None) -> None: ...
22+
23+
class TruncatedNormal(Initializer):
24+
def __init__(self, mean: _TensorCompatible = 0.0, stddev: _TensorCompatible = 0.05, seed: int | None = None) -> None: ...
25+
26+
class RandomNormal(Initializer):
27+
def __init__(self, mean: _TensorCompatible = 0.0, stddev: _TensorCompatible = 0.05, seed: int | None = None) -> None: ...
28+
29+
class RandomUniform(Initializer):
30+
def __init__(self, minval: _TensorCompatible = -0.05, maxval: _TensorCompatible = 0.05, seed: int | None = None) -> None: ...
31+
32+
class Zeros(Initializer): ...
33+
34+
constant = Constant
35+
glorot_normal = GlorotNormal
36+
glorot_uniform = GlorotUniform
37+
truncated_normal = TruncatedNormal
38+
zeros = Zeros
39+
40+
_Initializer: TypeAlias = ( # noqa: Y047
41+
str | Initializer | type[Initializer] | Callable[[_ShapeLike], Tensor] | dict[str, Any] | None
42+
)
43+
44+
@overload
45+
def get(identifier: None) -> None: ...
46+
@overload
47+
def get(identifier: str | Initializer | dict[str, Any] | type[Initializer]) -> Initializer: ...
48+
@overload
49+
def get(identifier: Callable[[_ShapeLike], Tensor]) -> Callable[[_ShapeLike], Tensor]: ...
50+
def __getattr__(name: str) -> Incomplete: ...

0 commit comments

Comments
 (0)