70
70
from jax ._src .sharding import Sharding
71
71
from jax ._src .sharding_impls import (PmapSharding , TransferToMemoryKind ,
72
72
XLACompatibleSharding )
73
- from jax ._src .layout import Layout
73
+ from jax ._src .layout import Layout , AutoLayout
74
74
from jax ._src .traceback_util import api_boundary
75
75
from jax ._src import tree_util
76
76
from jax ._src .util import unzip2 , safe_map , safe_zip , wrap_name , wraps
@@ -2710,22 +2710,34 @@ class ShapeDtypeStruct:
2710
2710
named_shape: (optional) a dictionary representing a named shape
2711
2711
sharding: (optional) a :class:`jax.Sharding` object
2712
2712
"""
2713
- __slots__ = ["shape" , "dtype" , "named_shape" , "sharding" ]
2713
+ __slots__ = ["shape" , "dtype" , "named_shape" , "sharding" , "_dll" ]
2714
+
2714
2715
def __init__ (self , shape , dtype , named_shape = None , sharding = None ):
2715
2716
self .shape = tuple (shape )
2716
2717
if dtype is None :
2717
2718
raise ValueError ("ShapeDtypeStruct: dtype must be specified." )
2718
2719
self .dtype = dtype if dtypes .issubdtype (dtype , dtypes .extended ) else np .dtype (dtype )
2719
- if sharding is not None and not isinstance (sharding , Sharding ):
2720
+ if sharding is not None and not isinstance (sharding , ( Sharding , Layout ) ):
2720
2721
raise ValueError (
2721
- "sharding should be an instance of `jax.sharding.Sharding`. "
2722
- f"Got { sharding } of type { type (sharding )} ." )
2723
- self .sharding = sharding
2722
+ "sharding should be an instance of `jax.sharding.Sharding` or"
2723
+ f" `jax.experimental.layout.Layout`. Got { sharding } of type"
2724
+ f" { type (sharding )} ." )
2725
+ if (isinstance (sharding , Layout ) and
2726
+ isinstance (sharding .device_local_layout , AutoLayout )):
2727
+ raise TypeError (
2728
+ "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local"
2729
+ f" layout in a `ShapeDtypeStruct`. Got { sharding } " )
2730
+ self .sharding = sharding .sharding if isinstance (sharding , Layout ) else sharding
2731
+ self ._dll = sharding .device_local_layout if isinstance (sharding , Layout ) else None
2724
2732
self .named_shape = {} if named_shape is None else dict (named_shape )
2725
2733
2726
2734
size = property (lambda self : math .prod (self .shape ))
2727
2735
ndim = property (lambda self : len (self .shape ))
2728
2736
2737
+ @property
2738
+ def layout (self ):
2739
+ return Layout (self ._dll , self .sharding )
2740
+
2729
2741
def __len__ (self ):
2730
2742
try :
2731
2743
return self .shape [0 ]
@@ -2735,28 +2747,31 @@ def __len__(self):
2735
2747
def __repr__ (self ):
2736
2748
ns = f", named_shape={ self .named_shape } " if self .named_shape else ""
2737
2749
sh = f", sharding={ self .sharding } " if self .sharding is not None else ""
2750
+ l = f", layout={ self .layout } " if self ._dll is not None else ""
2738
2751
return (f"{ type (self ).__name__ } (shape={ self .shape } , "
2739
- f"dtype={ self .dtype .name } { ns } { sh } )" )
2752
+ f"dtype={ self .dtype .name } { ns } { sh } { l } )" )
2740
2753
2741
2754
__str__ = __repr__
2742
2755
2743
2756
def __eq__ (self , other ):
2744
2757
if not isinstance (other , ShapeDtypeStruct ):
2745
2758
return False
2746
2759
else :
2747
- return ((other .shape , other .dtype , other .named_shape , other .sharding ) ==
2748
- (self .shape , self .dtype , self .named_shape , self .sharding ))
2760
+ return ((other .shape , other .dtype , other .named_shape , other .sharding , other . layout ) ==
2761
+ (self .shape , self .dtype , self .named_shape , self .sharding , self . layout ))
2749
2762
2750
2763
def __hash__ (self ):
2751
2764
# TODO(frostig): avoid the conversion from dict by addressing
2752
2765
# https://github.com/google/jax/issues/8182
2753
2766
named = frozenset (self .named_shape .items ())
2754
- return hash ((self .shape , self .dtype , named , self .sharding ))
2767
+ return hash ((self .shape , self .dtype , named , self .sharding , self .layout ))
2768
+
2755
2769
2756
2770
core .pytype_aval_mappings [ShapeDtypeStruct ] = (
2757
2771
lambda x : ShapedArray (x .shape , dtypes .canonicalize_dtype (x .dtype , allow_extended_dtype = True ),
2758
2772
weak_type = False , named_shape = x .named_shape ))
2759
2773
2774
+
2760
2775
@api_boundary
2761
2776
def eval_shape (fun : Callable , * args , ** kwargs ):
2762
2777
"""Compute the shape/dtype of ``fun`` without any FLOPs.
0 commit comments