Skip to content

[monarch] initial commit of tensor_engine submodule #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/suo/51/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions monarch_extension/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct MessageParser<'a> {
fn create_function(obj: Bound<'_, PyAny>) -> PyResult<ResolvableFunction> {
let cloudpickle = obj
.py()
.import("monarch.common.function")?
.import("monarch._src.tensor_engine.common.function")?
.getattr("ResolvableFromCloudpickle")?;
if obj.is_instance(&cloudpickle)? {
Ok(ResolvableFunction::Cloudpickle(Cloudpickle::new(
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<'a> MessageParser<'a> {
let referenceable = self
.current
.py()
.import("monarch.common.reference")?
.import("monarch._src.tensor_engine.common.reference")?
.getattr("Referenceable")?;
let mut flat: Vec<Option<Ref>> = vec![];
for x in output_tuple.0.try_iter()? {
Expand Down Expand Up @@ -198,8 +198,8 @@ static CONVERT_MAP: OnceLock<HashMap<u64, FnType>> = OnceLock::new();

fn create_map(py: Python) -> HashMap<u64, FnType> {
let messages = py
.import("monarch.common.messages")
.expect("import monarch.common.messages");
.import("monarch._src.tensor_engine.common.messages")
.expect("import monarch._src.tensor_engine.common.messages");
let mut m: HashMap<u64, FnType> = HashMap::new();
let key = |name: &str| {
messages
Expand Down
82 changes: 50 additions & 32 deletions python/monarch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from monarch import timer
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
from monarch._src.actor.shape import NDSlice, Shape
from monarch.common._coalescing import coalescing
from monarch._src.tensor_engine.common._coalescing import coalescing

from monarch.common.device_mesh import (
from monarch._src.tensor_engine.common.device_mesh import (
DeviceMesh,
get_active_mesh,
no_mesh,
Expand All @@ -45,17 +45,23 @@
to_mesh,
)

from monarch.common.function import resolvers as function_resolvers
from monarch._src.tensor_engine.common.function import (
resolvers as function_resolvers,
)

from monarch.common.future import Future
from monarch._src.tensor_engine.common.future import Future

from monarch.common.invocation import RemoteException
from monarch.common.opaque_ref import OpaqueRef
from monarch.common.pipe import create_pipe, Pipe, remote_generator
from monarch.common.remote import remote
from monarch.common.selection import Selection
from monarch.common.stream import get_active_stream, Stream
from monarch.common.tensor import reduce, reduce_, Tensor
from monarch._src.tensor_engine.common.invocation import RemoteException
from monarch._src.tensor_engine.common.opaque_ref import OpaqueRef
from monarch._src.tensor_engine.common.pipe import (
create_pipe,
Pipe,
remote_generator,
)
from monarch._src.tensor_engine.common.remote import remote
from monarch._src.tensor_engine.common.selection import Selection
from monarch._src.tensor_engine.common.stream import get_active_stream, Stream
from monarch._src.tensor_engine.common.tensor import reduce, reduce_, Tensor
from monarch.fetch import fetch_shard, inspect, show
from monarch.gradient_generator import grad_function, grad_generator
from monarch.notebook import mast_mesh, reserve_torchx as mast_reserve
Expand All @@ -72,29 +78,41 @@


_public_api = {
"coalescing": ("monarch.common._coalescing", "coalescing"),
"remote": ("monarch.common.remote", "remote"),
"DeviceMesh": ("monarch.common.device_mesh", "DeviceMesh"),
"get_active_mesh": ("monarch.common.device_mesh", "get_active_mesh"),
"no_mesh": ("monarch.common.device_mesh", "no_mesh"),
"RemoteProcessGroup": ("monarch.common.device_mesh", "RemoteProcessGroup"),
"function_resolvers": ("monarch.common.function", "resolvers"),
"Future": ("monarch.common.future", "Future"),
"RemoteException": ("monarch.common.invocation", "RemoteException"),
"coalescing": ("monarch._src.tensor_engine.common._coalescing", "coalescing"),
"remote": ("monarch._src.tensor_engine.common.remote", "remote"),
"DeviceMesh": ("monarch._src.tensor_engine.common.device_mesh", "DeviceMesh"),
"get_active_mesh": (
"monarch._src.tensor_engine.common.device_mesh",
"get_active_mesh",
),
"no_mesh": ("monarch._src.tensor_engine.common.device_mesh", "no_mesh"),
"RemoteProcessGroup": (
"monarch._src.tensor_engine.common.device_mesh",
"RemoteProcessGroup",
),
"function_resolvers": ("monarch._src.tensor_engine.common.function", "resolvers"),
"Future": ("monarch._src.tensor_engine.common.future", "Future"),
"RemoteException": (
"monarch._src.tensor_engine.common.invocation",
"RemoteException",
),
"Shape": ("monarch._src.actor.shape", "Shape"),
"NDSlice": ("monarch._src.actor.shape", "NDSlice"),
"Selection": ("monarch.common.selection", "Selection"),
"OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"),
"create_pipe": ("monarch.common.pipe", "create_pipe"),
"Pipe": ("monarch.common.pipe", "Pipe"),
"remote_generator": ("monarch.common.pipe", "remote_generator"),
"get_active_stream": ("monarch.common.stream", "get_active_stream"),
"Stream": ("monarch.common.stream", "Stream"),
"Tensor": ("monarch.common.tensor", "Tensor"),
"reduce": ("monarch.common.tensor", "reduce"),
"reduce_": ("monarch.common.tensor", "reduce_"),
"to_mesh": ("monarch.common.device_mesh", "to_mesh"),
"slice_mesh": ("monarch.common.device_mesh", "slice_mesh"),
"Selection": ("monarch._src.tensor_engine.common.selection", "Selection"),
"OpaqueRef": ("monarch._src.tensor_engine.common.opaque_ref", "OpaqueRef"),
"create_pipe": ("monarch._src.tensor_engine.common.pipe", "create_pipe"),
"Pipe": ("monarch._src.tensor_engine.common.pipe", "Pipe"),
"remote_generator": ("monarch._src.tensor_engine.common.pipe", "remote_generator"),
"get_active_stream": (
"monarch._src.tensor_engine.common.stream",
"get_active_stream",
),
"Stream": ("monarch._src.tensor_engine.common.stream", "Stream"),
"Tensor": ("monarch._src.tensor_engine.common.tensor", "Tensor"),
"reduce": ("monarch._src.tensor_engine.common.tensor", "reduce"),
"reduce_": ("monarch._src.tensor_engine.common.tensor", "reduce_"),
"to_mesh": ("monarch._src.tensor_engine.common.device_mesh", "to_mesh"),
"slice_mesh": ("monarch._src.tensor_engine.common.device_mesh", "slice_mesh"),
"call_on_shard_and_fetch": ("monarch.fetch", "call_on_shard_and_fetch"),
"fetch_shard": ("monarch.fetch", "fetch_shard"),
"inspect": ("monarch.fetch", "inspect"),
Expand Down
2 changes: 1 addition & 1 deletion python/monarch/_rust_bindings/monarch_extension/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from typing import Any, ClassVar, Dict, final, List, NamedTuple, Union

from monarch._rust_bindings.monarch_extension.tensor_worker import Ref
from monarch._rust_bindings.monarch_messages.debugger import DebuggerActionType
from monarch._src.actor._extension.monarch_hyperactor.proc import (
from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/python/monarch/_src/actor:actor
ActorId,
Proc,
Serialized,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,17 @@
)

import torch
from monarch.common import messages

from monarch.common.fake import fake_call
from monarch.common.function_caching import (
hashable_tensor_flatten,
TensorGroup,
TensorGroupPattern,
)
from monarch.common.tensor import InputChecker, Tensor
from monarch.common.tree import flatten
from . import messages

if TYPE_CHECKING:
from monarch.common.client import Recorder
from monarch.common.recording import Recording
from .fake import fake_call
from .function_caching import hashable_tensor_flatten, TensorGroup, TensorGroupPattern
from .tensor import InputChecker, Tensor
from .tree import flatten

from .client import Client
if TYPE_CHECKING:
from .client import Client, Recorder
from .recording import Recording

_coalescing = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,20 @@
WorldState,
)
from monarch._src.actor.shape import NDSlice
from monarch.common import messages
from monarch.common.borrows import Borrow, StorageAliases
from monarch.common.controller_api import LogMessage, MessageResult, TController
from monarch.common.device_mesh import DeviceMesh

from monarch.common.future import Future
from monarch.common.invocation import DeviceException, RemoteException, Seq
from monarch.common.recording import flatten_messages, Recording
from . import _coalescing, messages
from .borrows import Borrow, StorageAliases
from .controller_api import LogMessage, MessageResult, TController
from .device_mesh import DeviceMesh

from monarch.common.reference import Ref, Referenceable
from monarch.common.stream import StreamRef
from monarch.common.tensor import Tensor
from monarch.common.tree import tree_map
from .future import Future
from .invocation import DeviceException, RemoteException, Seq
from .recording import flatten_messages, Recording

from . import _coalescing
from .reference import Ref, Referenceable
from .stream import StreamRef
from .tensor import Tensor
from .tree import tree_map


logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from monarch._src.actor.shape import NDSlice

from monarch.common.invocation import DeviceException, RemoteException, Seq
from monarch.common.reference import Ref
from monarch.common.tensor import Tensor
from .invocation import DeviceException, RemoteException, Seq
from .reference import Ref
from .tensor import Tensor


class LogMessage(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
Union,
)

import monarch.common.messages as messages
import torch
from monarch._src.actor.shape import MeshTrait, NDSlice, Shape

from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map

from . import messages

from ._tensor_to_table import tensor_to_table
from .context_manager import activate_first_context_manager
from .messages import Dims
Expand All @@ -41,7 +42,7 @@
from .tensor import MeshSliceTensor, Tensor

if TYPE_CHECKING:
from monarch.common.client import Client
from .client import Client

logger: Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -382,7 +383,7 @@ def _remote(*args, **kwargs):
# we break the dependency to allow for separate files by
# having device_mesh and tensor locally import the `remote`
# entrypoint
from monarch.common.remote import remote
from .remote import remote

return remote(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monarch_supervisor import TTL

if TYPE_CHECKING:
from monarch.common.client import Client
from .client import Client

from .invocation import RemoteException

Expand Down
125 changes: 125 additions & 0 deletions python/monarch/_src/tensor_engine/common/invocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
import traceback
from typing import Any, List, Optional, Tuple

from monarch._src.actor._extension.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
ActorId,
)


Seq = int


class DeviceException(Exception):
"""
Non-deterministic failure in the underlying worker, controller or its infrastructure.
For example, a worker may enter a crash loop, or its GPU may be lost
"""

def __init__(
self,
exception: Exception,
frames: List[traceback.FrameSummary],
source_actor_id: ActorId,
message: str,
):
self.exception = exception
self.frames = frames
self.source_actor_id = source_actor_id
self.message = message

def __str__(self):
try:
exe = str(self.exception)
worker_tb = "".join(traceback.format_list(self.frames))
return (
f"{self.message}\n"
f"Traceback of the failure on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}"
)
except Exception as e:
print(e)
return "oops"


class RemoteException(Exception):
"""
Deterministic problem with the user's code.
For example, an OOM resulting in trying to allocate too much GPU memory, or violating
some invariant enforced by the various APIs.
"""

def __init__(
self,
seq: Seq,
exception: Exception,
controller_frame_index: Optional[int],
controller_frames: Optional[List[traceback.FrameSummary]],
worker_frames: List[traceback.FrameSummary],
source_actor_id: ActorId,
message="A remote function has failed asynchronously.",
):
self.exception = exception
self.worker_frames = worker_frames
self.message = message
self.seq = seq
self.controller_frame_index = controller_frame_index
self.source_actor_id = source_actor_id
self.controller_frames = controller_frames

def __str__(self):
try:
exe = str(self.exception)
worker_tb = "".join(traceback.format_list(self.worker_frames))
controller_tb = (
"".join(traceback.format_list(self.controller_frames))
if self.controller_frames is not None
else " <not related to a specific invocation>\n"
)
return (
f"{self.message}\n"
f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}"
f"Traceback of where the remote function failed on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}"
)
except Exception as e:
print(e)
return "oops"


class Invocation:
def __init__(self, seq: Seq):
self.seq = seq
self.users: Optional[set["Invocation"]] = set()
self.failure: Optional[RemoteException] = None
self.fut_value: Any = None

def __repr__(self):
return f"<Invocation {self.seq}>"

def fail(self, remote_exception: RemoteException):
if self.failure is None or self.failure.seq > remote_exception.seq:
self.failure = remote_exception
return True
return False

def add_user(self, r: "Invocation"):
if self.users is not None:
self.users.add(r)
if self.failure is not None:
r.fail(self.failure)

def complete(self) -> Tuple[Any, Optional[RemoteException]]:
"""
Complete the current invocation.
Return the result and exception tuple.
"""
# after completion we no longer need to inform users of failures
# since they will just immediately get the value during add_user
self.users = None

return (self.fut_value if self.failure is None else None, self.failure)
Loading
Loading