Skip to content

Commit 9de1436

Browse files
committed
[11/n] tensor engine, fix bugs to make test_remote_functions pass
Fixes: * formatting of error messages * transposition of all_gathers, via a workaround. the real fix relies on fixing nccl-comm actor to respect the order of dimensions in process groups. * cloning of WireValue, which drops the is_wrapped_number property incorrectly. Differential Revision: [D77880706](https://our.internmc.facebook.com/intern/diff/D77880706/) ghstack-source-id: 294754402 Pull Request resolved: #455
1 parent 0e4f4a9 commit 9de1436

File tree

7 files changed

+22
-13
lines changed

7 files changed

+22
-13
lines changed

Cargo.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,3 @@ members = [
1414
"nccl-sys",
1515
"torch-sys",
1616
]
17-
18-
[profile.release]
19-
incremental = true

monarch_tensor_worker/src/stream.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ impl StreamActor {
984984
}
985985
let err = err.unwrap_dependent_error().unwrap_or(err);
986986
WorkerError {
987-
backtrace: format!("{:?}", err),
987+
backtrace: err.to_string(),
988988
worker_actor_id: worker_actor_id.clone(),
989989
}
990990
})

python/monarch/mesh_controller.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None:
124124
}
125125
os.environ.update(process_env)
126126
pdb.set_trace = _set_trace
127+
# workaround for set_manual_seed somehow not working if cuda is not initialized
128+
torch.cuda.init()
127129
except Exception:
128130
traceback.print_exc()
129131
raise
@@ -248,7 +250,7 @@ def __str__(self):
248250
return (
249251
f"A remote function has failed asynchronously on rank {self.rank}.\n"
250252
f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}"
251-
f"Error as reported from worker!!!!!!!:\n{self.worker_error_string}"
253+
f"Error as reported from worker:\n{self.worker_error_string}"
252254
)
253255
except Exception:
254256
traceback.print_exc()

python/monarch/proc_mesh.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,10 @@ async def proc_mesh_nonblocking(
299299
) -> ProcMesh:
300300
if gpus is None:
301301
gpus = _local_device_count()
302-
spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
302+
# gpus must come last in this order otherwise test_remote_function_all_gather
303+
# because test_remote_function_all_gather expects that hosts comes before gpus
304+
# in the order of the dimensions.
305+
spec = AllocSpec(AllocConstraints(), hosts=hosts, gpus=gpus)
303306
env = env or {}
304307
cmd, args, base_env = _get_bootstrap_args()
305308
env.update(base_env)
@@ -313,7 +316,10 @@ def proc_mesh_blocking(
313316
) -> ProcMesh:
314317
if gpus is None:
315318
gpus = _local_device_count()
316-
spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
319+
# gpus must come last in this order otherwise test_remote_function_all_gather
320+
# because test_remote_function_all_gather expects that hosts comes before gpus
321+
# in the order of the dimensions.
322+
spec = AllocSpec(AllocConstraints(), hosts=hosts, gpus=gpus)
317323
env = env or {}
318324
cmd, args, base_env = _get_bootstrap_args()
319325
env.update(base_env)

python/tests/test_remote_functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def local_device_mesh(
185185
# out is not counted as a failure, so we set a more restrictive timeout to
186186
# ensure we see a hard failure in CI.
187187
@pytest.mark.timeout(120)
188-
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
188+
@pytest.mark.parametrize(
189+
"backend_type", [BackendType.PY, BackendType.RS, BackendType.MESH]
190+
)
189191
class TestRemoteFunctions(RemoteFunctionsTestBase):
190192
@classmethod
191193
def do_test_reduce_scatter_tensor(cls, backend_type, reduce_op, expected_tensor):
@@ -952,10 +954,13 @@ def test_remote_function_failure_message_contains_traceback(self, backend_type):
952954
x = outer_remote_function_that_calls_inner()
953955
try:
954956
inspect(x)
955-
except RemoteException as e:
957+
except OldRemoteException as e:
956958
backtrace = "\n".join([frame.name for frame in e.worker_frames])
957959
assert "outer_remote_function" in backtrace
958960
assert "inner_remote_function" in backtrace
961+
except NewRemoteException as e:
962+
assert "outer_remote_function" in e.worker_error_string
963+
assert "inner_remote_function" in e.worker_error_string
959964

960965
def test_remote_function_broadcast(self, backend_type):
961966
with self.local_device_mesh(2, 2, backend_type) as device_mesh:

torch-sys/Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ links = "torch"
1212
anyhow = "1.0.98"
1313
async-trait = "0.1.86"
1414
atomic_refcell = "0.1.13"
15+
bincode = "1.3.3"
1516
cxx = "1.0.119"
1617
derive_more = { version = "1.0.0", features = ["full"] }
1718
monarch_types = { version = "0.0.0", path = "../monarch_types" }
@@ -24,9 +25,6 @@ thiserror = "2.0.12"
2425
tokio = { version = "1.45.0", features = ["full", "test-util", "tracing"] }
2526
tracing = { version = "0.1.41", features = ["attributes", "valuable"] }
2627

27-
[dev-dependencies]
28-
bincode = "1.3.3"
29-
3028
[build-dependencies]
3129
bindgen = "0.70.1"
3230
cxx-build = "1.0.119"

torch-sys/src/ivalue.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ impl Clone for OpaqueIValue {
150150
/// This creates a deep copy of the underlying data and can be expensive.
151151
/// It might also panic if the `IValue` is not cloneable.
152152
fn clone(&self) -> Self {
153-
Self(ffi::ivalue_deepcopy(&self.0).unwrap())
153+
let serialized = bincode::serialize(&self.0).unwrap();
154+
bincode::deserialize(&serialized).unwrap()
154155
}
155156
}
156157

0 commit comments

Comments
 (0)