Skip to content

Commit 3d78af3

Browse files
committed
[monarch] initial commit of tensor_engine submodule
Now that the actor layer is place, time to refactor the tensor engine following the same principles: - Every symbol has at most two locations: a private location (which should be exactly the same as its location in the source tree), and optionally a public location in the user-facing API. - We discourage end-users from depending on private locations via underscores, per python convention. In particular, we violate principle 2 in many places in the tensor engine—all of our top-level files/modules are visible if you do `from monarch.foo import bar`, and we are sloppy about the distinction of public and private. So this refactor will be a little bit painful. I will approach it in parts. This diff defines the `tensor_engine` submodule and moves `common` to `monarch.tensor_engine._common`. Differential Revision: [D77847587](https://our.internmc.facebook.com/intern/diff/D77847587/) ghstack-source-id: 294791348 Pull Request resolved: #458
1 parent ee50995 commit 3d78af3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+431
-382
lines changed

monarch_extension/src/convert.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct MessageParser<'a> {
4444
fn create_function(obj: Bound<'_, PyAny>) -> PyResult<ResolvableFunction> {
4545
let cloudpickle = obj
4646
.py()
47-
.import("monarch.common.function")?
47+
.import("monarch._src.tensor_engine.common.function")?
4848
.getattr("ResolvableFromCloudpickle")?;
4949
if obj.is_instance(&cloudpickle)? {
5050
Ok(ResolvableFunction::Cloudpickle(Cloudpickle::new(
@@ -102,7 +102,7 @@ impl<'a> MessageParser<'a> {
102102
let referenceable = self
103103
.current
104104
.py()
105-
.import("monarch.common.reference")?
105+
.import("monarch._src.tensor_engine.common.reference")?
106106
.getattr("Referenceable")?;
107107
let mut flat: Vec<Option<Ref>> = vec![];
108108
for x in output_tuple.0.try_iter()? {
@@ -198,8 +198,8 @@ static CONVERT_MAP: OnceLock<HashMap<u64, FnType>> = OnceLock::new();
198198

199199
fn create_map(py: Python) -> HashMap<u64, FnType> {
200200
let messages = py
201-
.import("monarch.common.messages")
202-
.expect("import monarch.common.messages");
201+
.import("monarch._src.tensor_engine.common.messages")
202+
.expect("import monarch._src.tensor_engine.common.messages");
203203
let mut m: HashMap<u64, FnType> = HashMap::new();
204204
let key = |name: &str| {
205205
messages

python/monarch/__init__.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,6 @@
3434
from monarch import timer
3535
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
3636
from monarch._src.actor.shape import NDSlice, Shape
37-
from monarch.common._coalescing import coalescing
38-
39-
from monarch.common.device_mesh import (
40-
DeviceMesh,
41-
get_active_mesh,
42-
no_mesh,
43-
RemoteProcessGroup,
44-
slice_mesh,
45-
to_mesh,
46-
)
47-
48-
from monarch.common.function import resolvers as function_resolvers
49-
50-
from monarch.common.future import Future
51-
52-
from monarch.common.invocation import RemoteException
53-
from monarch.common.opaque_ref import OpaqueRef
54-
from monarch.common.pipe import create_pipe, Pipe, remote_generator
55-
from monarch.common.remote import remote
56-
from monarch.common.selection import Selection
57-
from monarch.common.stream import get_active_stream, Stream
58-
from monarch.common.tensor import reduce, reduce_, Tensor
5937
from monarch.fetch import fetch_shard, inspect, show
6038
from monarch.gradient_generator import grad_function, grad_generator
6139
from monarch.notebook import mast_mesh, reserve_torchx as mast_reserve
@@ -68,33 +46,58 @@
6846
from monarch.rust_local_mesh import local_mesh, local_meshes, SocketType
6947
from monarch.simulator.config import set_meta # noqa
7048
from monarch.simulator.interface import Simulator
49+
from monarch._src.tensor_engine.common._coalescing import coalescing
50+
51+
from monarch._src.tensor_engine.common.device_mesh import (
52+
DeviceMesh,
53+
get_active_mesh,
54+
no_mesh,
55+
RemoteProcessGroup,
56+
slice_mesh,
57+
to_mesh,
58+
)
59+
60+
from monarch._src.tensor_engine.common.function import resolvers as function_resolvers
61+
62+
from monarch._src.tensor_engine.common.future import Future
63+
64+
from monarch._src.tensor_engine.common.invocation import RemoteException
65+
from monarch._src.tensor_engine.common.opaque_ref import OpaqueRef
66+
from monarch._src.tensor_engine.common.pipe import create_pipe, Pipe, remote_generator
67+
from monarch._src.tensor_engine.common.remote import remote
68+
from monarch._src.tensor_engine.common.selection import Selection
69+
from monarch._src.tensor_engine.common.stream import get_active_stream, Stream
70+
from monarch._src.tensor_engine.common.tensor import reduce, reduce_, Tensor
7171
from monarch.world_mesh import world_mesh
7272

7373

7474
_public_api = {
75-
"coalescing": ("monarch.common._coalescing", "coalescing"),
76-
"remote": ("monarch.common.remote", "remote"),
77-
"DeviceMesh": ("monarch.common.device_mesh", "DeviceMesh"),
78-
"get_active_mesh": ("monarch.common.device_mesh", "get_active_mesh"),
79-
"no_mesh": ("monarch.common.device_mesh", "no_mesh"),
80-
"RemoteProcessGroup": ("monarch.common.device_mesh", "RemoteProcessGroup"),
81-
"function_resolvers": ("monarch.common.function", "resolvers"),
82-
"Future": ("monarch.common.future", "Future"),
83-
"RemoteException": ("monarch.common.invocation", "RemoteException"),
75+
"coalescing": ("monarch._src.tensor_engine.common._coalescing", "coalescing"),
76+
"remote": ("monarch._src.tensor_engine.common.remote", "remote"),
77+
"DeviceMesh": ("monarch._src.tensor_engine.common.device_mesh", "DeviceMesh"),
78+
"get_active_mesh": ("monarch._src.tensor_engine.common.device_mesh", "get_active_mesh"),
79+
"no_mesh": ("monarch._src.tensor_engine.common.device_mesh", "no_mesh"),
80+
"RemoteProcessGroup": (
81+
"monarch._src.tensor_engine.common.device_mesh",
82+
"RemoteProcessGroup",
83+
),
84+
"function_resolvers": ("monarch._src.tensor_engine.common.function", "resolvers"),
85+
"Future": ("monarch._src.tensor_engine.common.future", "Future"),
86+
"RemoteException": ("monarch._src.tensor_engine.common.invocation", "RemoteException"),
8487
"Shape": ("monarch._src.actor.shape", "Shape"),
8588
"NDSlice": ("monarch._src.actor.shape", "NDSlice"),
86-
"Selection": ("monarch.common.selection", "Selection"),
87-
"OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"),
88-
"create_pipe": ("monarch.common.pipe", "create_pipe"),
89-
"Pipe": ("monarch.common.pipe", "Pipe"),
90-
"remote_generator": ("monarch.common.pipe", "remote_generator"),
91-
"get_active_stream": ("monarch.common.stream", "get_active_stream"),
92-
"Stream": ("monarch.common.stream", "Stream"),
93-
"Tensor": ("monarch.common.tensor", "Tensor"),
94-
"reduce": ("monarch.common.tensor", "reduce"),
95-
"reduce_": ("monarch.common.tensor", "reduce_"),
96-
"to_mesh": ("monarch.common.device_mesh", "to_mesh"),
97-
"slice_mesh": ("monarch.common.device_mesh", "slice_mesh"),
89+
"Selection": ("monarch._src.tensor_engine.common.selection", "Selection"),
90+
"OpaqueRef": ("monarch._src.tensor_engine.common.opaque_ref", "OpaqueRef"),
91+
"create_pipe": ("monarch._src.tensor_engine.common.pipe", "create_pipe"),
92+
"Pipe": ("monarch._src.tensor_engine.common.pipe", "Pipe"),
93+
"remote_generator": ("monarch._src.tensor_engine.common.pipe", "remote_generator"),
94+
"get_active_stream": ("monarch._src.tensor_engine.common.stream", "get_active_stream"),
95+
"Stream": ("monarch._src.tensor_engine.common.stream", "Stream"),
96+
"Tensor": ("monarch._src.tensor_engine.common.tensor", "Tensor"),
97+
"reduce": ("monarch._src.tensor_engine.common.tensor", "reduce"),
98+
"reduce_": ("monarch._src.tensor_engine.common.tensor", "reduce_"),
99+
"to_mesh": ("monarch._src.tensor_engine.common.device_mesh", "to_mesh"),
100+
"slice_mesh": ("monarch._src.tensor_engine.common.device_mesh", "slice_mesh"),
98101
"call_on_shard_and_fetch": ("monarch.fetch", "call_on_shard_and_fetch"),
99102
"fetch_shard": ("monarch.fetch", "fetch_shard"),
100103
"inspect": ("monarch.fetch", "inspect"),

python/monarch/_rust_bindings/monarch_extension/client.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ from typing import Any, ClassVar, Dict, final, List, NamedTuple, Union
88

99
from monarch._rust_bindings.monarch_extension.tensor_worker import Ref
1010
from monarch._rust_bindings.monarch_messages.debugger import DebuggerActionType
11-
from monarch._src.actor._extension.monarch_hyperactor.proc import (
11+
from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/python/monarch/_src/actor:actor
1212
ActorId,
1313
Proc,
1414
Serialized,

python/monarch/_src/tensor_engine/common/__init__.py

Whitespace-only changes.

python/monarch/common/_coalescing.py renamed to python/monarch/_src/tensor_engine/common/_coalescing.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,17 @@
2424
)
2525

2626
import torch
27-
from monarch.common import messages
2827

29-
from monarch.common.fake import fake_call
30-
from monarch.common.function_caching import (
31-
hashable_tensor_flatten,
32-
TensorGroup,
33-
TensorGroupPattern,
34-
)
35-
from monarch.common.tensor import InputChecker, Tensor
36-
from monarch.common.tree import flatten
28+
from . import messages
3729

38-
if TYPE_CHECKING:
39-
from monarch.common.client import Recorder
40-
from monarch.common.recording import Recording
30+
from .fake import fake_call
31+
from .function_caching import hashable_tensor_flatten, TensorGroup, TensorGroupPattern
32+
from .tensor import InputChecker, Tensor
33+
from .tree import flatten
4134

42-
from .client import Client
35+
if TYPE_CHECKING:
36+
from .client import Client, Recorder
37+
from .recording import Recording
4338

4439
_coalescing = None
4540

0 commit comments

Comments
 (0)