Skip to content

Commit 216063f

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Refactor workspace location enum (#404)
Summary: Pull Request resolved: #404 This does some cleanup of the workspace location enum we use to represent the remote side location, to generalize for use in other locations (right now it's overly tied to the rsync module). Reviewed By: suo Differential Revision: D77456893 fbshipit-source-id: 7ef89c7ccdcc8fe3d10feded31a23d2a14d7532b
1 parent 9061e09 commit 216063f

File tree

8 files changed

+94
-28
lines changed

8 files changed

+94
-28
lines changed

hyperactor_mesh/src/code_sync.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
*/
88

99
pub mod rsync;
10+
mod workspace;
11+
12+
pub use workspace::WorkspaceLocation;

hyperactor_mesh/src/code_sync/rsync.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use tokio::process::Child;
3737
use tokio::process::Command;
3838

3939
use crate::actor_mesh::ActorMesh;
40+
use crate::code_sync::WorkspaceLocation;
4041
use crate::connect::Connect;
4142
use crate::connect::accept;
4243
use crate::connect::connect_mesh;
@@ -155,15 +156,9 @@ impl RsyncDaemon {
155156
}
156157
}
157158

158-
#[derive(Clone, Debug, Named, Serialize, Deserialize)]
159-
pub enum Workspace {
160-
Constant(PathBuf),
161-
FromEnvVar(String),
162-
}
163-
164159
#[derive(Debug, Named, Serialize, Deserialize)]
165160
pub struct RsyncParams {
166-
pub workspace: Workspace,
161+
pub workspace: WorkspaceLocation,
167162
}
168163

169164
#[derive(Debug)]
@@ -180,10 +175,7 @@ impl Actor for RsyncActor {
180175
type Params = RsyncParams;
181176

182177
async fn new(RsyncParams { workspace }: Self::Params) -> Result<Self> {
183-
let workspace = match workspace {
184-
Workspace::Constant(p) => p,
185-
Workspace::FromEnvVar(v) => PathBuf::from(std::env::var(v)?),
186-
};
178+
let workspace = workspace.resolve()?;
187179
let daemon = RsyncDaemon::spawn(TcpListener::bind(("::1", 0)).await?, &workspace).await?;
188180
Ok(Self { daemon })
189181
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::path::PathBuf;
10+
11+
use anyhow::Result;
12+
use serde::Deserialize;
13+
use serde::Serialize;
14+
15+
#[derive(Clone, Debug, Serialize, Deserialize)]
16+
pub enum WorkspaceLocation {
17+
Constant(PathBuf),
18+
FromEnvVar(String),
19+
}
20+
21+
impl WorkspaceLocation {
22+
pub fn resolve(&self) -> Result<PathBuf> {
23+
Ok(match self {
24+
WorkspaceLocation::Constant(p) => p.clone(),
25+
WorkspaceLocation::FromEnvVar(v) => PathBuf::from(std::env::var(v)?),
26+
})
27+
}
28+
}

monarch_extension/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ crate-type = ["cdylib"]
1515

1616
[dependencies]
1717
anyhow = "1.0.98"
18+
bincode = "1.3.3"
1819
clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] }
1920
controller = { version = "0.0.0", path = "../controller", optional = true }
2021
hyperactor = { version = "0.0.0", path = "../hyperactor" }
@@ -31,6 +32,7 @@ nccl-sys = { path = "../nccl-sys", optional = true }
3132
ndslice = { version = "0.0.0", path = "../ndslice" }
3233
pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] }
3334
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", rev = "c5a3746f110b4d246556b0f6c29f5f555919eee3", features = ["attributes", "tokio-runtime"] }
35+
serde = { version = "1.0.185", features = ["derive", "rc"] }
3436
tokio = { version = "1.45.0", features = ["full", "test-util", "tracing"] }
3537
torch-sys = { version = "0.0.0", path = "../torch-sys", optional = true }
3638
torch-sys-cuda = { version = "0.0.0", path = "../torch-sys-cuda", optional = true }

monarch_extension/src/code_sync.rs

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,66 @@ use std::sync::Arc;
1313

1414
use hyperactor_mesh::RootActorMesh;
1515
use hyperactor_mesh::SlicedActorMesh;
16+
use hyperactor_mesh::code_sync::WorkspaceLocation;
1617
use hyperactor_mesh::code_sync::rsync;
1718
use hyperactor_mesh::proc_mesh::SharedSpawnable;
1819
use hyperactor_mesh::shape::Shape;
1920
use monarch_hyperactor::proc_mesh::PyProcMesh;
2021
use monarch_hyperactor::runtime::signal_safe_block_on;
2122
use monarch_hyperactor::shape::PyShape;
2223
use pyo3::Bound;
24+
use pyo3::exceptions::PyRuntimeError;
25+
use pyo3::exceptions::PyValueError;
2326
use pyo3::prelude::*;
27+
use pyo3::types::PyBytes;
2428
use pyo3::types::PyModule;
29+
use serde::Deserialize;
30+
use serde::Serialize;
2531

26-
#[pyclass(frozen, module = "monarch._rust_bindings.monarch_extension.code_sync")]
27-
#[derive(Clone, Debug)]
28-
enum RemoteWorkspace {
32+
#[pyclass(
33+
frozen,
34+
name = "WorkspaceLocation",
35+
module = "monarch._rust_bindings.monarch_extension.code_sync"
36+
)]
37+
#[derive(Clone, Debug, Serialize, Deserialize)]
38+
enum PyWorkspaceLocation {
2939
Constant(PathBuf),
3040
FromEnvVar(String),
3141
}
3242

33-
impl From<RemoteWorkspace> for rsync::Workspace {
34-
fn from(workspace: RemoteWorkspace) -> rsync::Workspace {
43+
impl From<PyWorkspaceLocation> for WorkspaceLocation {
44+
fn from(workspace: PyWorkspaceLocation) -> WorkspaceLocation {
3545
match workspace {
36-
RemoteWorkspace::Constant(v) => rsync::Workspace::Constant(v),
37-
RemoteWorkspace::FromEnvVar(v) => rsync::Workspace::FromEnvVar(v),
46+
PyWorkspaceLocation::Constant(v) => WorkspaceLocation::Constant(v),
47+
PyWorkspaceLocation::FromEnvVar(v) => WorkspaceLocation::FromEnvVar(v),
3848
}
3949
}
4050
}
4151

52+
#[pymethods]
53+
impl PyWorkspaceLocation {
54+
#[staticmethod]
55+
fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
56+
bincode::deserialize(bytes.as_bytes())
57+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
58+
}
59+
60+
fn __reduce__<'py>(
61+
slf: &Bound<'py, Self>,
62+
) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
63+
let bytes = bincode::serialize(&*slf.borrow())
64+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
65+
let py_bytes = PyBytes::new(slf.py(), &bytes);
66+
Ok((slf.as_any().getattr("from_bytes")?, (py_bytes,)))
67+
}
68+
69+
fn resolve(&self) -> PyResult<PathBuf> {
70+
let loc: WorkspaceLocation = self.clone().into();
71+
loc.resolve()
72+
.map_err(|e| PyRuntimeError::new_err(format!("{}", e)))
73+
}
74+
}
75+
4276
#[pyclass(
4377
frozen,
4478
name = "RsyncMeshClient",
@@ -59,7 +93,7 @@ impl RsyncMeshClient {
5993
proc_mesh: &PyProcMesh,
6094
shape: &PyShape,
6195
local_workspace: PathBuf,
62-
remote_workspace: RemoteWorkspace,
96+
remote_workspace: PyWorkspaceLocation,
6397
) -> PyResult<Self> {
6498
let proc_mesh = Arc::clone(&proc_mesh.inner);
6599
let shape = shape.get_inner().clone();
@@ -92,7 +126,7 @@ impl RsyncMeshClient {
92126
}
93127

94128
pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
95-
module.add_class::<RemoteWorkspace>()?;
129+
module.add_class::<PyWorkspaceLocation>()?;
96130
module.add_class::<RsyncMeshClient>()?;
97131
Ok(())
98132
}

python/monarch/_rust_bindings/monarch_extension/code_sync.pyi

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,31 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from pathlib import Path
78
from typing import final
89

910
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh
1011

1112
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
1213

13-
class RemoteWorkspace:
14+
class WorkspaceLocation:
1415
"""
15-
Python binding for the Rust RemoteWorkspace enum.
16+
Python binding for the Rust WorkspaceLocation enum.
1617
"""
1718
@final
18-
class Constant(RemoteWorkspace):
19+
class Constant(WorkspaceLocation):
1920
def __init__(self, path) -> None: ...
2021

2122
@final
22-
class FromEnvVar(RemoteWorkspace):
23+
class FromEnvVar(WorkspaceLocation):
2324
def __init__(self, var) -> None: ...
2425

26+
def resolve(self) -> Path:
27+
"""
28+
Resolve the workspace location to a Path.
29+
"""
30+
...
31+
2532
@final
2633
class RsyncMeshClient:
2734
"""
@@ -32,6 +39,6 @@ class RsyncMeshClient:
3239
proc_mesh: ProcMesh,
3340
shape: Shape,
3441
local_workspace: str,
35-
remote_workspace: RemoteWorkspace,
42+
remote_workspace: WorkspaceLocation,
3643
) -> RsyncMeshClient: ...
3744
async def sync_workspace(self) -> None: ...

python/monarch/code_sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from monarch._rust_bindings.monarch_extension.code_sync import ( # noqa: F401
8-
RemoteWorkspace,
98
RsyncMeshClient,
9+
WorkspaceLocation,
1010
)

python/monarch/proc_mesh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
4646
from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
4747

48-
from monarch.code_sync import RemoteWorkspace, RsyncMeshClient
48+
from monarch.code_sync import RsyncMeshClient, WorkspaceLocation
4949
from monarch.common._device_utils import _local_device_count
5050
from monarch.common.shape import MeshTrait
5151
from monarch.rdma import RDMAManager
@@ -231,7 +231,7 @@ async def sync_workspace(self) -> None:
231231
# TODO(agallagher): Is there a better way to infer/set the local
232232
# workspace dir, rather than use PWD?
233233
local_workspace=os.getcwd(),
234-
remote_workspace=RemoteWorkspace.FromEnvVar("WORKSPACE_DIR"),
234+
remote_workspace=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
235235
)
236236
await self._rsync_mesh_client.sync_workspace()
237237

0 commit comments

Comments
 (0)