Skip to content

Commit 7293451

Browse files
pzhan9facebook-github-bot
authored andcommitted
Add def slice to PythonActorMesh and PythonActorMeshRef (#551)
Summary: Pull Request resolved: #551 This diffs adds `def slice` method to both `PythonActorMesh` and `PythonActorMeshRef`. With this method, we can: 1. slice a `PythonActorMesh` object into a `PythonActorMeshRef`; 1. slice a `PythonActorMeshRef` into another `PythonActorMeshRef`. Tests are added to demo that we can cast to the sliced mesh ref. Reviewed By: shayne-fletcher Differential Revision: D78292490 fbshipit-source-id: 3497ea72a5cf2e41b7bbbcf30e202430001a750d
1 parent 5bd2302 commit 7293451

File tree

5 files changed

+284
-54
lines changed

5 files changed

+284
-54
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use ndslice::Range;
3636
use ndslice::Selection;
3737
use ndslice::Shape;
3838
use ndslice::ShapeError;
39+
use ndslice::SliceError;
3940
use ndslice::selection;
4041
use ndslice::selection::EvalOpts;
4142
use ndslice::selection::ReifyView;
@@ -95,6 +96,47 @@ where
9596
Ok(())
9697
}
9798

99+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
100+
pub(crate) fn cast_to_sliced_mesh<A, M>(
101+
caps: &impl cap::CanSend,
102+
actor_mesh_id: ActorMeshId,
103+
sender: &ActorId,
104+
comm_actor_ref: &ActorRef<CommActor>,
105+
sel_of_sliced: &Selection,
106+
message: M,
107+
sliced_shape: &Shape,
108+
base_shape: &Shape,
109+
) -> Result<(), CastError>
110+
where
111+
A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
112+
M: Castable + RemoteMessage,
113+
{
114+
let base_slice = base_shape.slice();
115+
116+
// Casting to `*`?
117+
let sel_of_base = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
118+
// Reify this view into base.
119+
base_slice.reify_view(sliced_shape.slice())?
120+
} else {
121+
// No, fall back on `of_ranks`.
122+
let ranks = sel_of_sliced
123+
.eval(&EvalOpts::strict(), sliced_shape.slice())?
124+
.collect::<BTreeSet<_>>();
125+
Selection::of_ranks(base_slice, &ranks)?
126+
};
127+
128+
// Cast.
129+
actor_mesh_cast::<A, M>(
130+
caps,
131+
actor_mesh_id,
132+
base_shape,
133+
sender,
134+
comm_actor_ref,
135+
sel_of_base,
136+
message,
137+
)
138+
}
139+
98140
/// A mesh of actors, all of which reside on the same [`ProcMesh`].
99141
pub trait ActorMesh: Mesh<Id = ActorMeshId> {
100142
/// The type of actor in the mesh.
@@ -350,31 +392,15 @@ impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {
350392
Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
351393
M: Castable + RemoteMessage,
352394
{
353-
let base_shape = self.0.shape();
354-
let base_slice = base_shape.slice();
355-
356-
// Casting to `*`?
357-
let selection = if selection::normalize(&sel) == normal::NormalizedSelection::True {
358-
// Reify this view into base.
359-
base_slice.reify_view(self.shape().slice()).unwrap()
360-
} else {
361-
// No, fall back on `of_ranks`.
362-
let ranks = sel
363-
.eval(&EvalOpts::strict(), self.shape().slice())
364-
.unwrap()
365-
.collect::<BTreeSet<_>>();
366-
Selection::of_ranks(base_slice, &ranks).unwrap()
367-
};
368-
369-
// Cast.
370-
actor_mesh_cast::<A, M>(
371-
self.proc_mesh().client(), // send capability
372-
self.id(), // actor mesh id (destination mesh)
373-
base_shape, // actor mesh shape
374-
self.proc_mesh().client().actor_id(), // sender
375-
self.proc_mesh().comm_actor(), // comm actor
376-
selection, // the selected actors
377-
message, // the message
395+
cast_to_sliced_mesh::<A, M>(
396+
/*caps=*/ self.proc_mesh().client(),
397+
/*actor_mesh_id=*/ self.id(),
398+
/*sender=*/ self.proc_mesh().client().actor_id(),
399+
/*comm_actor_ref*/ self.proc_mesh().comm_actor(),
400+
/*sel_of_sliced=*/ &sel,
401+
/*message=*/ message,
402+
/*sliced_shape=*/ self.shape(),
403+
/*base_shape=*/ self.0.shape(),
378404
)
379405
}
380406
}
@@ -394,6 +420,9 @@ pub enum CastError {
394420
#[error(transparent)]
395421
ShapeError(#[from] ShapeError),
396422

423+
#[error(transparent)]
424+
SliceError(#[from] SliceError),
425+
397426
#[error(transparent)]
398427
SerializationError(#[from] bincode::Error),
399428

hyperactor_mesh/src/reference.rs

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ use hyperactor::actor::RemoteActor;
1919
use hyperactor::cap;
2020
use hyperactor::message::Castable;
2121
use hyperactor::message::IndexedErasedUnbound;
22+
use ndslice::Range;
2223
use ndslice::Selection;
2324
use ndslice::Shape;
25+
use ndslice::ShapeError;
2426
use serde::Deserialize;
2527
use serde::Serialize;
2628

2729
use crate::CommActor;
2830
use crate::actor_mesh::CastError;
2931
use crate::actor_mesh::actor_mesh_cast;
32+
use crate::actor_mesh::cast_to_sliced_mesh;
3033

3134
#[macro_export]
3235
macro_rules! mesh_id {
@@ -71,10 +74,15 @@ pub struct ProcMeshId(pub String);
7174
pub struct ActorMeshId(pub ProcMeshId, pub String);
7275

7376
/// Types references to Actor Meshes.
74-
#[derive(Debug, Serialize, Deserialize)]
77+
#[derive(Debug, Serialize, Deserialize, PartialEq)]
7578
pub struct ActorMeshRef<A: RemoteActor> {
7679
pub(crate) mesh_id: ActorMeshId,
77-
shape: Shape,
80+
/// The shape of the root mesh.
81+
root: Shape,
82+
/// If some, it mean this mesh ref points to a sliced mesh, and this field
83+
/// is this sliced mesh's shape. If None, it means this mesh ref points to
84+
/// the root mesh.
85+
sliced: Option<Shape>,
7886
/// The reference to the comm actor of the underlying Proc Mesh.
7987
comm_actor_ref: ActorRef<CommActor>,
8088
phantom: PhantomData<A>,
@@ -87,12 +95,13 @@ impl<A: RemoteActor> ActorMeshRef<A> {
8795
/// line argument) is a valid reference.
8896
pub(crate) fn attest(
8997
mesh_id: ActorMeshId,
90-
shape: Shape,
98+
root: Shape,
9199
comm_actor_ref: ActorRef<CommActor>,
92100
) -> Self {
93101
Self {
94102
mesh_id,
95-
shape,
103+
root,
104+
sliced: None,
96105
comm_actor_ref,
97106
phantom: PhantomData,
98107
}
@@ -105,7 +114,10 @@ impl<A: RemoteActor> ActorMeshRef<A> {
105114

106115
/// Shape of the Actor Mesh.
107116
pub fn shape(&self) -> &Shape {
108-
&self.shape
117+
match &self.sliced {
118+
Some(s) => s,
119+
None => &self.root,
120+
}
109121
}
110122

111123
/// Cast an [`M`]-typed message to the ranks selected by `sel`
@@ -121,37 +133,53 @@ impl<A: RemoteActor> ActorMeshRef<A> {
121133
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
122134
M: Castable + RemoteMessage,
123135
{
124-
actor_mesh_cast::<A, M>(
125-
caps,
126-
self.mesh_id.clone(),
127-
self.shape(),
128-
caps.mailbox().actor_id(),
129-
&self.comm_actor_ref,
130-
selection,
131-
message,
132-
)
136+
match &self.sliced {
137+
Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
138+
caps,
139+
self.mesh_id.clone(),
140+
caps.mailbox().actor_id(),
141+
&self.comm_actor_ref,
142+
&selection,
143+
message,
144+
sliced_shape,
145+
&self.root,
146+
),
147+
None => actor_mesh_cast::<A, M>(
148+
caps,
149+
self.mesh_id.clone(),
150+
&self.root,
151+
caps.mailbox().actor_id(),
152+
&self.comm_actor_ref,
153+
selection,
154+
message,
155+
),
156+
}
157+
}
158+
159+
pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
160+
let sliced = self.shape().select(label, range)?;
161+
Ok(Self {
162+
mesh_id: self.mesh_id.clone(),
163+
root: self.root.clone(),
164+
sliced: Some(sliced),
165+
comm_actor_ref: self.comm_actor_ref.clone(),
166+
phantom: PhantomData,
167+
})
133168
}
134169
}
135170

136171
impl<A: RemoteActor> Clone for ActorMeshRef<A> {
137172
fn clone(&self) -> Self {
138173
Self {
139174
mesh_id: self.mesh_id.clone(),
140-
shape: self.shape.clone(),
175+
root: self.root.clone(),
176+
sliced: self.sliced.clone(),
141177
comm_actor_ref: self.comm_actor_ref.clone(),
142178
phantom: PhantomData,
143179
}
144180
}
145181
}
146182

147-
impl<A: RemoteActor> PartialEq for ActorMeshRef<A> {
148-
fn eq(&self, other: &Self) -> bool {
149-
self.mesh_id == other.mesh_id && self.shape == other.shape
150-
}
151-
}
152-
153-
impl<A: RemoteActor> Eq for ActorMeshRef<A> {}
154-
155183
#[cfg(test)]
156184
mod tests {
157185
use async_trait::async_trait;

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ use pyo3::exceptions::PyRuntimeError;
2727
use pyo3::exceptions::PyValueError;
2828
use pyo3::prelude::*;
2929
use pyo3::types::PyBytes;
30+
use pyo3::types::PyDict;
31+
use pyo3::types::PySlice;
3032
use serde::Deserialize;
3133
use serde::Serialize;
3234
use tokio::sync::Mutex;
@@ -178,6 +180,11 @@ impl PythonActorMesh {
178180
Ok(monitor_instance.into_py(py))
179181
}
180182

183+
#[pyo3(signature = (**kwargs))]
184+
fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PythonActorMeshRef> {
185+
self.bind()?.slice(kwargs)
186+
}
187+
181188
#[getter]
182189
pub fn client(&self) -> PyMailbox {
183190
self.client.clone()
@@ -222,6 +229,75 @@ impl PythonActorMeshRef {
222229
Ok(())
223230
}
224231

232+
#[pyo3(signature = (**kwargs))]
233+
fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
234+
// When the input type is `int`, convert it into `ndslice::Range`.
235+
fn convert_int(index: isize) -> PyResult<ndslice::Range> {
236+
if index < 0 {
237+
return Err(PyException::new_err(format!(
238+
"does not support negative index in selection: {}",
239+
index
240+
)));
241+
}
242+
Ok(ndslice::Range::from(index as usize))
243+
}
244+
245+
// When the input type is `slice`, convert it into `ndslice::Range`.
246+
fn convert_py_slice<'py>(s: &Bound<'py, PySlice>) -> PyResult<ndslice::Range> {
247+
fn get_attr<'py>(s: &Bound<'py, PySlice>, attr: &str) -> PyResult<Option<isize>> {
248+
let v = s.getattr(attr)?.extract::<Option<isize>>()?;
249+
if v.is_some() && v.unwrap() < 0 {
250+
return Err(PyException::new_err(format!(
251+
"does not support negative {} in slice: {}",
252+
attr,
253+
v.unwrap(),
254+
)));
255+
}
256+
Ok(v)
257+
}
258+
259+
let start = get_attr(s, "start")?.unwrap_or(0);
260+
let stop: Option<isize> = get_attr(s, "stop")?;
261+
let step = get_attr(s, "step")?.unwrap_or(1);
262+
Ok(ndslice::Range(
263+
start as usize,
264+
stop.map(|s| s as usize),
265+
step as usize,
266+
))
267+
}
268+
269+
if kwargs.is_none() || kwargs.unwrap().is_empty() {
270+
return Err(PyException::new_err("selection cannot be empty"));
271+
}
272+
273+
let mut sliced = self.inner.clone();
274+
275+
for entry in kwargs.unwrap().items() {
276+
let label = entry.get_item(0)?.str()?;
277+
let label_str = label.to_str()?;
278+
279+
let value = entry.get_item(1)?;
280+
281+
let range = if let Ok(index) = value.extract::<isize>() {
282+
convert_int(index)?
283+
} else if let Ok(s) = value.downcast::<PySlice>() {
284+
convert_py_slice(s)?
285+
} else {
286+
return Err(PyException::new_err(
287+
"selection only supports type int or slice",
288+
));
289+
};
290+
sliced = sliced.select(label_str, range).map_err(|err| {
291+
PyException::new_err(format!(
292+
"failed to select label {}; error is: {}",
293+
label_str, err
294+
))
295+
})?;
296+
}
297+
298+
Ok(Self { inner: sliced })
299+
}
300+
225301
#[getter]
226302
fn shape(&self) -> PyShape {
227303
PyShape::from(self.inner.shape().clone())

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
from collections.abc import Mapping
109
from typing import AsyncIterator, final
1110

1211
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
@@ -18,6 +17,7 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
1817
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1918
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
2019
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
20+
from typing_extensions import Self
2121

2222
@final
2323
class PythonActorMeshRef:
@@ -31,6 +31,12 @@ class PythonActorMeshRef:
3131
"""Cast a message to the selected actors in the mesh."""
3232
...
3333

34+
def slice(self, **kwargs: int | slice[int | None, int | None, int | None]) -> Self:
35+
"""
36+
See PythonActorMeshRef.slice for documentation.
37+
"""
38+
...
39+
3440
@property
3541
def shape(self) -> Shape:
3642
"""
@@ -53,6 +59,22 @@ class PythonActorMesh:
5359
"""
5460
Cast a message to the selected actors in the mesh.
5561
"""
62+
...
63+
64+
def slice(
65+
self, **kwargs: int | slice[int | None, int | None, int | None]
66+
) -> PythonActorMeshRef:
67+
"""
68+
Slice the mesh into a new mesh ref with the given selection. The reason
69+
it returns a mesh ref, rather than the mesh object itself, is because
70+
sliced mesh is a view of the original mesh, and does not own the mesh's
71+
resources.
72+
73+
Arguments:
74+
- `kwargs`: argument name is the label, and argument value is how to
75+
slice the mesh along the dimension of that label.
76+
"""
77+
...
5678

5779
def get_supervision_event(self) -> ActorSupervisionEvent | None:
5880
"""

0 commit comments

Comments
 (0)