From 5d473a6077f6769a8588ad7355e7093d25f67e2b Mon Sep 17 00:00:00 2001 From: zdevito Date: Wed, 16 Jul 2025 10:36:57 -0700 Subject: [PATCH] [18/n] tensor engine: Simplify CallFunctionError The only case we every 'catch' is DependentError. Everything else is just a textural description. So this collapses those cases into just the two cases. This retains the enum variants as function calls so that we can also reify them again if we need to catch them but having them obscures the places where DependentError needs to be handled differently. Differential Revision: [D78363663](https://our.internmc.facebook.com/intern/diff/D78363663/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D78363663/)! [ghstack-poisoned] --- monarch_messages/src/worker.rs | 138 +++++++++++++++--------- monarch_simulator/src/worker.rs | 2 +- monarch_tensor_worker/src/borrow.rs | 5 +- monarch_tensor_worker/src/lib.rs | 24 +++-- monarch_tensor_worker/src/stream.rs | 158 ++++++++++++++++------------ 5 files changed, 194 insertions(+), 133 deletions(-) diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index 0192796f..8dda5535 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -506,74 +506,112 @@ pub enum StreamCreationMode { CreateNewStream, } -/// The kinds of errors that a CallFunction message can return with. -// TODO(agallagher): We should move most variants out into `ValueError`. +/// When a worker runs any function, it may not succeed either because the function itself +/// failed (Error) or because an input to the function already had an error value +/// DependentError. #[derive(Error, Debug, Named)] pub enum CallFunctionError { - #[error("ref not found: {0}")] - RefNotFound(Ref), + #[error("{0}")] + Error(#[from] anyhow::Error), + #[error("Computation depended on an input that failed with errror: {0}")] + DependentError(Arc), +} - #[error("dependent error {0}")] - DependentError(#[from] Arc), +impl CallFunctionError { + /// Checks if the error is a `DependentError` and returns the underlying + /// error if so. Otherwise, returns `None`. + pub fn unwrap_dependent_error(&self) -> Option> { + match self { + CallFunctionError::DependentError(e) => Some(e.clone()), + _ => None, + } + } - #[error("invalid remote function: {0}")] - InvalidRemoteFunction(String), + // Static functions for backward compatibility with existing enum cases - #[error("unsupported arg type for {0} function: {1}")] - UnsupportedArgType(String, String), + #[allow(non_snake_case)] + pub fn RefNotFound(r: Ref) -> Self { + Self::Error(anyhow::anyhow!("ref not found: {}", r)) + } - #[error("remote function failed: {0}")] - RemoteFunctionFailed(#[from] SerializablePyErr), + #[allow(non_snake_case)] + pub fn InvalidRemoteFunction(msg: String) -> Self { + Self::Error(anyhow::anyhow!("invalid remote function: {}", msg)) + } - #[error("borrow failed: {0}")] - BorrowError(#[from] BorrowError), + #[allow(non_snake_case)] + pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self { + Self::Error(anyhow::anyhow!( + "unsupported arg type for {} function: {}", + function_type, + arg_type + )) + } - #[error("torch operator failed: {0}")] - OperatorFailed(#[from] CallOpError), + #[allow(non_snake_case)] + pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self { + Self::Error(anyhow::anyhow!("remote function failed: {}", err)) + } - #[error("unexpected number of returns from op, expected {expected}, got {actual}")] - UnexpectedNumberOfReturns { expected: usize, actual: usize }, + #[allow(non_snake_case)] + pub fn BorrowError(err: BorrowError) -> Self { + Self::Error(anyhow::anyhow!("borrow failed: {}", err)) + } - #[error( - "expected only a single arg (and no kwargs) when no function is given: {args}, {kwargs}" - )] - TooManyArgsForValue { args: String, kwargs: String }, + #[allow(non_snake_case)] + pub fn OperatorFailed(err: CallOpError) -> Self { + Self::Error(anyhow::anyhow!("torch operator failed: {}", err)) + } - #[error("error: {0}")] - Anyhow(#[from] anyhow::Error), + #[allow(non_snake_case)] + pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self { + Self::Error(anyhow::anyhow!( + "unexpected number of returns from op, expected {}, got {}", + expected, + actual + )) + } - #[error("recording failed at message {index}: ({message}). Error: {error}")] - RecordingFailed { - index: usize, - message: String, - error: Arc, - }, + #[allow(non_snake_case)] + pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self { + Self::Error(anyhow::anyhow!( + "expected only a single arg (and no kwargs) when no function is given: {}, {}", + args, + kwargs + )) + } + + #[allow(non_snake_case)] + pub fn Anyhow(err: anyhow::Error) -> Self { + Self::Error(err) + } + + #[allow(non_snake_case)] + pub fn RecordingFailed(index: usize, message: String, error: Arc) -> Self { + Self::Error(anyhow::anyhow!( + "recording failed at message {}: ({}). Error: {}", + index, + message, + error + )) + } } -impl CallFunctionError { - /// Checks if the error is a `DependentError` and returns the underlying - /// error if so. Otherwise, returns `None`. - pub fn unwrap_dependent_error(&self) -> Option> { - match self { - CallFunctionError::DependentError(e) => Some(e.clone()), - _ => None, - } +impl From for CallFunctionError { + fn from(v: SerializablePyErr) -> CallFunctionError { + CallFunctionError::Error(v.into()) } } -/// Errors encountered during worker operations which get propagated in the -/// refs-to-values maps used to store values. -#[derive(Debug, Serialize, Deserialize, Error, Named, EnumAsInner)] -pub enum ValueError { - // TODO(agallagher): Migrate to variants for each error (after we cleanup - // `CallFunctionError` to make it serializable). - #[error("call function error: {0}")] - CallFunctionError(String), +impl From for CallFunctionError { + fn from(v: BorrowError) -> CallFunctionError { + CallFunctionError::Error(v.into()) + } } -impl From> for ValueError { - fn from(err: Arc) -> Self { - ValueError::CallFunctionError(format!("{:?}", err)) +impl From for CallFunctionError { + fn from(v: CallOpError) -> CallFunctionError { + CallFunctionError::Error(v.into()) } } @@ -874,7 +912,7 @@ pub enum WorkerMessage { /// The stream to retrieve from. stream: StreamRef, #[reply] - response_port: hyperactor::OncePortRef>>, + response_port: hyperactor::OncePortRef>>, }, } diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index e1c528ac..ad5e459d 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -639,7 +639,7 @@ impl WorkerMessageHandler for WorkerActor { _cx: &hyperactor::Context, _ref_id: Ref, _stream: StreamRef, - ) -> Result>> { + ) -> Result>> { bail!("unimplemented: get_ref_unit_tests_only") } diff --git a/monarch_tensor_worker/src/borrow.rs b/monarch_tensor_worker/src/borrow.rs index e36f6293..7b40b0d9 100644 --- a/monarch_tensor_worker/src/borrow.rs +++ b/monarch_tensor_worker/src/borrow.rs @@ -422,10 +422,9 @@ mod tests { let error = result .context("no such ref")? .err() - .context("expected error")? - .into_call_function_error()?; + .context("expected error")?; assert!( - error.starts_with("DependentError"), + error.contains("Computation depended on an input that failed"), "If a borrowed value contains an error, downstream calls should propagate that error (unexpected error string: {})", error, ); diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 772a8b19..33cf65c8 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -80,7 +80,6 @@ use monarch_messages::worker::Ref; use monarch_messages::worker::ResolvableFunction; use monarch_messages::worker::StreamCreationMode; use monarch_messages::worker::StreamRef; -use monarch_messages::worker::ValueError; use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerMessageHandler; use monarch_messages::worker::WorkerParams; @@ -679,10 +678,9 @@ impl WorkerMessageHandler for WorkerActor { (k, RValue::PyObject(object.into_py_object().unwrap()).into()) }) .collect(); - let device_mesh = self - .device_meshes - .get(&device_mesh) - .ok_or_else(|| CallFunctionError::RefNotFound(device_mesh))?; + let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| { + CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh)) + })?; let pipe = PipeActor::spawn( cx, PipeParams { @@ -989,10 +987,15 @@ impl WorkerMessageHandler for WorkerActor { // Get a port for the pipe let pipe = match self.pipes.get(&pipe) { - None => Err(Arc::new(CallFunctionError::RefNotFound(pipe))), + None => Err(Arc::new(CallFunctionError::Error(anyhow::anyhow!( + "ref not found: {}", + pipe + )))), Some(pipe) => match pipe.as_ref() { Ok(pipe) => Ok(pipe.port()), - Err(e) => Err(Arc::new(CallFunctionError::DependentError(e.clone()))), + Err(e) => Err(Arc::new(CallFunctionError::DependentError( + e.unwrap_dependent_error().unwrap_or(e.clone()), + ))), }, }; @@ -1020,7 +1023,7 @@ impl WorkerMessageHandler for WorkerActor { cx: &hyperactor::Context, ref_id: Ref, stream: StreamRef, - ) -> Result>> { + ) -> Result>> { let stream = self.try_get_stream(stream)?; Ok(stream .get_ref_unit_tests_only(cx, ref_id.clone()) @@ -1437,8 +1440,7 @@ mod tests { let mutated_ref = result .context("no such ref")? .err() - .context("expected error")? - .into_call_function_error()?; + .context("expected error")?; assert!(mutated_ref.contains("InvalidRemoteFunction")); let responses = controller_rx.drain(); @@ -2542,7 +2544,7 @@ mod tests { ])) .unwrap(); - let value: Result<_, ValueError> = handle + let value: Result<_, String> = handle .get_ref_unit_tests_only(&client, 3.into(), 0.into()) .await .unwrap() diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 7a5305c5..2e23deef 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -235,7 +235,7 @@ pub enum StreamMessage { GetRefUnitTestsOnly( Ref, // value - #[reply] OncePortHandle>>>, + #[reply] OncePortHandle>>, ), GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle>), @@ -661,10 +661,10 @@ impl StreamActor { .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result))) .collect::>()) } else { - Err(CallFunctionError::UnexpectedNumberOfReturns { - expected: result_refs.len(), - actual: actual_results.len(), - }) + Err(CallFunctionError::UnexpectedNumberOfReturns( + result_refs.len(), + actual_results.len(), + )) } }); @@ -1620,10 +1620,10 @@ impl StreamMessageHandler for StreamActor { } }) }), - _ => Err(CallFunctionError::TooManyArgsForValue { - args: format!("{:?}", args), - kwargs: format!("{:?}", kwargs), - }), + _ => Err(CallFunctionError::TooManyArgsForValue( + format!("{:?}", args), + format!("{:?}", kwargs), + )), } }; @@ -1769,7 +1769,6 @@ impl StreamMessageHandler for StreamActor { CallFunctionError::DependentError(dep_err) => { Err(CallFunctionError::DependentError(dep_err.clone())) } - CallFunctionError::RefNotFound(ref_) => Err(CallFunctionError::RefNotFound(*ref_)), _ => bail!("unexpected error for pipe in set_value: {:?}", err), }, }; @@ -2034,11 +2033,11 @@ impl StreamMessageHandler for StreamActor { .messages .get(index) .unwrap(); - error = Some(Arc::new(CallFunctionError::RecordingFailed { + error = Some(Arc::new(CallFunctionError::RecordingFailed( index, - message: format!("{message:?}"), - error: err.clone(), - })); + format!("{message:?}"), + err.clone(), + ))); // Report failure to the controller. self.controller_actor .remote_function_failed( @@ -2117,7 +2116,7 @@ impl StreamMessageHandler for StreamActor { &mut self, _cx: &Context, reference: Ref, - ) -> Result>>> { + ) -> Result>> { /// For testing only, doesn't support Tensor or TensorList. fn rvalue_to_wire( value: Result>, @@ -2139,7 +2138,7 @@ impl StreamMessageHandler for StreamActor { Ok(self .env .get(&reference) - .map(|rvalue| rvalue_to_wire(rvalue.clone()))) + .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string()))) } async fn get_tensor_ref_unit_tests_only( @@ -2253,11 +2252,12 @@ mod tests { .unwrap() .unwrap() .unwrap(); - allclose( + let x = allclose( &factory_float_tensor(data, "cpu".try_into().unwrap()), &actual.borrow(), ) - .unwrap() + .unwrap(); + x } async fn validate_dependent_error( @@ -2734,17 +2734,19 @@ mod tests { .await .unwrap(); let error = result.unwrap().unwrap_err(); - match error.as_ref() { - CallFunctionError::RecordingFailed { - error: inner_error, .. - } => match inner_error.as_ref() { - CallFunctionError::RefNotFound(err_ref) => { - assert_eq!(*err_ref, nonexistent_ref) - } - _ => panic!("Unexpected error inside RecordingFailed: {:?}", inner_error), - }, - _ => panic!("Unexpected error instead of RecordingFailed: {:?}", error), - }; + + // Check that the error contains the expected strings + let error_str = format!("{:?}", error); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); + assert!( + error_str.contains("ref not found"), + "Error should contain 'ref not found': {}", + error_str + ); } assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await; @@ -3031,16 +3033,18 @@ mod tests { .await? .unwrap() .unwrap_err(); - match result_error.as_ref() { - CallFunctionError::RecordingFailed { error, .. } => match error.as_ref() { - CallFunctionError::OperatorFailed(_) => (), - _ => panic!("Unexpected error inside RecordingFailed: {:?}", error), - }, - _ => panic!( - "Unexpected error instead of RecordingFailed: {:?}", - result_error - ), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", result_error); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); + assert!( + error_str.contains("torch operator failed"), + "Error should contain 'torch operator failed': {}", + error_str + ); } let controller_msg = test_setup.controller_rx.recv().await.unwrap(); @@ -3091,16 +3095,18 @@ mod tests { .await? .unwrap() .unwrap_err(); - match result_error.as_ref() { - CallFunctionError::DependentError(dep_err) => match dep_err.as_ref() { - CallFunctionError::RecordingFailed { .. } => (), - _ => panic!("Unexpected error inside DependentError: {:?}", dep_err), - }, - _ => panic!( - "Unexpected error instead of DependentError: {:?}", - result_error - ), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", result_error); + assert!( + error_str.contains("Computation depended on an input that failed"), + "Error should contain dependency message: {}", + error_str + ); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); } // This tests that the DependentError was never reported to the controller. @@ -3470,12 +3476,22 @@ mod tests { .unwrap() .unwrap_err(); - match result_error.as_ref() { - CallFunctionError::DependentError(dep_err) => { - assert!(Arc::ptr_eq(dep_err, &input_error)); - } - _ => panic!("Unexpected error: {:?}", result_error), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", result_error); + assert!( + error_str.contains("Computation depended on an input that failed"), + "Error should contain dependency message: {}", + error_str + ); + + // Since we're checking for pointer equality in the original code, we need to ensure + // the error is propagated correctly. We can check that the original error message is contained. + let input_error_str = format!("{:?}", input_error); + assert!( + error_str.contains(&input_error_str), + "Error should contain the original error: {}", + error_str + ); // Verify that neither stream sends a failure message to the controller. check_fetch_result_error( @@ -4215,10 +4231,13 @@ mod tests { .await? .unwrap() .unwrap_err(); - assert!(matches!( - real_result_err.as_ref(), - CallFunctionError::RecordingFailed { .. } - )); + // Check that the error contains the expected string + let error_str = format!("{:?}", real_result_err); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); let controller_msg = test_setup.controller_rx.recv().await.unwrap(); match controller_msg { @@ -4289,15 +4308,18 @@ mod tests { .await? .unwrap() .unwrap_err(); - match real_result_err.as_ref() { - CallFunctionError::DependentError(err) => match err.as_ref() { - CallFunctionError::Anyhow(err) => { - assert!(err.to_string().contains("bad pipe")); - } - _ => panic!("Unexpected error: {:?}", real_result_err), - }, - _ => panic!("Unexpected error: {:?}", real_result_err), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", real_result_err); + assert!( + error_str.contains("Computation depended on an input that failed"), + "Error should contain dependency message: {}", + error_str + ); + assert!( + error_str.contains("bad pipe"), + "Error should contain 'bad pipe': {}", + error_str + ); check_fetch_result_error( &test_setup.client,