Skip to content

Commit 23b3a1d

Browse files
: selection: handle dim mismatch in reify_view
Summary: improve `reify_view` to support dimension mismatches by falling back to `Selection::of_ranks`. this enables `reify_views` to be used in `mesh_controller`, simplifying cast selection logic. also improves ergonomics by allowing `reify_views` to accept slices directly (e.g., `&[Slice]` or `Vec<Slice>`). Differential Revision: D78035345
1 parent 0a5e8b1 commit 23b3a1d

File tree

3 files changed

+47
-43
lines changed

3 files changed

+47
-43
lines changed

hyperactor_mesh/src/reference.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,6 @@ impl<A: RemoteActor> ActorMeshRef<A> {
112112
&self.shape
113113
}
114114

115-
/// Shape of the underlying Proc Mesh.
116-
fn proc_mesh_shape(&self) -> &Shape {
117-
&self.proc_mesh_shape
118-
}
119-
120-
fn name(&self) -> &str {
121-
&self.mesh_id.1
122-
}
123-
124115
/// Cast an [`M`]-typed message to the ranks selected by `sel`
125116
/// in this ActorMesh.
126117
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.

monarch_extension/src/mesh_controller.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
*/
88

99
use std::collections::BTreeMap;
10-
use std::collections::BTreeSet;
1110
use std::collections::HashMap;
1211
use std::collections::HashSet;
1312
use std::collections::VecDeque;
@@ -55,9 +54,9 @@ use monarch_messages::worker::WorkerMessage;
5554
use monarch_messages::worker::WorkerParams;
5655
use monarch_tensor_worker::AssignRankMessage;
5756
use monarch_tensor_worker::WorkerActor;
58-
use ndslice::Selection;
5957
use ndslice::Slice;
6058
use ndslice::selection;
59+
use ndslice::selection::ReifyView;
6160
use pyo3::exceptions::PyValueError;
6261
use pyo3::prelude::*;
6362
use tokio::sync::Mutex;
@@ -808,19 +807,8 @@ impl Handler<ClientToControllerMessage> for MeshControllerActor {
808807
) -> anyhow::Result<()> {
809808
match message {
810809
ClientToControllerMessage::Send { slices, message } => {
811-
let selection = slices
812-
.iter()
813-
.map(|slice| {
814-
Selection::of_ranks(
815-
self.workers().shape().slice(),
816-
&slice.iter().collect::<BTreeSet<usize>>(),
817-
)
818-
})
819-
.collect::<Result<Vec<_>, _>>()?
820-
.into_iter()
821-
.reduce(selection::dsl::union)
822-
.unwrap_or_else(selection::dsl::false_);
823-
self.workers().cast(selection, message.clone())?;
810+
let sel = self.workers().shape().slice().reify_views(slices)?;
811+
self.workers().cast(sel, message)?;
824812
}
825813
ClientToControllerMessage::Node {
826814
seq,

ndslice/src/selection.rs

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,7 @@ pub trait ReifyView: sealed::Sealed {
10131013

10141014
/// Reify multiple views as a union of selections in the
10151015
/// coordinate system of `self`.
1016-
fn reify_views(&self, views: &[&Slice]) -> Result<Selection, SliceError>;
1016+
fn reify_views<V: AsRef<[Slice]>>(&self, views: V) -> Result<Selection, SliceError>;
10171017
}
10181018

10191019
impl ReifyView for Slice {
@@ -1030,7 +1030,6 @@ impl ReifyView for Slice {
10301030
/// # Errors
10311031
///
10321032
/// Returns an error if:
1033-
/// - The number of dimensions in the view does not match the base
10341033
/// - The view lies outside the bounds of the base slice
10351034
///
10361035
/// # Example
@@ -1044,16 +1043,14 @@ impl ReifyView for Slice {
10441043
/// let selection = base.reify_view(view).unwrap();
10451044
/// ```
10461045
fn reify_view(&self, view: &Slice) -> Result<Selection, SliceError> {
1047-
if view.num_dim() != self.num_dim() {
1048-
return Err(SliceError::InvalidDims {
1049-
expected: self.num_dim(),
1050-
got: view.num_dim(),
1051-
});
1052-
}
10531046
if view.is_empty() {
10541047
return Ok(dsl::false_());
10551048
}
10561049

1050+
if view.num_dim() != self.num_dim() {
1051+
return Selection::of_ranks(self, &view.iter().collect::<BTreeSet<usize>>());
1052+
}
1053+
10571054
let origin = self.coordinates(view.offset())?;
10581055
let mut acc = dsl::true_();
10591056
for (&start, &len) in origin.iter().zip(view.sizes()).rev() {
@@ -1076,7 +1073,6 @@ impl ReifyView for Slice {
10761073
/// # Errors
10771074
///
10781075
/// Returns an error if any view:
1079-
/// - Has a different number of dimensions than the base slice
10801076
/// - Refers to coordinates not contained within the base
10811077
///
10821078
/// # Example
@@ -1087,15 +1083,22 @@ impl ReifyView for Slice {
10871083
/// let shape = ndslice::shape!(x = 4, y = 4);
10881084
/// let base = shape.slice();
10891085
///
1090-
/// let a = ndslice::select!(shape, x = 0..2, y = 0..2).unwrap();
1091-
/// let b = ndslice::select!(shape, x = 2..4, y = 2..4).unwrap();
1086+
/// let a = ndslice::select!(shape, x = 0..2, y = 0..2)
1087+
/// .unwrap()
1088+
/// .slice()
1089+
/// .clone();
1090+
/// let b = ndslice::select!(shape, x = 2..4, y = 2..4)
1091+
/// .unwrap()
1092+
/// .slice()
1093+
/// .clone();
10921094
///
1093-
/// let sel = base.reify_views(&[a.slice(), b.slice()]).unwrap();
1095+
/// let sel = base.reify_views(&[a, b]).unwrap();
10941096
/// ```
1095-
fn reify_views(&self, views: &[&Slice]) -> Result<Selection, SliceError> {
1097+
fn reify_views<V: AsRef<[Slice]>>(&self, views: V) -> Result<Selection, SliceError> {
1098+
let views = views.as_ref();
10961099
let mut selections = Vec::with_capacity(views.len());
10971100

1098-
for &view in views {
1101+
for view in views {
10991102
if view.is_empty() {
11001103
continue;
11011104
}
@@ -2157,6 +2160,28 @@ mod tests {
21572160
);
21582161
}
21592162

2163+
#[test]
2164+
fn test_reify_view_dimension_mismatch() {
2165+
let shape = shape!(host = 2, gpu = 4);
2166+
let base = shape.slice();
2167+
2168+
// Select the 3rd GPU (index 2) across both hosts i.e. flat
2169+
// indices [2, 6]
2170+
let indices = vec![
2171+
base.location(&[0, 2]).unwrap(),
2172+
base.location(&[1, 2]).unwrap(),
2173+
];
2174+
2175+
let view = Slice::new(indices[0], vec![indices.len()], vec![4]).unwrap();
2176+
let selection = base.reify_view(&view).unwrap();
2177+
2178+
let expected = Selection::of_ranks(base, &indices.iter().cloned().collect()).unwrap();
2179+
assert_structurally_eq!(&selection, expected);
2180+
2181+
let actual: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2182+
assert_eq!(actual, indices);
2183+
}
2184+
21602185
#[test]
21612186
fn test_union_of_slices_empty() {
21622187
let base = Slice::new_row_major([2]);
@@ -2175,7 +2200,7 @@ mod tests {
21752200
let shape = shape!(x = 3);
21762201
let base = shape.slice();
21772202
let selected = select!(shape, x = 1).unwrap();
2178-
let view = selected.slice();
2203+
let view = selected.slice().clone();
21792204

21802205
let selection = base.reify_views(&[view]).unwrap();
21812206
let expected = range(1..=1, true_());
@@ -2197,11 +2222,11 @@ mod tests {
21972222

21982223
// View A: (0, *)
21992224
let a = select!(shape, x = 0).unwrap();
2200-
let view_a = a.slice();
2225+
let view_a = a.slice().clone();
22012226

22022227
// View B: (1, *)
22032228
let b = select!(shape, x = 1).unwrap();
2204-
let view_b = b.slice();
2229+
let view_b = b.slice().clone();
22052230

22062231
let selection = base.reify_views(&[view_a, view_b]).unwrap();
22072232
let expected = union(
@@ -2224,10 +2249,10 @@ mod tests {
22242249
let base = shape.slice();
22252250

22262251
let selected1 = select!(shape, y = 0..2).unwrap();
2227-
let view1 = selected1.slice();
2252+
let view1 = selected1.slice().clone();
22282253

22292254
let selected2 = select!(shape, y = 1..4).unwrap();
2230-
let view2 = selected2.slice();
2255+
let view2 = selected2.slice().clone();
22312256

22322257
let selection = base.reify_views(&[view1, view2]).unwrap();
22332258
let expected = union(

0 commit comments

Comments
 (0)