diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index afda10ad..117b6f8c 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -44,7 +44,7 @@ struct MessageParser<'a> { fn create_function(obj: Bound<'_, PyAny>) -> PyResult { let cloudpickle = obj .py() - .import("monarch.common.function")? + .import("monarch._src.tensor_engine.common.function")? .getattr("ResolvableFromCloudpickle")?; if obj.is_instance(&cloudpickle)? { Ok(ResolvableFunction::Cloudpickle(Cloudpickle::new( @@ -102,7 +102,7 @@ impl<'a> MessageParser<'a> { let referenceable = self .current .py() - .import("monarch.common.reference")? + .import("monarch._src.tensor_engine.common.reference")? .getattr("Referenceable")?; let mut flat: Vec> = vec![]; for x in output_tuple.0.try_iter()? { @@ -198,8 +198,8 @@ static CONVERT_MAP: OnceLock> = OnceLock::new(); fn create_map(py: Python) -> HashMap { let messages = py - .import("monarch.common.messages") - .expect("import monarch.common.messages"); + .import("monarch._src.tensor_engine.common.messages") + .expect("import monarch._src.tensor_engine.common.messages"); let mut m: HashMap = HashMap::new(); let key = |name: &str| { messages diff --git a/python/monarch/__init__.py b/python/monarch/__init__.py index 11b649b6..85a5d1e3 100644 --- a/python/monarch/__init__.py +++ b/python/monarch/__init__.py @@ -34,9 +34,9 @@ from monarch import timer from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.shape import NDSlice, Shape - from monarch.common._coalescing import coalescing + from monarch._src.tensor_engine.common._coalescing import coalescing - from monarch.common.device_mesh import ( + from monarch._src.tensor_engine.common.device_mesh import ( DeviceMesh, get_active_mesh, no_mesh, @@ -45,17 +45,23 @@ to_mesh, ) - from monarch.common.function import resolvers as function_resolvers + from monarch._src.tensor_engine.common.function import ( + resolvers as function_resolvers, + ) - from monarch.common.future import Future + from monarch._src.tensor_engine.common.future import Future - from monarch.common.invocation import RemoteException - from monarch.common.opaque_ref import OpaqueRef - from monarch.common.pipe import create_pipe, Pipe, remote_generator - from monarch.common.remote import remote - from monarch.common.selection import Selection - from monarch.common.stream import get_active_stream, Stream - from monarch.common.tensor import reduce, reduce_, Tensor + from monarch._src.tensor_engine.common.invocation import RemoteException + from monarch._src.tensor_engine.common.opaque_ref import OpaqueRef + from monarch._src.tensor_engine.common.pipe import ( + create_pipe, + Pipe, + remote_generator, + ) + from monarch._src.tensor_engine.common.remote import remote + from monarch._src.tensor_engine.common.selection import Selection + from monarch._src.tensor_engine.common.stream import get_active_stream, Stream + from monarch._src.tensor_engine.common.tensor import reduce, reduce_, Tensor from monarch.fetch import fetch_shard, inspect, show from monarch.gradient_generator import grad_function, grad_generator from monarch.notebook import mast_mesh, reserve_torchx as mast_reserve @@ -72,29 +78,41 @@ _public_api = { - "coalescing": ("monarch.common._coalescing", "coalescing"), - "remote": ("monarch.common.remote", "remote"), - "DeviceMesh": ("monarch.common.device_mesh", "DeviceMesh"), - "get_active_mesh": ("monarch.common.device_mesh", "get_active_mesh"), - "no_mesh": ("monarch.common.device_mesh", "no_mesh"), - "RemoteProcessGroup": ("monarch.common.device_mesh", "RemoteProcessGroup"), - "function_resolvers": ("monarch.common.function", "resolvers"), - "Future": ("monarch.common.future", "Future"), - "RemoteException": ("monarch.common.invocation", "RemoteException"), + "coalescing": ("monarch._src.tensor_engine.common._coalescing", "coalescing"), + "remote": ("monarch._src.tensor_engine.common.remote", "remote"), + "DeviceMesh": ("monarch._src.tensor_engine.common.device_mesh", "DeviceMesh"), + "get_active_mesh": ( + "monarch._src.tensor_engine.common.device_mesh", + "get_active_mesh", + ), + "no_mesh": ("monarch._src.tensor_engine.common.device_mesh", "no_mesh"), + "RemoteProcessGroup": ( + "monarch._src.tensor_engine.common.device_mesh", + "RemoteProcessGroup", + ), + "function_resolvers": ("monarch._src.tensor_engine.common.function", "resolvers"), + "Future": ("monarch._src.tensor_engine.common.future", "Future"), + "RemoteException": ( + "monarch._src.tensor_engine.common.invocation", + "RemoteException", + ), "Shape": ("monarch._src.actor.shape", "Shape"), "NDSlice": ("monarch._src.actor.shape", "NDSlice"), - "Selection": ("monarch.common.selection", "Selection"), - "OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"), - "create_pipe": ("monarch.common.pipe", "create_pipe"), - "Pipe": ("monarch.common.pipe", "Pipe"), - "remote_generator": ("monarch.common.pipe", "remote_generator"), - "get_active_stream": ("monarch.common.stream", "get_active_stream"), - "Stream": ("monarch.common.stream", "Stream"), - "Tensor": ("monarch.common.tensor", "Tensor"), - "reduce": ("monarch.common.tensor", "reduce"), - "reduce_": ("monarch.common.tensor", "reduce_"), - "to_mesh": ("monarch.common.device_mesh", "to_mesh"), - "slice_mesh": ("monarch.common.device_mesh", "slice_mesh"), + "Selection": ("monarch._src.tensor_engine.common.selection", "Selection"), + "OpaqueRef": ("monarch._src.tensor_engine.common.opaque_ref", "OpaqueRef"), + "create_pipe": ("monarch._src.tensor_engine.common.pipe", "create_pipe"), + "Pipe": ("monarch._src.tensor_engine.common.pipe", "Pipe"), + "remote_generator": ("monarch._src.tensor_engine.common.pipe", "remote_generator"), + "get_active_stream": ( + "monarch._src.tensor_engine.common.stream", + "get_active_stream", + ), + "Stream": ("monarch._src.tensor_engine.common.stream", "Stream"), + "Tensor": ("monarch._src.tensor_engine.common.tensor", "Tensor"), + "reduce": ("monarch._src.tensor_engine.common.tensor", "reduce"), + "reduce_": ("monarch._src.tensor_engine.common.tensor", "reduce_"), + "to_mesh": ("monarch._src.tensor_engine.common.device_mesh", "to_mesh"), + "slice_mesh": ("monarch._src.tensor_engine.common.device_mesh", "slice_mesh"), "call_on_shard_and_fetch": ("monarch.fetch", "call_on_shard_and_fetch"), "fetch_shard": ("monarch.fetch", "fetch_shard"), "inspect": ("monarch.fetch", "inspect"), diff --git a/python/monarch/_rust_bindings/monarch_extension/client.pyi b/python/monarch/_rust_bindings/monarch_extension/client.pyi index a38a466c..1f970465 100644 --- a/python/monarch/_rust_bindings/monarch_extension/client.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/client.pyi @@ -8,7 +8,7 @@ from typing import Any, ClassVar, Dict, final, List, NamedTuple, Union from monarch._rust_bindings.monarch_extension.tensor_worker import Ref from monarch._rust_bindings.monarch_messages.debugger import DebuggerActionType -from monarch._src.actor._extension.monarch_hyperactor.proc import ( +from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/python/monarch/_src/actor:actor ActorId, Proc, Serialized, diff --git a/python/monarch/common/__init__.py b/python/monarch/_src/tensor_engine/__init__.py similarity index 100% rename from python/monarch/common/__init__.py rename to python/monarch/_src/tensor_engine/__init__.py diff --git a/python/monarch/common/_C.pyi b/python/monarch/_src/tensor_engine/common/_C.pyi similarity index 100% rename from python/monarch/common/_C.pyi rename to python/monarch/_src/tensor_engine/common/_C.pyi diff --git a/python/monarch/_src/tensor_engine/common/__init__.py b/python/monarch/_src/tensor_engine/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/monarch/common/_coalescing.py b/python/monarch/_src/tensor_engine/common/_coalescing.py similarity index 96% rename from python/monarch/common/_coalescing.py rename to python/monarch/_src/tensor_engine/common/_coalescing.py index 080cec7b..a7c7469e 100644 --- a/python/monarch/common/_coalescing.py +++ b/python/monarch/_src/tensor_engine/common/_coalescing.py @@ -24,22 +24,17 @@ ) import torch -from monarch.common import messages -from monarch.common.fake import fake_call -from monarch.common.function_caching import ( - hashable_tensor_flatten, - TensorGroup, - TensorGroupPattern, -) -from monarch.common.tensor import InputChecker, Tensor -from monarch.common.tree import flatten +from . import messages -if TYPE_CHECKING: - from monarch.common.client import Recorder - from monarch.common.recording import Recording +from .fake import fake_call +from .function_caching import hashable_tensor_flatten, TensorGroup, TensorGroupPattern +from .tensor import InputChecker, Tensor +from .tree import flatten - from .client import Client +if TYPE_CHECKING: + from .client import Client, Recorder + from .recording import Recording _coalescing = None diff --git a/python/monarch/common/_tensor_to_table.py b/python/monarch/_src/tensor_engine/common/_tensor_to_table.py similarity index 100% rename from python/monarch/common/_tensor_to_table.py rename to python/monarch/_src/tensor_engine/common/_tensor_to_table.py diff --git a/python/monarch/common/base_tensor.py b/python/monarch/_src/tensor_engine/common/base_tensor.py similarity index 100% rename from python/monarch/common/base_tensor.py rename to python/monarch/_src/tensor_engine/common/base_tensor.py diff --git a/python/monarch/common/borrows.py b/python/monarch/_src/tensor_engine/common/borrows.py similarity index 100% rename from python/monarch/common/borrows.py rename to python/monarch/_src/tensor_engine/common/borrows.py diff --git a/python/monarch/common/client.py b/python/monarch/_src/tensor_engine/common/client.py similarity index 97% rename from python/monarch/common/client.py rename to python/monarch/_src/tensor_engine/common/client.py index c9bf546f..9743f5ae 100644 --- a/python/monarch/common/client.py +++ b/python/monarch/_src/tensor_engine/common/client.py @@ -38,21 +38,20 @@ WorldState, ) from monarch._src.actor.shape import NDSlice -from monarch.common import messages -from monarch.common.borrows import Borrow, StorageAliases -from monarch.common.controller_api import LogMessage, MessageResult, TController -from monarch.common.device_mesh import DeviceMesh -from monarch.common.future import Future -from monarch.common.invocation import DeviceException, RemoteException, Seq -from monarch.common.recording import flatten_messages, Recording +from . import _coalescing, messages +from .borrows import Borrow, StorageAliases +from .controller_api import LogMessage, MessageResult, TController +from .device_mesh import DeviceMesh -from monarch.common.reference import Ref, Referenceable -from monarch.common.stream import StreamRef -from monarch.common.tensor import Tensor -from monarch.common.tree import tree_map +from .future import Future +from .invocation import DeviceException, RemoteException, Seq +from .recording import flatten_messages, Recording -from . import _coalescing +from .reference import Ref, Referenceable +from .stream import StreamRef +from .tensor import Tensor +from .tree import tree_map logger = logging.getLogger(__name__) diff --git a/python/monarch/common/constants.py b/python/monarch/_src/tensor_engine/common/constants.py similarity index 100% rename from python/monarch/common/constants.py rename to python/monarch/_src/tensor_engine/common/constants.py diff --git a/python/monarch/common/context_manager.py b/python/monarch/_src/tensor_engine/common/context_manager.py similarity index 100% rename from python/monarch/common/context_manager.py rename to python/monarch/_src/tensor_engine/common/context_manager.py diff --git a/python/monarch/common/controller_api.py b/python/monarch/_src/tensor_engine/common/controller_api.py similarity index 95% rename from python/monarch/common/controller_api.py rename to python/monarch/_src/tensor_engine/common/controller_api.py index b0a42e84..f5b2ea59 100644 --- a/python/monarch/common/controller_api.py +++ b/python/monarch/_src/tensor_engine/common/controller_api.py @@ -15,9 +15,9 @@ from monarch._src.actor.shape import NDSlice -from monarch.common.invocation import DeviceException, RemoteException, Seq -from monarch.common.reference import Ref -from monarch.common.tensor import Tensor +from .invocation import DeviceException, RemoteException, Seq +from .reference import Ref +from .tensor import Tensor class LogMessage(NamedTuple): diff --git a/python/monarch/common/device_mesh.py b/python/monarch/_src/tensor_engine/common/device_mesh.py similarity index 98% rename from python/monarch/common/device_mesh.py rename to python/monarch/_src/tensor_engine/common/device_mesh.py index 86efc9e2..5d22164f 100644 --- a/python/monarch/common/device_mesh.py +++ b/python/monarch/_src/tensor_engine/common/device_mesh.py @@ -26,13 +26,14 @@ Union, ) -import monarch.common.messages as messages import torch from monarch._src.actor.shape import MeshTrait, NDSlice, Shape from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map +from . import messages + from ._tensor_to_table import tensor_to_table from .context_manager import activate_first_context_manager from .messages import Dims @@ -41,7 +42,7 @@ from .tensor import MeshSliceTensor, Tensor if TYPE_CHECKING: - from monarch.common.client import Client + from .client import Client logger: Logger = logging.getLogger(__name__) @@ -382,7 +383,7 @@ def _remote(*args, **kwargs): # we break the dependency to allow for separate files by # having device_mesh and tensor locally import the `remote` # entrypoint - from monarch.common.remote import remote + from .remote import remote return remote(*args, **kwargs) diff --git a/python/monarch/common/fake.py b/python/monarch/_src/tensor_engine/common/fake.py similarity index 100% rename from python/monarch/common/fake.py rename to python/monarch/_src/tensor_engine/common/fake.py diff --git a/python/monarch/common/function.py b/python/monarch/_src/tensor_engine/common/function.py similarity index 100% rename from python/monarch/common/function.py rename to python/monarch/_src/tensor_engine/common/function.py diff --git a/python/monarch/common/function_caching.py b/python/monarch/_src/tensor_engine/common/function_caching.py similarity index 100% rename from python/monarch/common/function_caching.py rename to python/monarch/_src/tensor_engine/common/function_caching.py diff --git a/python/monarch/common/future.py b/python/monarch/_src/tensor_engine/common/future.py similarity index 99% rename from python/monarch/common/future.py rename to python/monarch/_src/tensor_engine/common/future.py index 7ce600d3..35129151 100644 --- a/python/monarch/common/future.py +++ b/python/monarch/_src/tensor_engine/common/future.py @@ -23,7 +23,7 @@ from monarch_supervisor import TTL if TYPE_CHECKING: - from monarch.common.client import Client + from .client import Client from .invocation import RemoteException diff --git a/python/monarch/common/init.cpp b/python/monarch/_src/tensor_engine/common/init.cpp similarity index 100% rename from python/monarch/common/init.cpp rename to python/monarch/_src/tensor_engine/common/init.cpp diff --git a/python/monarch/_src/tensor_engine/common/invocation.py b/python/monarch/_src/tensor_engine/common/invocation.py new file mode 100644 index 00000000..1a01e492 --- /dev/null +++ b/python/monarch/_src/tensor_engine/common/invocation.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import traceback +from typing import Any, List, Optional, Tuple + +from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension + ActorId, +) + + +Seq = int + + +class DeviceException(Exception): + """ + Non-deterministic failure in the underlying worker, controller or its infrastructure. + For example, a worker may enter a crash loop, or its GPU may be lost + """ + + def __init__( + self, + exception: Exception, + frames: List[traceback.FrameSummary], + source_actor_id: ActorId, + message: str, + ): + self.exception = exception + self.frames = frames + self.source_actor_id = source_actor_id + self.message = message + + def __str__(self): + try: + exe = str(self.exception) + worker_tb = "".join(traceback.format_list(self.frames)) + return ( + f"{self.message}\n" + f"Traceback of the failure on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}" + ) + except Exception as e: + print(e) + return "oops" + + +class RemoteException(Exception): + """ + Deterministic problem with the user's code. + For example, an OOM resulting in trying to allocate too much GPU memory, or violating + some invariant enforced by the various APIs. + """ + + def __init__( + self, + seq: Seq, + exception: Exception, + controller_frame_index: Optional[int], + controller_frames: Optional[List[traceback.FrameSummary]], + worker_frames: List[traceback.FrameSummary], + source_actor_id: ActorId, + message="A remote function has failed asynchronously.", + ): + self.exception = exception + self.worker_frames = worker_frames + self.message = message + self.seq = seq + self.controller_frame_index = controller_frame_index + self.source_actor_id = source_actor_id + self.controller_frames = controller_frames + + def __str__(self): + try: + exe = str(self.exception) + worker_tb = "".join(traceback.format_list(self.worker_frames)) + controller_tb = ( + "".join(traceback.format_list(self.controller_frames)) + if self.controller_frames is not None + else " \n" + ) + return ( + f"{self.message}\n" + f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}" + f"Traceback of where the remote function failed on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}" + ) + except Exception as e: + print(e) + return "oops" + + +class Invocation: + def __init__(self, seq: Seq): + self.seq = seq + self.users: Optional[set["Invocation"]] = set() + self.failure: Optional[RemoteException] = None + self.fut_value: Any = None + + def __repr__(self): + return f"" + + def fail(self, remote_exception: RemoteException): + if self.failure is None or self.failure.seq > remote_exception.seq: + self.failure = remote_exception + return True + return False + + def add_user(self, r: "Invocation"): + if self.users is not None: + self.users.add(r) + if self.failure is not None: + r.fail(self.failure) + + def complete(self) -> Tuple[Any, Optional[RemoteException]]: + """ + Complete the current invocation. + Return the result and exception tuple. + """ + # after completion we no longer need to inform users of failures + # since they will just immediately get the value during add_user + self.users = None + + return (self.fut_value if self.failure is None else None, self.failure) diff --git a/python/monarch/common/mast.py b/python/monarch/_src/tensor_engine/common/mast.py similarity index 100% rename from python/monarch/common/mast.py rename to python/monarch/_src/tensor_engine/common/mast.py diff --git a/python/monarch/common/messages.py b/python/monarch/_src/tensor_engine/common/messages.py similarity index 98% rename from python/monarch/common/messages.py rename to python/monarch/_src/tensor_engine/common/messages.py index 5e600912..5f0b53ff 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/_src/tensor_engine/common/messages.py @@ -24,20 +24,20 @@ from monarch._rust_bindings.monarch_extension import tensor_worker from monarch._src.actor.shape import NDSlice -from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction -from monarch.common.invocation import DeviceException, RemoteException -from monarch.common.reference import Referenceable -from monarch.common.tree import flattener from pyre_extensions import none_throws +from .function import ResolvableFromCloudpickle, ResolvableFunction +from .invocation import DeviceException, RemoteException +from .reference import Referenceable + from .tensor_factory import TensorFactory +from .tree import flattener if TYPE_CHECKING: - from monarch.common.stream import StreamRef - from .device_mesh import DeviceMesh, RemoteProcessGroup from .pipe import Pipe from .recording import Recording + from .stream import StreamRef from .tensor import Tensor diff --git a/python/monarch/common/mock_cuda.cpp b/python/monarch/_src/tensor_engine/common/mock_cuda.cpp similarity index 100% rename from python/monarch/common/mock_cuda.cpp rename to python/monarch/_src/tensor_engine/common/mock_cuda.cpp diff --git a/python/monarch/common/mock_cuda.h b/python/monarch/_src/tensor_engine/common/mock_cuda.h similarity index 100% rename from python/monarch/common/mock_cuda.h rename to python/monarch/_src/tensor_engine/common/mock_cuda.h diff --git a/python/monarch/common/mock_cuda.py b/python/monarch/_src/tensor_engine/common/mock_cuda.py similarity index 66% rename from python/monarch/common/mock_cuda.py rename to python/monarch/_src/tensor_engine/common/mock_cuda.py index 87fca239..f6afd8a1 100644 --- a/python/monarch/common/mock_cuda.py +++ b/python/monarch/_src/tensor_engine/common/mock_cuda.py @@ -8,10 +8,10 @@ from contextlib import contextmanager from typing import Generator, Optional -import monarch.common._C # @manual=//monarch/python/monarch/common:_C +import monarch._src.tensor_engine.common._C # @manual=//monarch/python/monarch/_src/tensor_engine/common:_C import torch -monarch.common._C.patch_cuda() +monarch._src.tensor_engine.common._C.patch_cuda() _mock_cuda_stream: Optional[torch.cuda.Stream] = None @@ -27,15 +27,15 @@ def get_mock_cuda_stream() -> torch.cuda.Stream: def mock_cuda_guard() -> Generator[None, None, None]: try: with torch.cuda.stream(get_mock_cuda_stream()): - monarch.common._C.mock_cuda() + monarch._src.tensor_engine.common._C.mock_cuda() yield finally: - monarch.common._C.unmock_cuda() + monarch._src.tensor_engine.common._C.unmock_cuda() def mock_cuda() -> None: - monarch.common._C.mock_cuda() + monarch._src.tensor_engine.common._C.mock_cuda() def unmock_cuda() -> None: - monarch.common._C.unmock_cuda() + monarch._src.tensor_engine.common._C.unmock_cuda() diff --git a/python/monarch/common/opaque_ref.py b/python/monarch/_src/tensor_engine/common/opaque_ref.py similarity index 100% rename from python/monarch/common/opaque_ref.py rename to python/monarch/_src/tensor_engine/common/opaque_ref.py diff --git a/python/monarch/common/pipe.py b/python/monarch/_src/tensor_engine/common/pipe.py similarity index 99% rename from python/monarch/common/pipe.py rename to python/monarch/_src/tensor_engine/common/pipe.py index 5f756fa7..c91d42d2 100644 --- a/python/monarch/common/pipe.py +++ b/python/monarch/_src/tensor_engine/common/pipe.py @@ -10,12 +10,12 @@ from typing import Any, Dict import torch -from monarch.common.remote import Remote, remote from . import device_mesh, messages, stream from .fake import fake_call from .function import ResolvableFunctionFromPath from .reference import Referenceable +from .remote import Remote, remote from .tensor import dtensor_check, Tensor from .tree import flatten diff --git a/python/monarch/common/process_group.py b/python/monarch/_src/tensor_engine/common/process_group.py similarity index 100% rename from python/monarch/common/process_group.py rename to python/monarch/_src/tensor_engine/common/process_group.py diff --git a/python/monarch/common/recording.py b/python/monarch/_src/tensor_engine/common/recording.py similarity index 97% rename from python/monarch/common/recording.py rename to python/monarch/_src/tensor_engine/common/recording.py index 4417ad90..b390ac85 100644 --- a/python/monarch/common/recording.py +++ b/python/monarch/_src/tensor_engine/common/recording.py @@ -12,14 +12,14 @@ from monarch._src.actor.shape import iter_ranks -from monarch.common.reference import Ref +from . import messages -from monarch.common.tensor import InputChecker +from .reference import Ref -from . import messages +from .tensor import InputChecker if TYPE_CHECKING: - from monarch.common.client import Client + from .client import Client from monarch._src.actor.shape import NDSlice diff --git a/python/monarch/common/reference.py b/python/monarch/_src/tensor_engine/common/reference.py similarity index 100% rename from python/monarch/common/reference.py rename to python/monarch/_src/tensor_engine/common/reference.py diff --git a/python/monarch/common/remote.py b/python/monarch/_src/tensor_engine/common/remote.py similarity index 93% rename from python/monarch/common/remote.py rename to python/monarch/_src/tensor_engine/common/remote.py index 01f55e80..761ff6ae 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/_src/tensor_engine/common/remote.py @@ -24,36 +24,35 @@ TypeVar, ) -import monarch.common.messages as messages - import torch -from monarch.common import _coalescing, device_mesh, stream +from . import _coalescing, device_mesh, messages, stream if TYPE_CHECKING: - from monarch.common.client import Client + from .client import Client + +from torch import autograd, distributed as dist +from typing_extensions import ParamSpec -from monarch.common.device_mesh import RemoteProcessGroup -from monarch.common.fake import fake_call +from .device_mesh import RemoteProcessGroup +from .fake import fake_call -from monarch.common.function import ( +from .function import ( Propagator, resolvable_function, ResolvableFunction, ResolvableFunctionFromPath, ) -from monarch.common.function_caching import ( +from .function_caching import ( hashable_tensor_flatten, tensor_placeholder, TensorGroup, TensorPlaceholder, ) -from monarch.common.future import Future -from monarch.common.messages import Dims -from monarch.common.tensor import dtensor_check, dtensor_dispatch -from monarch.common.tree import flatten, tree_map -from torch import autograd, distributed as dist -from typing_extensions import ParamSpec +from .future import Future +from .messages import Dims +from .tensor import dtensor_check, dtensor_dispatch +from .tree import flatten, tree_map logger: Logger = logging.getLogger(__name__) diff --git a/python/monarch/common/selection.py b/python/monarch/_src/tensor_engine/common/selection.py similarity index 100% rename from python/monarch/common/selection.py rename to python/monarch/_src/tensor_engine/common/selection.py diff --git a/python/monarch/common/stream.py b/python/monarch/_src/tensor_engine/common/stream.py similarity index 98% rename from python/monarch/common/stream.py rename to python/monarch/_src/tensor_engine/common/stream.py index 643a89a6..a20031e7 100644 --- a/python/monarch/common/stream.py +++ b/python/monarch/_src/tensor_engine/common/stream.py @@ -15,7 +15,7 @@ from .reference import Referenceable if TYPE_CHECKING: - from monarch.common.client import Client # @manual + from .client import Client # @manual from .tensor import Tensor diff --git a/python/monarch/common/tensor.py b/python/monarch/_src/tensor_engine/common/tensor.py similarity index 99% rename from python/monarch/common/tensor.py rename to python/monarch/_src/tensor_engine/common/tensor.py index 29ebd6b9..0caadf3d 100644 --- a/python/monarch/common/tensor.py +++ b/python/monarch/_src/tensor_engine/common/tensor.py @@ -29,16 +29,16 @@ import torch import torch._ops -from monarch.common.function import ResolvableFunctionFromPath from torch._subclasses.fake_tensor import FakeTensor from torch.utils._pytree import tree_map from . import messages, stream from .base_tensor import BaseTensor from .borrows import StorageAliases +from .function import ResolvableFunctionFromPath if TYPE_CHECKING: - from monarch.common.device_mesh import DeviceMesh + from .device_mesh import DeviceMesh from monarch._src.actor.shape import NDSlice @@ -129,7 +129,7 @@ def __new__(cls, fake: torch.Tensor, mesh: "DeviceMesh", stream: "Stream"): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - from monarch.common.remote import remote + from .remote import remote # device_mesh <-> tensor <-> remote are mututally recursive # we break the dependency to allow for separate files by diff --git a/python/monarch/common/tensor_factory.py b/python/monarch/_src/tensor_engine/common/tensor_factory.py similarity index 100% rename from python/monarch/common/tensor_factory.py rename to python/monarch/_src/tensor_engine/common/tensor_factory.py diff --git a/python/monarch/common/tree.py b/python/monarch/_src/tensor_engine/common/tree.py similarity index 100% rename from python/monarch/common/tree.py rename to python/monarch/_src/tensor_engine/common/tree.py diff --git a/python/monarch/_testing.py b/python/monarch/_testing.py index 052992ff..7175af3a 100644 --- a/python/monarch/_testing.py +++ b/python/monarch/_testing.py @@ -14,10 +14,13 @@ import monarch_supervisor from monarch._src.actor.shape import NDSlice +from monarch._src.tensor_engine.common.client import Client +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) from monarch.actor import proc_mesh, ProcMesh -from monarch.common.client import Client -from monarch.common.device_mesh import DeviceMesh -from monarch.common.invocation import DeviceException, RemoteException from monarch.controller.backend import ProcessBackend from monarch.mesh_controller import spawn_tensor_engine from monarch.python_local_mesh import PythonLocalContext diff --git a/python/monarch/builtins/log.py b/python/monarch/builtins/log.py index 6bede860..0a88d90c 100644 --- a/python/monarch/builtins/log.py +++ b/python/monarch/builtins/log.py @@ -6,7 +6,7 @@ import logging -from monarch.common.remote import remote +from monarch._src.tensor_engine.common.remote import remote logger = logging.getLogger(__name__) diff --git a/python/monarch/builtins/random.py b/python/monarch/builtins/random.py index cc7a3f21..502549df 100644 --- a/python/monarch/builtins/random.py +++ b/python/monarch/builtins/random.py @@ -8,7 +8,7 @@ from typing import Callable import torch -from monarch.common.remote import remote +from monarch._src.tensor_engine.common.remote import remote @remote(propagate="inspect") diff --git a/python/monarch/cached_remote_function.py b/python/monarch/cached_remote_function.py index 05fc1d08..8cd2bd31 100644 --- a/python/monarch/cached_remote_function.py +++ b/python/monarch/cached_remote_function.py @@ -12,9 +12,15 @@ from typing import Dict, List, Optional, Type, Union import torch -from monarch.common.process_group import SingleControllerProcessGroupWrapper +from monarch._src.tensor_engine.common.process_group import ( + SingleControllerProcessGroupWrapper, +) -from monarch.common.remote import DummyProcessGroup, remote, RemoteProcessGroup +from monarch._src.tensor_engine.common.remote import ( + DummyProcessGroup, + remote, + RemoteProcessGroup, +) from torch import autograd from torch.utils._pytree import tree_flatten, tree_unflatten diff --git a/python/monarch/common/invocation.py b/python/monarch/common/invocation.py index 1a01e492..e89267dd 100644 --- a/python/monarch/common/invocation.py +++ b/python/monarch/common/invocation.py @@ -4,122 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe -import traceback -from typing import Any, List, Optional, Tuple - -from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension - ActorId, -) - - -Seq = int - - -class DeviceException(Exception): - """ - Non-deterministic failure in the underlying worker, controller or its infrastructure. - For example, a worker may enter a crash loop, or its GPU may be lost - """ - - def __init__( - self, - exception: Exception, - frames: List[traceback.FrameSummary], - source_actor_id: ActorId, - message: str, - ): - self.exception = exception - self.frames = frames - self.source_actor_id = source_actor_id - self.message = message - - def __str__(self): - try: - exe = str(self.exception) - worker_tb = "".join(traceback.format_list(self.frames)) - return ( - f"{self.message}\n" - f"Traceback of the failure on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}" - ) - except Exception as e: - print(e) - return "oops" +""" +Deprecated shim for monarch.common.invocation. +This module has been moved to monarch._src.tensor_engine.common.invocation. +This shim is provided for backward compatibility. +""" -class RemoteException(Exception): - """ - Deterministic problem with the user's code. - For example, an OOM resulting in trying to allocate too much GPU memory, or violating - some invariant enforced by the various APIs. - """ +import warnings - def __init__( - self, - seq: Seq, - exception: Exception, - controller_frame_index: Optional[int], - controller_frames: Optional[List[traceback.FrameSummary]], - worker_frames: List[traceback.FrameSummary], - source_actor_id: ActorId, - message="A remote function has failed asynchronously.", - ): - self.exception = exception - self.worker_frames = worker_frames - self.message = message - self.seq = seq - self.controller_frame_index = controller_frame_index - self.source_actor_id = source_actor_id - self.controller_frames = controller_frames - - def __str__(self): - try: - exe = str(self.exception) - worker_tb = "".join(traceback.format_list(self.worker_frames)) - controller_tb = ( - "".join(traceback.format_list(self.controller_frames)) - if self.controller_frames is not None - else " \n" - ) - return ( - f"{self.message}\n" - f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}" - f"Traceback of where the remote function failed on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}" - ) - except Exception as e: - print(e) - return "oops" - - -class Invocation: - def __init__(self, seq: Seq): - self.seq = seq - self.users: Optional[set["Invocation"]] = set() - self.failure: Optional[RemoteException] = None - self.fut_value: Any = None - - def __repr__(self): - return f"" - - def fail(self, remote_exception: RemoteException): - if self.failure is None or self.failure.seq > remote_exception.seq: - self.failure = remote_exception - return True - return False - - def add_user(self, r: "Invocation"): - if self.users is not None: - self.users.add(r) - if self.failure is not None: - r.fail(self.failure) - - def complete(self) -> Tuple[Any, Optional[RemoteException]]: - """ - Complete the current invocation. - Return the result and exception tuple. - """ - # after completion we no longer need to inform users of failures - # since they will just immediately get the value during add_user - self.users = None +# Issue deprecation warning +warnings.warn( + "monarch.common.invocation has been moved to monarch._src.tensor_engine.common.invocation. " + "Please update your imports. This shim will be removed in a future version.", + DeprecationWarning, + stacklevel=2, +) - return (self.fut_value if self.failure is None else None, self.failure) +# Re-export everything from the new location +from monarch._src.tensor_engine.common.invocation import * # noqa: F401, F403, E402 diff --git a/python/monarch/controller/backend.py b/python/monarch/controller/backend.py index 09b577cc..9faa16cf 100644 --- a/python/monarch/controller/backend.py +++ b/python/monarch/controller/backend.py @@ -15,7 +15,7 @@ from monarch._src.actor.shape import iter_ranks, Slices as Ranks -from monarch.common import messages +from monarch._src.tensor_engine.common import messages from monarch_supervisor import ( Context, FunctionCall, diff --git a/python/monarch/controller/controller.py b/python/monarch/controller/controller.py index a591c6d8..7c44d6f9 100644 --- a/python/monarch/controller/controller.py +++ b/python/monarch/controller/controller.py @@ -21,11 +21,11 @@ from monarch._src.actor.shape import NDSlice -from monarch.common import messages -from monarch.common.controller_api import LogMessage, MessageResult -from monarch.common.invocation import DeviceException, Seq -from monarch.common.reference import Ref -from monarch.common.tensor import Tensor +from monarch._src.tensor_engine.common import messages +from monarch._src.tensor_engine.common.controller_api import LogMessage, MessageResult +from monarch._src.tensor_engine.common.invocation import DeviceException, Seq +from monarch._src.tensor_engine.common.reference import Ref +from monarch._src.tensor_engine.common.tensor import Tensor from monarch.controller import debugger from .backend import Backend diff --git a/python/monarch/controller/history.py b/python/monarch/controller/history.py index 61ca7508..ef5fd97a 100644 --- a/python/monarch/controller/history.py +++ b/python/monarch/controller/history.py @@ -12,12 +12,16 @@ ActorId, ) -from monarch.common.controller_api import MessageResult +from monarch._src.tensor_engine.common.controller_api import MessageResult -from monarch.common.invocation import Invocation, RemoteException, Seq +from monarch._src.tensor_engine.common.invocation import ( + Invocation, + RemoteException, + Seq, +) if TYPE_CHECKING: - from monarch.common.tensor import Tensor + from monarch._src.tensor_engine.common.tensor import Tensor class History: diff --git a/python/monarch/controller/rust_backend/controller.py b/python/monarch/controller/rust_backend/controller.py index 919d7fb3..a13fc7a2 100644 --- a/python/monarch/controller/rust_backend/controller.py +++ b/python/monarch/controller/rust_backend/controller.py @@ -31,11 +31,14 @@ ) from monarch._src.actor.shape import NDSlice -from monarch.common.controller_api import LogMessage, MessageResult -from monarch.common.device_mesh import no_mesh -from monarch.common.invocation import DeviceException, RemoteException -from monarch.common.messages import SupportsToRustMessage -from monarch.common.tensor import Tensor +from monarch._src.tensor_engine.common.controller_api import LogMessage, MessageResult +from monarch._src.tensor_engine.common.device_mesh import no_mesh +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) +from monarch._src.tensor_engine.common.messages import SupportsToRustMessage +from monarch._src.tensor_engine.common.tensor import Tensor from monarch.controller.debugger import read as debugger_read, write as debugger_write from pyre_extensions import none_throws diff --git a/python/monarch/fetch.py b/python/monarch/fetch.py index 95c39405..048a285e 100644 --- a/python/monarch/fetch.py +++ b/python/monarch/fetch.py @@ -11,11 +11,11 @@ from typing import TypeVar -from monarch.common.device_mesh import no_mesh +from monarch._src.tensor_engine.common.device_mesh import no_mesh -from monarch.common.future import Future +from monarch._src.tensor_engine.common.future import Future -from monarch.common.remote import _call_on_shard_and_fetch +from monarch._src.tensor_engine.common.remote import _call_on_shard_and_fetch T = TypeVar("T") diff --git a/python/monarch/gradient_generator.py b/python/monarch/gradient_generator.py index 742399fd..6d33ef9d 100644 --- a/python/monarch/gradient_generator.py +++ b/python/monarch/gradient_generator.py @@ -14,9 +14,9 @@ import torch import torch.autograd.graph -from monarch.common import device_mesh, stream -from monarch.common.tensor import Tensor -from monarch.common.tree import flatten +from monarch._src.tensor_engine.common import device_mesh, stream +from monarch._src.tensor_engine.common.tensor import Tensor +from monarch._src.tensor_engine.common.tree import flatten from monarch.gradient import GradientGenerator as _GradientGenerator from torch._C._autograd import _get_sequence_nr # @manual from torch.autograd.graph import get_gradient_edge, GradientEdge diff --git a/python/monarch/memory.py b/python/monarch/memory.py index 6a76801c..e4b4bbc3 100644 --- a/python/monarch/memory.py +++ b/python/monarch/memory.py @@ -10,7 +10,7 @@ from pathlib import Path import torch -from monarch.common.remote import remote +from monarch._src.tensor_engine.common.remote import remote PATH_KEY = "dir_snapshots" diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 201d9c95..e5b47fdc 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -36,11 +36,11 @@ ) from monarch._src.actor.actor_mesh import Port, PortTuple from monarch._src.actor.shape import NDSlice -from monarch.common import messages -from monarch.common.controller_api import TController -from monarch.common.invocation import Seq -from monarch.common.stream import StreamRef -from monarch.common.tensor import Tensor +from monarch._src.tensor_engine.common import messages +from monarch._src.tensor_engine.common.controller_api import TController +from monarch._src.tensor_engine.common.invocation import Seq +from monarch._src.tensor_engine.common.stream import StreamRef +from monarch._src.tensor_engine.common.tensor import Tensor from monarch.tensor_worker_main import _set_trace @@ -52,11 +52,14 @@ from monarch._src.actor._extension.monarch_hyperactor.shape import Point -from monarch.common.client import Client -from monarch.common.controller_api import LogMessage, MessageResult -from monarch.common.device_mesh import DeviceMesh -from monarch.common.future import Future as OldFuture -from monarch.common.invocation import DeviceException, RemoteException +from monarch._src.tensor_engine.common.client import Client +from monarch._src.tensor_engine.common.controller_api import LogMessage, MessageResult +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.future import Future as OldFuture +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) from monarch.rust_local_mesh import _get_worker_exec_info logger: Logger = logging.getLogger(__name__) diff --git a/python/monarch/notebook.py b/python/monarch/notebook.py index 31ed9fcc..1cb16ad4 100644 --- a/python/monarch/notebook.py +++ b/python/monarch/notebook.py @@ -27,10 +27,10 @@ from typing import Any, List, Optional import zmq -from monarch.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh -from monarch.common.mast import mast_get_jobs, MastJob -from monarch.common.remote import remote +from monarch._src.tensor_engine.common.mast import mast_get_jobs, MastJob +from monarch._src.tensor_engine.common.remote import remote from monarch.world_mesh import world_mesh from monarch_supervisor import Context, get_message_queue, HostConnected from monarch_supervisor.host import main as host_main diff --git a/python/monarch/opaque_module.py b/python/monarch/opaque_module.py index f11ebb5d..21f2345c 100644 --- a/python/monarch/opaque_module.py +++ b/python/monarch/opaque_module.py @@ -7,11 +7,14 @@ from typing import List import torch -from monarch.common.function_caching import TensorGroup, TensorGroupPattern -from monarch.common.opaque_ref import OpaqueRef -from monarch.common.remote import remote -from monarch.common.tensor_factory import TensorFactory -from monarch.common.tree import flatten +from monarch._src.tensor_engine.common.function_caching import ( + TensorGroup, + TensorGroupPattern, +) +from monarch._src.tensor_engine.common.opaque_ref import OpaqueRef +from monarch._src.tensor_engine.common.remote import remote +from monarch._src.tensor_engine.common.tensor_factory import TensorFactory +from monarch._src.tensor_engine.common.tree import flatten from monarch.opaque_object import _fresh_opaque_ref, OpaqueObject from torch.autograd.graph import get_gradient_edge diff --git a/python/monarch/opaque_object.py b/python/monarch/opaque_object.py index 5fa2279f..bc0795a6 100644 --- a/python/monarch/opaque_object.py +++ b/python/monarch/opaque_object.py @@ -7,14 +7,14 @@ import functools import torch -from monarch.common.function import ( +from monarch._src.tensor_engine.common.function import ( ConvertsToResolvable, resolvable_function, ResolvableFunction, ) -from monarch.common.opaque_ref import OpaqueRef -from monarch.common.remote import remote +from monarch._src.tensor_engine.common.opaque_ref import OpaqueRef +from monarch._src.tensor_engine.common.remote import remote def _invoke_method(obj: OpaqueRef, method_name: str, *args, **kwargs): diff --git a/python/monarch/parallel/pipelining/runtime.py b/python/monarch/parallel/pipelining/runtime.py index b8fb95e3..07d21e43 100644 --- a/python/monarch/parallel/pipelining/runtime.py +++ b/python/monarch/parallel/pipelining/runtime.py @@ -19,7 +19,7 @@ import torch.optim as optim from monarch import fetch_shard, no_mesh, OpaqueRef, remote, Stream, Tensor -from monarch.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh from monarch.opaque_module import OpaqueModule from .schedule_ir import ( diff --git a/python/monarch/profiler.py b/python/monarch/profiler.py index 0915909c..41258f5f 100644 --- a/python/monarch/profiler.py +++ b/python/monarch/profiler.py @@ -13,7 +13,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple import torch -from monarch.common.remote import remote +from monarch._src.tensor_engine.common.remote import remote from monarch.remote_class import ControllerRemoteClass, WorkerRemoteClass diff --git a/python/monarch/python_local_mesh.py b/python/monarch/python_local_mesh.py index c3de5980..f9eabd2b 100644 --- a/python/monarch/python_local_mesh.py +++ b/python/monarch/python_local_mesh.py @@ -12,14 +12,17 @@ import monarch_supervisor from monarch._src.actor.device_utils import _local_device_count -from monarch.common.fake import fake_call -from monarch.common.invocation import DeviceException, RemoteException +from monarch._src.tensor_engine.common.fake import fake_call +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) from monarch.world_mesh import world_mesh from monarch_supervisor import Context, HostConnected from monarch_supervisor.python_executable import PYTHON_EXECUTABLE if TYPE_CHECKING: - from monarch.common.device_mesh import DeviceMesh + from monarch._src.tensor_engine.common.device_mesh import DeviceMesh class PythonLocalContext: diff --git a/python/monarch/random.py b/python/monarch/random.py index bc9c4a8b..12bc633b 100644 --- a/python/monarch/random.py +++ b/python/monarch/random.py @@ -8,8 +8,8 @@ from typing import NamedTuple, Tuple import torch -from monarch.common.remote import remote -from monarch.common.tensor import Tensor +from monarch._src.tensor_engine.common.remote import remote +from monarch._src.tensor_engine.common.tensor import Tensor class State(NamedTuple): diff --git a/python/monarch/remote_class.py b/python/monarch/remote_class.py index 4824ad1a..88ed9353 100644 --- a/python/monarch/remote_class.py +++ b/python/monarch/remote_class.py @@ -9,8 +9,8 @@ import itertools from typing import Any, Dict -from monarch.common import device_mesh -from monarch.common.remote import remote +from monarch._src.tensor_engine.common import device_mesh +from monarch._src.tensor_engine.common.remote import remote class ControllerRemoteClass: diff --git a/python/monarch/rust_backend_mesh.py b/python/monarch/rust_backend_mesh.py index 260e26f1..cb2ddbe7 100644 --- a/python/monarch/rust_backend_mesh.py +++ b/python/monarch/rust_backend_mesh.py @@ -22,10 +22,13 @@ ) from monarch._src.actor.shape import NDSlice -from monarch.common.client import Client -from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus -from monarch.common.invocation import DeviceException, RemoteException -from monarch.common.mast import MastJob +from monarch._src.tensor_engine.common.client import Client +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh, DeviceMeshStatus +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) +from monarch._src.tensor_engine.common.mast import MastJob from monarch.controller.rust_backend.controller import RustController TORCHX_MAST_TASK_GROUP_NAME = "script" diff --git a/python/monarch/rust_local_mesh.py b/python/monarch/rust_local_mesh.py index 9b9cf2cf..81b7bf7f 100644 --- a/python/monarch/rust_local_mesh.py +++ b/python/monarch/rust_local_mesh.py @@ -54,9 +54,12 @@ ActorId, ) -from monarch.common.device_mesh import DeviceMesh -from monarch.common.fake import fake_call -from monarch.common.invocation import DeviceException, RemoteException +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.fake import fake_call +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) from monarch.rust_backend_mesh import ( IBootstrap, MeshWorld, diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index ffd447fb..25973de1 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -42,16 +42,19 @@ ) from monarch._src.actor.shape import NDSlice -from monarch.common.client import Client -from monarch.common.constants import ( +from monarch._src.tensor_engine.common.client import Client +from monarch._src.tensor_engine.common.constants import ( SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL, SIM_MESH_CLIENT_TIMEOUT, ) -from monarch.common.device_mesh import DeviceMesh -from monarch.common.fake import fake_call -from monarch.common.future import Future, T -from monarch.common.invocation import DeviceException, RemoteException -from monarch.common.messages import Dims +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.fake import fake_call +from monarch._src.tensor_engine.common.future import Future, T +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) +from monarch._src.tensor_engine.common.messages import Dims from monarch.controller.rust_backend.controller import RustController from monarch.rust_backend_mesh import MeshWorld diff --git a/python/monarch/simulator/command_history.py b/python/monarch/simulator/command_history.py index 30755579..7571cee4 100644 --- a/python/monarch/simulator/command_history.py +++ b/python/monarch/simulator/command_history.py @@ -14,7 +14,7 @@ import torch from monarch._src.actor.shape import NDSlice -from monarch.common import messages +from monarch._src.tensor_engine.common import messages from monarch.simulator.ir import IRGraph from monarch.simulator.tensor import DTensorRef from monarch.simulator.utils import clean_name, file_path_with_iter diff --git a/python/monarch/simulator/interface.py b/python/monarch/simulator/interface.py index 96b779b3..a70fba97 100644 --- a/python/monarch/simulator/interface.py +++ b/python/monarch/simulator/interface.py @@ -8,8 +8,8 @@ from monarch._src.actor.shape import NDSlice -from monarch.common.client import Client as _Client -from monarch.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.client import Client as _Client +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh from monarch.simulator.ir import IRGraph from monarch.simulator.simulator import ( diff --git a/python/monarch/simulator/mock_controller.py b/python/monarch/simulator/mock_controller.py index ab1db332..31d8a179 100644 --- a/python/monarch/simulator/mock_controller.py +++ b/python/monarch/simulator/mock_controller.py @@ -27,16 +27,24 @@ ) from monarch._src.actor.shape import iter_ranks, NDSlice, Slices as Ranks -from monarch.common import messages +from monarch._src.tensor_engine.common import messages -from monarch.common.controller_api import DebuggerMessage, LogMessage, MessageResult -from monarch.common.device_mesh import no_mesh -from monarch.common.invocation import Invocation, RemoteException, Seq -from monarch.common.reference import Ref -from monarch.common.tree import flatten +from monarch._src.tensor_engine.common.controller_api import ( + DebuggerMessage, + LogMessage, + MessageResult, +) +from monarch._src.tensor_engine.common.device_mesh import no_mesh +from monarch._src.tensor_engine.common.invocation import ( + Invocation, + RemoteException, + Seq, +) +from monarch._src.tensor_engine.common.reference import Ref +from monarch._src.tensor_engine.common.tree import flatten if TYPE_CHECKING: - from monarch.common.tensor import Tensor + from monarch._src.tensor_engine.common.tensor import Tensor logger = logging.getLogger(__name__) diff --git a/python/monarch/simulator/profiling.py b/python/monarch/simulator/profiling.py index 6302231f..aca90406 100644 --- a/python/monarch/simulator/profiling.py +++ b/python/monarch/simulator/profiling.py @@ -31,15 +31,15 @@ import torch import torch.distributed as dist -from monarch.common import messages -from monarch.common.function import resolvable_function -from monarch.common.function_caching import ( +from monarch._src.tensor_engine.common import messages +from monarch._src.tensor_engine.common.function import resolvable_function +from monarch._src.tensor_engine.common.function_caching import ( hashable_tensor_flatten, HashableTreeSpec, key_filters, TensorGroup, ) -from monarch.common.tensor_factory import TensorFactory +from monarch._src.tensor_engine.common.tensor_factory import TensorFactory from monarch.simulator.command_history import CommandHistory, DTensorRef from torch.utils import _pytree as pytree from torch.utils._mode_utils import no_dispatch diff --git a/python/monarch/simulator/simulator.py b/python/monarch/simulator/simulator.py index f6568b7f..4d97406e 100644 --- a/python/monarch/simulator/simulator.py +++ b/python/monarch/simulator/simulator.py @@ -44,11 +44,14 @@ ActorId, ) from monarch._src.actor.shape import iter_ranks, NDSlice -from monarch.common import messages -from monarch.common.controller_api import LogMessage, MessageResult -from monarch.common.device_mesh import DeviceMesh -from monarch.common.function import ResolvableFunction, ResolvableFunctionFromPath -from monarch.common.invocation import DeviceException +from monarch._src.tensor_engine.common import messages +from monarch._src.tensor_engine.common.controller_api import LogMessage, MessageResult +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.function import ( + ResolvableFunction, + ResolvableFunctionFromPath, +) +from monarch._src.tensor_engine.common.invocation import DeviceException from monarch.simulator.command_history import CommandHistory, DTensorRef from monarch.simulator.config import META_VAL from monarch.simulator.ir import IRGraph diff --git a/python/monarch/simulator/tensor.py b/python/monarch/simulator/tensor.py index 10dd864d..749406cf 100644 --- a/python/monarch/simulator/tensor.py +++ b/python/monarch/simulator/tensor.py @@ -14,8 +14,8 @@ from typing import Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union import torch -from monarch.common.fake import fake_call -from monarch.common.tensor_factory import TensorFactory +from monarch._src.tensor_engine.common.fake import fake_call +from monarch._src.tensor_engine.common.tensor_factory import TensorFactory from monarch.simulator.task import Task, WorkerTaskManager logger = logging.getLogger(__name__) diff --git a/python/monarch/tensorboard.py b/python/monarch/tensorboard.py index e3d1496c..bd15efd5 100644 --- a/python/monarch/tensorboard.py +++ b/python/monarch/tensorboard.py @@ -8,7 +8,8 @@ import logging from typing import Any -from monarch.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh + from monarch.remote_class import ControllerRemoteClass, WorkerRemoteClass from torch.utils.tensorboard import SummaryWriter diff --git a/python/monarch/worker/_testing_function.py b/python/monarch/worker/_testing_function.py index 4448293e..aebc3bf3 100644 --- a/python/monarch/worker/_testing_function.py +++ b/python/monarch/worker/_testing_function.py @@ -20,10 +20,12 @@ PdbActor, ) from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction -from monarch.common import opaque_ref -from monarch.common.pipe import Pipe -from monarch.common.process_group import SingleControllerProcessGroupWrapper -from monarch.common.remote import remote +from monarch._src.tensor_engine.common import opaque_ref +from monarch._src.tensor_engine.common.pipe import Pipe +from monarch._src.tensor_engine.common.process_group import ( + SingleControllerProcessGroupWrapper, +) +from monarch._src.tensor_engine.common.remote import remote from torch.utils.data import DataLoader, TensorDataset diff --git a/python/monarch/worker/compiled_block.py b/python/monarch/worker/compiled_block.py index f20400bf..d0e3ab88 100644 --- a/python/monarch/worker/compiled_block.py +++ b/python/monarch/worker/compiled_block.py @@ -12,8 +12,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING import torch.fx -from monarch.common.messages import DependentOnError -from monarch.common.tree import tree_map +from monarch._src.tensor_engine.common.messages import DependentOnError +from monarch._src.tensor_engine.common.tree import tree_map from torch.fx.proxy import GraphAppendingTracer from .lines import Lines diff --git a/python/monarch/worker/debugger.py b/python/monarch/worker/debugger.py index 650dfe4c..89e3c374 100644 --- a/python/monarch/worker/debugger.py +++ b/python/monarch/worker/debugger.py @@ -13,7 +13,7 @@ import sys from typing import Optional, TYPE_CHECKING -from monarch.common import messages +from monarch._src.tensor_engine.common import messages logger = logging.getLogger(__name__) diff --git a/python/monarch/worker/worker.py b/python/monarch/worker/worker.py index 89b48924..32817410 100644 --- a/python/monarch/worker/worker.py +++ b/python/monarch/worker/worker.py @@ -39,13 +39,15 @@ import zmq.asyncio from monarch._src.actor.shape import NDSlice -from monarch.common import messages -from monarch.common.function import ResolvableFunction -from monarch.common.messages import DependentOnError, Dims -from monarch.common.process_group import SingleControllerProcessGroupWrapper -from monarch.common.reference import Ref, Referenceable -from monarch.common.tensor_factory import TensorFactory -from monarch.common.tree import flatten, flattener +from monarch._src.tensor_engine.common import messages +from monarch._src.tensor_engine.common.function import ResolvableFunction +from monarch._src.tensor_engine.common.messages import DependentOnError, Dims +from monarch._src.tensor_engine.common.process_group import ( + SingleControllerProcessGroupWrapper, +) +from monarch._src.tensor_engine.common.reference import Ref, Referenceable +from monarch._src.tensor_engine.common.tensor_factory import TensorFactory +from monarch._src.tensor_engine.common.tree import flatten, flattener from monarch_supervisor import get_message_queue, Letter from monarch_supervisor.logging import initialize_logging diff --git a/python/monarch/world_mesh.py b/python/monarch/world_mesh.py index 52698eb9..bc6f0cfc 100644 --- a/python/monarch/world_mesh.py +++ b/python/monarch/world_mesh.py @@ -10,9 +10,9 @@ from monarch._src.actor.shape import NDSlice -from monarch.common.client import Client +from monarch._src.tensor_engine.common.client import Client -from monarch.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh from monarch.controller.backend import ProcessBackend diff --git a/python/tests/dispatch_bench.py b/python/tests/dispatch_bench.py index ffbcbb55..42e2de56 100644 --- a/python/tests/dispatch_bench.py +++ b/python/tests/dispatch_bench.py @@ -11,12 +11,12 @@ import torch import torch.utils.benchmark as benchmark +from monarch._src.tensor_engine.common._coalescing import coalescing +from monarch._src.tensor_engine.common.remote import remote + # this function helps get a local device mesh for testing from monarch._testing import mock_mesh from monarch.builtins.log import set_logging_level_remote - -from monarch.common._coalescing import coalescing -from monarch.common.remote import remote from monarch.fetch import fetch_shard from monarch.python_local_mesh import python_local_mesh from monarch_supervisor.logging import initialize_logging diff --git a/python/tests/dispatch_bench_helper.py b/python/tests/dispatch_bench_helper.py index edc8f045..16f9b8d8 100644 --- a/python/tests/dispatch_bench_helper.py +++ b/python/tests/dispatch_bench_helper.py @@ -7,7 +7,7 @@ # pyre-unsafe import torch -from monarch.common.remote import remote +from monarch._src.tensor_engine.common.remote import remote def run_loop_local(n_iters, tensor_shape=(2, 2)): diff --git a/python/tests/simulator/test_profiling.py b/python/tests/simulator/test_profiling.py index 3b608a7e..4dcd7cbb 100644 --- a/python/tests/simulator/test_profiling.py +++ b/python/tests/simulator/test_profiling.py @@ -11,7 +11,7 @@ import torch -from monarch.common import messages +from monarch._src.tensor_engine.common import messages from monarch.simulator.profiling import RuntimeEstimator, RuntimeProfiler, TimingType diff --git a/python/tests/simulator/test_worker.py b/python/tests/simulator/test_worker.py index 68d1b8e0..51ac4dd8 100644 --- a/python/tests/simulator/test_worker.py +++ b/python/tests/simulator/test_worker.py @@ -9,7 +9,7 @@ from typing import Tuple import torch -from monarch.common.fake import fake_call +from monarch._src.tensor_engine.common.fake import fake_call from monarch.simulator.profiling import RuntimeEstimator from monarch.simulator.task import Task diff --git a/python/tests/test_coalescing.py b/python/tests/test_coalescing.py index 86568fc4..e581379d 100644 --- a/python/tests/test_coalescing.py +++ b/python/tests/test_coalescing.py @@ -27,10 +27,14 @@ remote, Stream, ) +from monarch._src.tensor_engine.common._coalescing import _record_and_define, compile +from monarch._src.tensor_engine.common.function_caching import ( + AliasOf, + Storage, + TensorGroup, +) +from monarch._src.tensor_engine.common.tensor import Tensor from monarch._testing import TestingContext -from monarch.common._coalescing import _record_and_define, compile -from monarch.common.function_caching import AliasOf, Storage, TensorGroup -from monarch.common.tensor import Tensor def _do_bogus_tensor_work(x, y, fail_rank=None): @@ -234,7 +238,7 @@ def test_no_coalescing(self, backend_type) -> None: @contextmanager def assertRecorded(self, times: int): with patch( - "monarch.common._coalescing._record_and_define", + "monarch._src.tensor_engine.common._coalescing._record_and_define", side_effect=_record_and_define, ) as m: yield diff --git a/python/tests/test_controller.py b/python/tests/test_controller.py index 963dac97..002f587c 100644 --- a/python/tests/test_controller.py +++ b/python/tests/test_controller.py @@ -27,12 +27,12 @@ Stream, Tensor, ) +from monarch._src.tensor_engine.common.controller_api import LogMessage +from monarch._src.tensor_engine.common.invocation import DeviceException +from monarch._src.tensor_engine.common.remote import remote +from monarch._src.tensor_engine.common.tree import flattener from monarch._testing import BackendType, TestingContext -from monarch.common.controller_api import LogMessage -from monarch.common.invocation import DeviceException -from monarch.common.remote import remote -from monarch.common.tree import flattener from monarch.rust_local_mesh import ( ControllerParams, local_mesh, diff --git a/python/tests/test_device_mesh.py b/python/tests/test_device_mesh.py index c84bb10c..c26a50d9 100644 --- a/python/tests/test_device_mesh.py +++ b/python/tests/test_device_mesh.py @@ -8,7 +8,7 @@ import pytest from monarch import DeviceMesh, NDSlice -from monarch.common.client import Client +from monarch._src.tensor_engine.common.client import Client from monarch.simulator.mock_controller import MockController diff --git a/python/tests/test_fault_tolerance.py b/python/tests/test_fault_tolerance.py index c27d8f84..b4a24500 100644 --- a/python/tests/test_fault_tolerance.py +++ b/python/tests/test_fault_tolerance.py @@ -17,8 +17,11 @@ from unittest import TestCase from monarch import fetch_shard, no_mesh, remote -from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus -from monarch.common.invocation import DeviceException, RemoteException +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh, DeviceMeshStatus +from monarch._src.tensor_engine.common.invocation import ( + DeviceException, + RemoteException, +) from monarch.rust_backend_mesh import MeshWorld, PoolDeviceMeshProvider from monarch.rust_local_mesh import ( Bootstrap, diff --git a/python/tests/test_future.py b/python/tests/test_future.py index e2dc9847..1c403232 100644 --- a/python/tests/test_future.py +++ b/python/tests/test_future.py @@ -13,8 +13,8 @@ from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch.common import future -from monarch.common.client import Client +from monarch._src.tensor_engine.common import future +from monarch._src.tensor_engine.common.client import Client class TestFuture: diff --git a/python/tests/test_mock_cuda.py b/python/tests/test_mock_cuda.py index 26d8bd74..746c1987 100644 --- a/python/tests/test_mock_cuda.py +++ b/python/tests/test_mock_cuda.py @@ -9,7 +9,7 @@ import pytest import torch -import monarch.common.mock_cuda # usort: skip +import monarch._src.tensor_engine.common.mock_cuda # usort: skip def simple_forward_backward(device: str) -> None: @@ -37,7 +37,7 @@ def setUp(self) -> None: return super().setUp() def test_output_is_garbage(self): - with monarch.common.mock_cuda.mock_cuda_guard(): + with monarch._src.tensor_engine.common.mock_cuda.mock_cuda_guard(): x = torch.arange(9, device="cuda", dtype=torch.float32).reshape(3, 3) y = 2 * torch.eye(3, device="cuda") true_output = torch.tensor( @@ -58,7 +58,7 @@ def test_turn_mock_on_and_off(self): self.assertTrue(torch.allclose(cpu_dw, real_dw.cpu())) self.assertTrue(torch.allclose(cpu_db, real_db.cpu())) - with monarch.common.mock_cuda.mock_cuda_guard(): + with monarch._src.tensor_engine.common.mock_cuda.mock_cuda_guard(): mocked_y, mocked_dw, mocked_db = simple_forward_backward("cuda") self.assertFalse(torch.allclose(cpu_y, mocked_y.cpu())) self.assertFalse(torch.allclose(cpu_dw, mocked_dw.cpu())) diff --git a/python/tests/test_remote_functions.py b/python/tests/test_remote_functions.py index 7b1b3398..c65cb36b 100644 --- a/python/tests/test_remote_functions.py +++ b/python/tests/test_remote_functions.py @@ -28,14 +28,14 @@ RemoteException as OldRemoteException, Stream, ) +from monarch._src.tensor_engine.common import remote as remote_module +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.remote import Remote from monarch._testing import BackendType, TestingContext from monarch.builtins.log import log_remote from monarch.builtins.random import set_manual_seed_remote from monarch.cached_remote_function import remote_autograd_function -from monarch.common import remote as remote_module -from monarch.common.device_mesh import DeviceMesh -from monarch.common.remote import Remote from monarch.mesh_controller import RemoteException as NewRemoteException from monarch.opaque_module import OpaqueModule diff --git a/python/tests/test_rust_backend.py b/python/tests/test_rust_backend.py index 3e2a925b..49b10b37 100644 --- a/python/tests/test_rust_backend.py +++ b/python/tests/test_rust_backend.py @@ -16,7 +16,7 @@ import torch import torch.utils._python_dispatch from monarch import fetch_shard, no_mesh, remote, Stream -from monarch.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh from monarch.rust_local_mesh import local_meshes, LoggingLocation, SocketType from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.functional import scaled_dot_product_attention @@ -186,7 +186,10 @@ def test_ivalue_problems(self) -> None: with local_mesh(hosts=1, gpu_per_host=1): from typing import cast - from monarch.common.messages import CallFunction, CommandGroup + from monarch._src.tensor_engine.common.messages import ( + CallFunction, + CommandGroup, + ) a = cast(monarch.Tensor, torch.rand(3, 4)) result = monarch.Tensor(a._fake, a.mesh, a.stream) @@ -194,7 +197,7 @@ def test_ivalue_problems(self) -> None: 0, result, (), - monarch.common.function.ResolvableFunctionFromPath( + monarch._src.tensor_engine.common.function.ResolvableFunctionFromPath( "torch.ops.aten.mul.Tensor" ), (2, a), diff --git a/python/tests/test_sim_backend.py b/python/tests/test_sim_backend.py index 1ea89cd5..f192b508 100644 --- a/python/tests/test_sim_backend.py +++ b/python/tests/test_sim_backend.py @@ -14,7 +14,7 @@ import torch from monarch import fetch_shard -from monarch.common.device_mesh import DeviceMesh +from monarch._src.tensor_engine.common.device_mesh import DeviceMesh from monarch.sim_mesh import sim_mesh diff --git a/setup.py b/setup.py index 85cf3ed9..10ce6d0a 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ monarch_cpp_src.append("python/monarch/common/mock_cuda.cpp") common_C = CppExtension( - "monarch.common._C", + "monarch._src.tensor_engine.common._C", monarch_cpp_src, extra_compile_args=["-g", "-O3"], libraries=["dl"],