Skip to content

Commit 291a5cd

Browse files
yueshengysjax authors
authored andcommitted
[PJRT][IFRT] Update PJRT, IFRT, and Py executable getters to return PjRtLayouts
PiperOrigin-RevId: 617889924
1 parent 383ae41 commit 291a5cd

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

jax/BUILD

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -720,11 +720,7 @@ pytype_strict_library(
720720
pytype_strict_library(
721721
name = "layout",
722722
srcs = ["_src/layout.py"],
723-
deps = [
724-
":util",
725-
":xla_bridge",
726-
"//jax/_src/lib",
727-
],
723+
deps = ["//jax/_src/lib"],
728724
)
729725

730726
pytype_strict_library(

jax/_src/layout.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
import re
18+
1719
from jax._src.lib import xla_client as xc
1820

1921

@@ -33,8 +35,7 @@ class SpecifiedLayout(XLACompatibleLayout):
3335

3436
def __init__(self, layout: xc.Layout):
3537
self._layout = layout
36-
self._layout_str = self._layout.to_string()
37-
self._minor_to_major = self._layout.minor_to_major()
38+
self._layout_str = str(self._layout)
3839

3940
def __repr__(self):
4041
return f'SpecifiedLayout({self._layout_str})'
@@ -50,6 +51,15 @@ def __eq__(self, other):
5051
def _to_xla_layout(self) -> str:
5152
return self._layout_str
5253

54+
@property
55+
def _minor_to_major(self):
56+
m = re.search("{([0-9,]*):", str(self))
57+
assert m is not None
58+
m2m_str = m.group(1)
59+
if m2m_str == "":
60+
return ()
61+
return tuple(int(x) for x in m2m_str.split(","))
62+
5363

5464
class LayoutRequest:
5565

0 commit comments

Comments
 (0)