Skip to content

Commit 2be7205

Browse files
superbobryjax authors
authored andcommitted
lax.Precision now uses _missing_ to handle aliases
Note that pytype does not support _missing_, so unfortunately we still have to have a separate definition for type checkers for the time being. PiperOrigin-RevId: 623820876
1 parent 477a44f commit 2be7205

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

jax/_src/lax/lax.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import itertools
2323
import math
2424
import operator
25-
from typing import Any, Callable, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
25+
from typing import Any, Callable, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING
2626
import warnings
2727

2828
import numpy as np
@@ -624,14 +624,14 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
624624

625625
_precision_strings: dict[Any, Precision] = {}
626626

627-
# TODO(b/328046715): pytype appears unable to handle overriding __new__ in an
628-
# enum class. Doing this crashes Pytype. For now, just write an explicit type
629-
# for type checkers.
627+
# TODO(b/333851820): pytype does not properly handle _missing_ in enums.
628+
# We work around that by defining `Precision` as a normal class.
630629
if TYPE_CHECKING:
630+
631631
class Precision:
632-
DEFAULT: Precision
633-
HIGH: Precision
634-
HIGHEST: Precision
632+
DEFAULT: ClassVar[Precision]
633+
HIGH: ClassVar[Precision]
634+
HIGHEST: ClassVar[Precision]
635635

636636
def __new__(cls, value: Precision | int | str | None) -> Precision:
637637
raise NotImplementedError
@@ -645,6 +645,7 @@ def value(self) -> int:
645645
raise NotImplementedError
646646

647647
else:
648+
648649
class Precision(enum.Enum):
649650
"""Precision enum for lax functions
650651
@@ -663,23 +664,21 @@ class Precision(enum.Enum):
663664
Slowest but most accurate. Performs computations in float32 or float64
664665
as applicable. Aliases: ``'highest'``, ``'float32'``.
665666
"""
667+
666668
DEFAULT = 0
667669
HIGH = 1
668670
HIGHEST = 2
669671

672+
@classmethod
673+
def _missing_(cls, value: object) -> Precision | None:
674+
return _precision_strings.get(value)
675+
670676
def __repr__(self) -> str:
671-
return f"{self.__class__.__name__}.{self.name}"
677+
return f'{self.__class__.__name__}.{self.name}'
672678

673679
def __str__(self) -> str:
674680
return self.name
675681

676-
# You can't define __new__ on an enum class directly, but you can monkey-patch
677-
# it after the fact. Another way to do this might be using a metaclass.
678-
def _precision_new(cls, value: Precision | int | str | None) -> Precision:
679-
return super(Precision, cls).__new__(cls, _precision_strings.get(value, value))
680-
681-
Precision.__new__ = _precision_new
682-
683682

684683
_precision_strings['highest'] = Precision.HIGHEST
685684
_precision_strings['float32'] = Precision.HIGHEST

0 commit comments

Comments
 (0)