File tree Expand file tree Collapse file tree 2 files changed +13
-7
lines changed Expand file tree Collapse file tree 2 files changed +13
-7
lines changed Original file line number Diff line number Diff line change @@ -720,11 +720,7 @@ pytype_strict_library(
720
720
pytype_strict_library (
721
721
name = "layout" ,
722
722
srcs = ["_src/layout.py" ],
723
- deps = [
724
- ":util" ,
725
- ":xla_bridge" ,
726
- "//jax/_src/lib" ,
727
- ],
723
+ deps = ["//jax/_src/lib" ],
728
724
)
729
725
730
726
pytype_strict_library (
Original file line number Diff line number Diff line change 14
14
15
15
from __future__ import annotations
16
16
17
+ import re
18
+
17
19
from jax ._src .lib import xla_client as xc
18
20
19
21
@@ -33,8 +35,7 @@ class SpecifiedLayout(XLACompatibleLayout):
33
35
34
36
def __init__ (self , layout : xc .Layout ):
35
37
self ._layout = layout
36
- self ._layout_str = self ._layout .to_string ()
37
- self ._minor_to_major = self ._layout .minor_to_major ()
38
+ self ._layout_str = str (self ._layout )
38
39
39
40
def __repr__ (self ):
40
41
return f'SpecifiedLayout({ self ._layout_str } )'
@@ -50,6 +51,15 @@ def __eq__(self, other):
50
51
def _to_xla_layout (self ) -> str :
51
52
return self ._layout_str
52
53
54
+ @property
55
+ def _minor_to_major (self ):
56
+ m = re .search ("{([0-9,]*):" , str (self ))
57
+ assert m is not None
58
+ m2m_str = m .group (1 )
59
+ if m2m_str == "" :
60
+ return ()
61
+ return tuple (int (x ) for x in m2m_str .split ("," ))
62
+
53
63
54
64
class LayoutRequest :
55
65
You can’t perform that action at this time.
0 commit comments