Skip to content

Commit df14fc1

Browse files
James Sunfacebook-github-bot
authored andcommitted
(3/n) Specialize state actor bootstrap for each alloc (#459)
Summary: Different alloc should have different setup of state actor. The remote allocator should bootstrap the state actor inside the initializer. This will be done in the follow-up diffs. Reviewed By: kaiyuan-li Differential Revision: D77914042
1 parent 7fb56d0 commit df14fc1

File tree

9 files changed

+583
-32
lines changed

9 files changed

+583
-32
lines changed

hyperactor_extension/src/alloc.rs

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::collections::HashMap;
10+
use std::sync::Arc;
11+
use std::sync::Mutex;
12+
13+
use async_trait::async_trait;
14+
use hyperactor::WorldId;
15+
use hyperactor::channel::ChannelTransport;
16+
use hyperactor_mesh::alloc::Alloc;
17+
use hyperactor_mesh::alloc::AllocConstraints;
18+
use hyperactor_mesh::alloc::AllocSpec;
19+
use hyperactor_mesh::alloc::AllocatorError;
20+
use hyperactor_mesh::alloc::ProcState;
21+
use hyperactor_mesh::log_source::LogSource;
22+
use hyperactor_mesh::shape::Shape;
23+
use ndslice::Slice;
24+
use pyo3::exceptions::PyValueError;
25+
use pyo3::prelude::*;
26+
use pyo3::types::PyDict;
27+
28+
/// A python class that wraps a Rust Alloc trait object. It represents what
29+
/// is shown on the python side. Internals are not exposed.
30+
/// It ensures that the Alloc is only used once (i.e. moved) in rust.
31+
#[pyclass(
32+
name = "Alloc",
33+
module = "monarch._rust_bindings.hyperactor_extension.alloc"
34+
)]
35+
pub struct PyAlloc {
36+
pub inner: Arc<Mutex<Option<PyAllocWrapper>>>,
37+
}
38+
39+
impl PyAlloc {
40+
/// Create a new PyAlloc with provided boxed trait.
41+
pub fn new(inner: Box<dyn Alloc + Sync + Send>) -> Self {
42+
Self {
43+
inner: Arc::new(Mutex::new(Some(PyAllocWrapper { inner }))),
44+
}
45+
}
46+
47+
/// Take the internal Alloc object.
48+
pub fn take(&self) -> Option<PyAllocWrapper> {
49+
self.inner.lock().unwrap().take()
50+
}
51+
}
52+
53+
#[pymethods]
54+
impl PyAlloc {
55+
fn __repr__(&self) -> PyResult<String> {
56+
let data = self.inner.lock().unwrap();
57+
match &*data {
58+
None => Ok("Alloc(None)".to_string()),
59+
Some(wrapper) => Ok(format!("Alloc({})", wrapper.shape())),
60+
}
61+
}
62+
}
63+
64+
/// Internal wrapper to translate from a dyn Alloc to an impl Alloc. Used
65+
/// to support polymorphism in the Python bindings.
66+
pub struct PyAllocWrapper {
67+
inner: Box<dyn Alloc + Sync + Send>,
68+
}
69+
70+
#[async_trait]
71+
impl Alloc for PyAllocWrapper {
72+
async fn next(&mut self) -> Option<ProcState> {
73+
self.inner.next().await
74+
}
75+
76+
fn shape(&self) -> &Shape {
77+
self.inner.shape()
78+
}
79+
80+
fn world_id(&self) -> &WorldId {
81+
self.inner.world_id()
82+
}
83+
84+
fn transport(&self) -> ChannelTransport {
85+
self.inner.transport()
86+
}
87+
88+
async fn log_source(&self) -> Result<LogSource, AllocatorError> {
89+
self.inner.log_source().await
90+
}
91+
92+
async fn stop(&mut self) -> Result<(), AllocatorError> {
93+
self.inner.stop().await
94+
}
95+
}
96+
97+
#[pyclass(
98+
name = "AllocConstraints",
99+
module = "monarch._rust_bindings.hyperactor_extension.alloc"
100+
)]
101+
pub struct PyAllocConstraints {
102+
inner: AllocConstraints,
103+
}
104+
105+
#[pymethods]
106+
impl PyAllocConstraints {
107+
#[new]
108+
#[pyo3(signature = (match_labels=None))]
109+
fn new(match_labels: Option<HashMap<String, String>>) -> PyResult<Self> {
110+
let mut constraints = AllocConstraints::default();
111+
if let Some(match_lables) = match_labels {
112+
constraints.match_labels = match_lables;
113+
}
114+
115+
Ok(Self { inner: constraints })
116+
}
117+
}
118+
119+
#[pyclass(
120+
name = "AllocSpec",
121+
module = "monarch._rust_bindings.hyperactor_extension.alloc"
122+
)]
123+
pub struct PyAllocSpec {
124+
pub inner: AllocSpec,
125+
}
126+
127+
#[pymethods]
128+
impl PyAllocSpec {
129+
#[new]
130+
#[pyo3(signature = (constraints, **kwargs))]
131+
fn new(constraints: &PyAllocConstraints, kwargs: Option<&Bound<'_, PyAny>>) -> PyResult<Self> {
132+
let Some(kwargs) = kwargs else {
133+
return Err(PyValueError::new_err(
134+
"Shape must have at least one dimension",
135+
));
136+
};
137+
let shape_dict = kwargs.downcast::<PyDict>()?;
138+
139+
let mut keys = Vec::new();
140+
let mut values = Vec::new();
141+
for (key, value) in shape_dict {
142+
keys.push(key.clone());
143+
values.push(value.clone());
144+
}
145+
146+
let shape = Shape::new(
147+
keys.into_iter()
148+
.map(|key| key.extract::<String>())
149+
.collect::<PyResult<Vec<String>>>()?,
150+
Slice::new_row_major(
151+
values
152+
.into_iter()
153+
.map(|key| key.extract::<usize>())
154+
.collect::<PyResult<Vec<usize>>>()?,
155+
),
156+
)
157+
.map_err(|e| PyValueError::new_err(format!("Invalid shape: {:?}", e)))?;
158+
159+
Ok(Self {
160+
inner: AllocSpec {
161+
shape,
162+
constraints: constraints.inner.clone(),
163+
},
164+
})
165+
}
166+
}
167+
168+
pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
169+
module.add_class::<PyAlloc>()?;
170+
module.add_class::<PyAllocConstraints>()?;
171+
module.add_class::<PyAllocSpec>()?;
172+
173+
Ok(())
174+
}

hyperactor_mesh/src/alloc.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,7 @@ pub trait Alloc {
223223
/// It allows remote processes to stream stdout and stderr back to the client.
224224
/// A client can connect to the log source to obtain the streamed logs.
225225
/// A log source is allocation specific. Each allocator can decide how to stream the logs back.
226-
async fn log_source(&self) -> Result<LogSource, AllocatorError> {
227-
// TODO: this should be implemented based on different allocators.
228-
// Having this temporarily here so that the client can connect to the log source.
229-
// But the client will not get anything.
230-
// The following diffs will gradually implement this for different allocators.
231-
LogSource::new_with_local_actor()
232-
.await
233-
.map_err(AllocatorError::from)
234-
}
226+
async fn log_source(&self) -> Result<LogSource, AllocatorError>;
235227

236228
/// Stop this alloc, shutting down all of its procs. A clean
237229
/// shutdown should result in Stop events from all allocs,
@@ -367,6 +359,10 @@ pub mod test_utils {
367359
self.alloc.transport()
368360
}
369361

362+
async fn log_source(&self) -> Result<LogSource, AllocatorError> {
363+
self.alloc.log_source().await
364+
}
365+
370366
async fn stop(&mut self) -> Result<(), AllocatorError> {
371367
self.alloc.stop().await
372368
}

hyperactor_mesh/src/alloc/local.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::alloc::AllocSpec;
3333
use crate::alloc::Allocator;
3434
use crate::alloc::AllocatorError;
3535
use crate::alloc::ProcState;
36+
use crate::log_source::LogSource;
3637
use crate::proc_mesh::mesh_agent::MeshAgent;
3738
use crate::shortuuid::ShortUuid;
3839

@@ -252,6 +253,14 @@ impl Alloc for LocalAlloc {
252253
ChannelTransport::Local
253254
}
254255

256+
async fn log_source(&self) -> Result<LogSource, AllocatorError> {
257+
// Local alloc does not need to stream logs back.
258+
// The client can subscribe to it but local actors will not stream logs into it.
259+
LogSource::new_with_local_actor()
260+
.await
261+
.map_err(AllocatorError::from)
262+
}
263+
255264
async fn stop(&mut self) -> Result<(), AllocatorError> {
256265
for rank in 0..self.size() {
257266
self.todo_tx

hyperactor_mesh/src/alloc/process.rs

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ use hyperactor::channel::ChannelTx;
2929
use hyperactor::channel::Rx;
3030
use hyperactor::channel::Tx;
3131
use hyperactor::channel::TxStatus;
32-
use hyperactor::id;
3332
use hyperactor::sync::flag;
3433
use hyperactor::sync::monitor;
3534
use hyperactor_state::state_actor::StateActor;
@@ -53,6 +52,8 @@ use crate::bootstrap;
5352
use crate::bootstrap::Allocator2Process;
5453
use crate::bootstrap::Process2Allocator;
5554
use crate::bootstrap::Process2AllocatorMessage;
55+
use crate::log_source::LogSource;
56+
use crate::log_source::StateServerInfo;
5657
use crate::shortuuid::ShortUuid;
5758

5859
/// The maximum number of log lines to tail keep for managed processes.
@@ -89,6 +90,9 @@ impl Allocator for ProcessAllocator {
8990
let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
9091
.await
9192
.map_err(anyhow::Error::from)?;
93+
let log_source = LogSource::new_with_local_actor()
94+
.await
95+
.map_err(AllocatorError::from)?;
9296

9397
let name = ShortUuid::generate();
9498
let n = spec.shape.slice().len();
@@ -97,6 +101,7 @@ impl Allocator for ProcessAllocator {
97101
world_id: WorldId(name.to_string()),
98102
spec: spec.clone(),
99103
bootstrap_addr,
104+
log_source,
100105
rx,
101106
index: 0,
102107
active: HashMap::new(),
@@ -115,6 +120,7 @@ pub struct ProcessAlloc {
115120
world_id: WorldId, // to provide storage
116121
spec: AllocSpec,
117122
bootstrap_addr: ChannelAddr,
123+
log_source: LogSource,
118124
rx: channel::ChannelRx<Process2Allocator>,
119125
index: usize,
120126
active: HashMap<usize, Child>,
@@ -145,13 +151,14 @@ struct Child {
145151
impl Child {
146152
fn monitored(
147153
mut process: tokio::process::Child,
154+
state_server_info: StateServerInfo,
148155
) -> (Self, impl Future<Output = ProcStopReason>) {
149156
let (group, handle) = monitor::group();
150157
let (exit_flag, exit_guard) = flag::guarded();
151158
let stop_reason = Arc::new(OnceLock::new());
152159

153160
// TODO(lky): enable state actor branch and remove this flag
154-
let use_state_actor = false;
161+
let use_state_actor = true;
155162

156163
// Set up stdout and stderr writers
157164
let mut stdout_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
@@ -161,24 +168,20 @@ impl Child {
161168

162169
// If state actor is enabled, try to set up LogWriter instances
163170
if use_state_actor {
164-
let state_actor_ref = ActorRef::<StateActor>::attest(id!(state_server[0].state[0]));
165-
// Parse the state actor address
166-
if let Ok(state_actor_addr) = "tcp![::]:3000".parse::<ChannelAddr>() {
167-
// Use the helper function to create both writers at once
168-
match hyperactor_state::log_writer::create_log_writers(
169-
state_actor_addr,
170-
state_actor_ref,
171-
) {
172-
Ok((stdout_writer, stderr_writer)) => {
173-
stdout_tee = stdout_writer;
174-
stderr_tee = stderr_writer;
175-
}
176-
Err(e) => {
177-
tracing::error!("failed to create log writers: {}", e);
178-
}
171+
let state_actor_ref = ActorRef::<StateActor>::attest(state_server_info.state_actor_id);
172+
let state_actor_addr = state_server_info.state_proc_addr;
173+
// Use the helper function to create both writers at once
174+
match hyperactor_state::log_writer::create_log_writers(
175+
state_actor_addr,
176+
state_actor_ref,
177+
) {
178+
Ok((stdout_writer, stderr_writer)) => {
179+
stdout_tee = stdout_writer;
180+
stderr_tee = stderr_writer;
181+
}
182+
Err(e) => {
183+
tracing::error!("failed to create log writers: {}", e);
179184
}
180-
} else {
181-
tracing::error!("failed to parse state actor address");
182185
}
183186
}
184187

@@ -394,7 +397,8 @@ impl ProcessAlloc {
394397
None
395398
}
396399
Ok(rank) => {
397-
let (handle, monitor) = Child::monitored(process);
400+
let (handle, monitor) =
401+
Child::monitored(process, self.log_source.server_info());
398402
self.children.spawn(async move { (index, monitor.await) });
399403
self.active.insert(index, handle);
400404
// Adjust for shape slice offset for non-zero shapes (sub-shapes).
@@ -498,6 +502,10 @@ impl Alloc for ProcessAlloc {
498502
ChannelTransport::Unix
499503
}
500504

505+
async fn log_source(&self) -> Result<LogSource, AllocatorError> {
506+
Ok(self.log_source.clone())
507+
}
508+
501509
async fn stop(&mut self) -> Result<(), AllocatorError> {
502510
// We rely on the teardown here, and that the process should
503511
// exit on its own. We should have a hard timeout here as well,

0 commit comments

Comments
 (0)