22
22
import itertools
23
23
import math
24
24
import operator
25
- from typing import Any , Callable , TypeVar , Union , cast as type_cast , overload , TYPE_CHECKING
25
+ from typing import Any , Callable , ClassVar , TypeVar , Union , cast as type_cast , overload , TYPE_CHECKING
26
26
import warnings
27
27
28
28
import numpy as np
@@ -624,14 +624,14 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array:
624
624
625
625
_precision_strings : dict [Any , Precision ] = {}
626
626
627
- # TODO(b/328046715): pytype appears unable to handle overriding __new__ in an
628
- # enum class. Doing this crashes Pytype. For now, just write an explicit type
629
- # for type checkers.
627
+ # TODO(b/333851820): pytype does not properly handle _missing_ in enums.
628
+ # We work around that by defining `Precision` as a normal class.
630
629
if TYPE_CHECKING :
630
+
631
631
class Precision :
632
- DEFAULT : Precision
633
- HIGH : Precision
634
- HIGHEST : Precision
632
+ DEFAULT : ClassVar [ Precision ]
633
+ HIGH : ClassVar [ Precision ]
634
+ HIGHEST : ClassVar [ Precision ]
635
635
636
636
def __new__ (cls , value : Precision | int | str | None ) -> Precision :
637
637
raise NotImplementedError
@@ -645,6 +645,7 @@ def value(self) -> int:
645
645
raise NotImplementedError
646
646
647
647
else :
648
+
648
649
class Precision (enum .Enum ):
649
650
"""Precision enum for lax functions
650
651
@@ -663,23 +664,21 @@ class Precision(enum.Enum):
663
664
Slowest but most accurate. Performs computations in float32 or float64
664
665
as applicable. Aliases: ``'highest'``, ``'float32'``.
665
666
"""
667
+
666
668
DEFAULT = 0
667
669
HIGH = 1
668
670
HIGHEST = 2
669
671
672
+ @classmethod
673
+ def _missing_ (cls , value : object ) -> Precision | None :
674
+ return _precision_strings .get (value )
675
+
670
676
def __repr__ (self ) -> str :
671
- return f" { self .__class__ .__name__ } .{ self .name } "
677
+ return f' { self .__class__ .__name__ } .{ self .name } '
672
678
673
679
def __str__ (self ) -> str :
674
680
return self .name
675
681
676
- # You can't define __new__ on an enum class directly, but you can monkey-patch
677
- # it after the fact. Another way to do this might be using a metaclass.
678
- def _precision_new (cls , value : Precision | int | str | None ) -> Precision :
679
- return super (Precision , cls ).__new__ (cls , _precision_strings .get (value , value ))
680
-
681
- Precision .__new__ = _precision_new
682
-
683
682
684
683
_precision_strings ['highest' ] = Precision .HIGHEST
685
684
_precision_strings ['float32' ] = Precision .HIGHEST
0 commit comments