Skip to content

Commit d0f2285

Browse files
committed
[monarch] initial commit of tensor_engine submodule
Pull Request resolved: #458 Now that the actor layer is place, time to refactor the tensor engine following the same principles: - Every symbol has at most two locations: a private location (which should be exactly the same as its location in the source tree), and optionally a public location in the user-facing API. - We discourage end-users from depending on private locations via underscores, per python convention. In particular, we violate principle 2 in many places in the tensor engine—all of our top-level files/modules are visible if you do `from monarch.foo import bar`, and we are sloppy about the distinction of public and private. So this refactor will be a little bit painful. I will approach it in parts. This diff defines the `tensor_engine` submodule and moves `common` to `monarch.tensor_engine._common`. Differential Revision: [D77847587](https://our.internmc.facebook.com/intern/diff/D77847587/) ghstack-source-id: 294872880
1 parent 704dd94 commit d0f2285

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+456
-359
lines changed

monarch_extension/src/convert.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct MessageParser<'a> {
4444
fn create_function(obj: Bound<'_, PyAny>) -> PyResult<ResolvableFunction> {
4545
let cloudpickle = obj
4646
.py()
47-
.import("monarch.common.function")?
47+
.import("monarch._src.tensor_engine.common.function")?
4848
.getattr("ResolvableFromCloudpickle")?;
4949
if obj.is_instance(&cloudpickle)? {
5050
Ok(ResolvableFunction::Cloudpickle(Cloudpickle::new(
@@ -102,7 +102,7 @@ impl<'a> MessageParser<'a> {
102102
let referenceable = self
103103
.current
104104
.py()
105-
.import("monarch.common.reference")?
105+
.import("monarch._src.tensor_engine.common.reference")?
106106
.getattr("Referenceable")?;
107107
let mut flat: Vec<Option<Ref>> = vec![];
108108
for x in output_tuple.0.try_iter()? {
@@ -198,8 +198,8 @@ static CONVERT_MAP: OnceLock<HashMap<u64, FnType>> = OnceLock::new();
198198

199199
fn create_map(py: Python) -> HashMap<u64, FnType> {
200200
let messages = py
201-
.import("monarch.common.messages")
202-
.expect("import monarch.common.messages");
201+
.import("monarch._src.tensor_engine.common.messages")
202+
.expect("import monarch._src.tensor_engine.common.messages");
203203
let mut m: HashMap<u64, FnType> = HashMap::new();
204204
let key = |name: &str| {
205205
messages

python/monarch/__init__.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
from monarch import timer
3535
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
3636
from monarch._src.actor.shape import NDSlice, Shape
37-
from monarch.common._coalescing import coalescing
37+
from monarch._src.tensor_engine.common._coalescing import coalescing
3838

39-
from monarch.common.device_mesh import (
39+
from monarch._src.tensor_engine.common.device_mesh import (
4040
DeviceMesh,
4141
get_active_mesh,
4242
no_mesh,
@@ -45,17 +45,23 @@
4545
to_mesh,
4646
)
4747

48-
from monarch.common.function import resolvers as function_resolvers
48+
from monarch._src.tensor_engine.common.function import (
49+
resolvers as function_resolvers,
50+
)
4951

50-
from monarch.common.future import Future
52+
from monarch._src.tensor_engine.common.future import Future
5153

52-
from monarch.common.invocation import RemoteException
53-
from monarch.common.opaque_ref import OpaqueRef
54-
from monarch.common.pipe import create_pipe, Pipe, remote_generator
55-
from monarch.common.remote import remote
56-
from monarch.common.selection import Selection
57-
from monarch.common.stream import get_active_stream, Stream
58-
from monarch.common.tensor import reduce, reduce_, Tensor
54+
from monarch._src.tensor_engine.common.invocation import RemoteException
55+
from monarch._src.tensor_engine.common.opaque_ref import OpaqueRef
56+
from monarch._src.tensor_engine.common.pipe import (
57+
create_pipe,
58+
Pipe,
59+
remote_generator,
60+
)
61+
from monarch._src.tensor_engine.common.remote import remote
62+
from monarch._src.tensor_engine.common.selection import Selection
63+
from monarch._src.tensor_engine.common.stream import get_active_stream, Stream
64+
from monarch._src.tensor_engine.common.tensor import reduce, reduce_, Tensor
5965
from monarch.fetch import fetch_shard, inspect, show
6066
from monarch.gradient_generator import grad_function, grad_generator
6167
from monarch.notebook import mast_mesh, reserve_torchx as mast_reserve
@@ -72,29 +78,41 @@
7278

7379

7480
_public_api = {
75-
"coalescing": ("monarch.common._coalescing", "coalescing"),
76-
"remote": ("monarch.common.remote", "remote"),
77-
"DeviceMesh": ("monarch.common.device_mesh", "DeviceMesh"),
78-
"get_active_mesh": ("monarch.common.device_mesh", "get_active_mesh"),
79-
"no_mesh": ("monarch.common.device_mesh", "no_mesh"),
80-
"RemoteProcessGroup": ("monarch.common.device_mesh", "RemoteProcessGroup"),
81-
"function_resolvers": ("monarch.common.function", "resolvers"),
82-
"Future": ("monarch.common.future", "Future"),
83-
"RemoteException": ("monarch.common.invocation", "RemoteException"),
81+
"coalescing": ("monarch._src.tensor_engine.common._coalescing", "coalescing"),
82+
"remote": ("monarch._src.tensor_engine.common.remote", "remote"),
83+
"DeviceMesh": ("monarch._src.tensor_engine.common.device_mesh", "DeviceMesh"),
84+
"get_active_mesh": (
85+
"monarch._src.tensor_engine.common.device_mesh",
86+
"get_active_mesh",
87+
),
88+
"no_mesh": ("monarch._src.tensor_engine.common.device_mesh", "no_mesh"),
89+
"RemoteProcessGroup": (
90+
"monarch._src.tensor_engine.common.device_mesh",
91+
"RemoteProcessGroup",
92+
),
93+
"function_resolvers": ("monarch._src.tensor_engine.common.function", "resolvers"),
94+
"Future": ("monarch._src.tensor_engine.common.future", "Future"),
95+
"RemoteException": (
96+
"monarch._src.tensor_engine.common.invocation",
97+
"RemoteException",
98+
),
8499
"Shape": ("monarch._src.actor.shape", "Shape"),
85100
"NDSlice": ("monarch._src.actor.shape", "NDSlice"),
86-
"Selection": ("monarch.common.selection", "Selection"),
87-
"OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"),
88-
"create_pipe": ("monarch.common.pipe", "create_pipe"),
89-
"Pipe": ("monarch.common.pipe", "Pipe"),
90-
"remote_generator": ("monarch.common.pipe", "remote_generator"),
91-
"get_active_stream": ("monarch.common.stream", "get_active_stream"),
92-
"Stream": ("monarch.common.stream", "Stream"),
93-
"Tensor": ("monarch.common.tensor", "Tensor"),
94-
"reduce": ("monarch.common.tensor", "reduce"),
95-
"reduce_": ("monarch.common.tensor", "reduce_"),
96-
"to_mesh": ("monarch.common.device_mesh", "to_mesh"),
97-
"slice_mesh": ("monarch.common.device_mesh", "slice_mesh"),
101+
"Selection": ("monarch._src.tensor_engine.common.selection", "Selection"),
102+
"OpaqueRef": ("monarch._src.tensor_engine.common.opaque_ref", "OpaqueRef"),
103+
"create_pipe": ("monarch._src.tensor_engine.common.pipe", "create_pipe"),
104+
"Pipe": ("monarch._src.tensor_engine.common.pipe", "Pipe"),
105+
"remote_generator": ("monarch._src.tensor_engine.common.pipe", "remote_generator"),
106+
"get_active_stream": (
107+
"monarch._src.tensor_engine.common.stream",
108+
"get_active_stream",
109+
),
110+
"Stream": ("monarch._src.tensor_engine.common.stream", "Stream"),
111+
"Tensor": ("monarch._src.tensor_engine.common.tensor", "Tensor"),
112+
"reduce": ("monarch._src.tensor_engine.common.tensor", "reduce"),
113+
"reduce_": ("monarch._src.tensor_engine.common.tensor", "reduce_"),
114+
"to_mesh": ("monarch._src.tensor_engine.common.device_mesh", "to_mesh"),
115+
"slice_mesh": ("monarch._src.tensor_engine.common.device_mesh", "slice_mesh"),
98116
"call_on_shard_and_fetch": ("monarch.fetch", "call_on_shard_and_fetch"),
99117
"fetch_shard": ("monarch.fetch", "fetch_shard"),
100118
"inspect": ("monarch.fetch", "inspect"),

python/monarch/_rust_bindings/monarch_extension/client.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ from typing import Any, ClassVar, Dict, final, List, NamedTuple, Union
88

99
from monarch._rust_bindings.monarch_extension.tensor_worker import Ref
1010
from monarch._rust_bindings.monarch_messages.debugger import DebuggerActionType
11-
from monarch._src.actor._extension.monarch_hyperactor.proc import (
11+
from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/python/monarch/_src/actor:actor
1212
ActorId,
1313
Proc,
1414
Serialized,

python/monarch/_src/tensor_engine/common/__init__.py

Whitespace-only changes.

python/monarch/common/_coalescing.py renamed to python/monarch/_src/tensor_engine/common/_coalescing.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,17 @@
2424
)
2525

2626
import torch
27-
from monarch.common import messages
2827

29-
from monarch.common.fake import fake_call
30-
from monarch.common.function_caching import (
31-
hashable_tensor_flatten,
32-
TensorGroup,
33-
TensorGroupPattern,
34-
)
35-
from monarch.common.tensor import InputChecker, Tensor
36-
from monarch.common.tree import flatten
28+
from . import messages
3729

38-
if TYPE_CHECKING:
39-
from monarch.common.client import Recorder
40-
from monarch.common.recording import Recording
30+
from .fake import fake_call
31+
from .function_caching import hashable_tensor_flatten, TensorGroup, TensorGroupPattern
32+
from .tensor import InputChecker, Tensor
33+
from .tree import flatten
4134

42-
from .client import Client
35+
if TYPE_CHECKING:
36+
from .client import Client, Recorder
37+
from .recording import Recording
4338

4439
_coalescing = None
4540

0 commit comments

Comments
 (0)