diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index 0192796f..8dda5535 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -506,74 +506,112 @@ pub enum StreamCreationMode { CreateNewStream, } -/// The kinds of errors that a CallFunction message can return with. -// TODO(agallagher): We should move most variants out into `ValueError`. +/// When a worker runs any function, it may not succeed either because the function itself +/// failed (Error) or because an input to the function already had an error value +/// DependentError. #[derive(Error, Debug, Named)] pub enum CallFunctionError { - #[error("ref not found: {0}")] - RefNotFound(Ref), + #[error("{0}")] + Error(#[from] anyhow::Error), + #[error("Computation depended on an input that failed with errror: {0}")] + DependentError(Arc), +} - #[error("dependent error {0}")] - DependentError(#[from] Arc), +impl CallFunctionError { + /// Checks if the error is a `DependentError` and returns the underlying + /// error if so. Otherwise, returns `None`. + pub fn unwrap_dependent_error(&self) -> Option> { + match self { + CallFunctionError::DependentError(e) => Some(e.clone()), + _ => None, + } + } - #[error("invalid remote function: {0}")] - InvalidRemoteFunction(String), + // Static functions for backward compatibility with existing enum cases - #[error("unsupported arg type for {0} function: {1}")] - UnsupportedArgType(String, String), + #[allow(non_snake_case)] + pub fn RefNotFound(r: Ref) -> Self { + Self::Error(anyhow::anyhow!("ref not found: {}", r)) + } - #[error("remote function failed: {0}")] - RemoteFunctionFailed(#[from] SerializablePyErr), + #[allow(non_snake_case)] + pub fn InvalidRemoteFunction(msg: String) -> Self { + Self::Error(anyhow::anyhow!("invalid remote function: {}", msg)) + } - #[error("borrow failed: {0}")] - BorrowError(#[from] BorrowError), + #[allow(non_snake_case)] + pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self { + Self::Error(anyhow::anyhow!( + "unsupported arg type for {} function: {}", + function_type, + arg_type + )) + } - #[error("torch operator failed: {0}")] - OperatorFailed(#[from] CallOpError), + #[allow(non_snake_case)] + pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self { + Self::Error(anyhow::anyhow!("remote function failed: {}", err)) + } - #[error("unexpected number of returns from op, expected {expected}, got {actual}")] - UnexpectedNumberOfReturns { expected: usize, actual: usize }, + #[allow(non_snake_case)] + pub fn BorrowError(err: BorrowError) -> Self { + Self::Error(anyhow::anyhow!("borrow failed: {}", err)) + } - #[error( - "expected only a single arg (and no kwargs) when no function is given: {args}, {kwargs}" - )] - TooManyArgsForValue { args: String, kwargs: String }, + #[allow(non_snake_case)] + pub fn OperatorFailed(err: CallOpError) -> Self { + Self::Error(anyhow::anyhow!("torch operator failed: {}", err)) + } - #[error("error: {0}")] - Anyhow(#[from] anyhow::Error), + #[allow(non_snake_case)] + pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self { + Self::Error(anyhow::anyhow!( + "unexpected number of returns from op, expected {}, got {}", + expected, + actual + )) + } - #[error("recording failed at message {index}: ({message}). Error: {error}")] - RecordingFailed { - index: usize, - message: String, - error: Arc, - }, + #[allow(non_snake_case)] + pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self { + Self::Error(anyhow::anyhow!( + "expected only a single arg (and no kwargs) when no function is given: {}, {}", + args, + kwargs + )) + } + + #[allow(non_snake_case)] + pub fn Anyhow(err: anyhow::Error) -> Self { + Self::Error(err) + } + + #[allow(non_snake_case)] + pub fn RecordingFailed(index: usize, message: String, error: Arc) -> Self { + Self::Error(anyhow::anyhow!( + "recording failed at message {}: ({}). Error: {}", + index, + message, + error + )) + } } -impl CallFunctionError { - /// Checks if the error is a `DependentError` and returns the underlying - /// error if so. Otherwise, returns `None`. - pub fn unwrap_dependent_error(&self) -> Option> { - match self { - CallFunctionError::DependentError(e) => Some(e.clone()), - _ => None, - } +impl From for CallFunctionError { + fn from(v: SerializablePyErr) -> CallFunctionError { + CallFunctionError::Error(v.into()) } } -/// Errors encountered during worker operations which get propagated in the -/// refs-to-values maps used to store values. -#[derive(Debug, Serialize, Deserialize, Error, Named, EnumAsInner)] -pub enum ValueError { - // TODO(agallagher): Migrate to variants for each error (after we cleanup - // `CallFunctionError` to make it serializable). - #[error("call function error: {0}")] - CallFunctionError(String), +impl From for CallFunctionError { + fn from(v: BorrowError) -> CallFunctionError { + CallFunctionError::Error(v.into()) + } } -impl From> for ValueError { - fn from(err: Arc) -> Self { - ValueError::CallFunctionError(format!("{:?}", err)) +impl From for CallFunctionError { + fn from(v: CallOpError) -> CallFunctionError { + CallFunctionError::Error(v.into()) } } @@ -874,7 +912,7 @@ pub enum WorkerMessage { /// The stream to retrieve from. stream: StreamRef, #[reply] - response_port: hyperactor::OncePortRef>>, + response_port: hyperactor::OncePortRef>>, }, } diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index e1c528ac..ad5e459d 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -639,7 +639,7 @@ impl WorkerMessageHandler for WorkerActor { _cx: &hyperactor::Context, _ref_id: Ref, _stream: StreamRef, - ) -> Result>> { + ) -> Result>> { bail!("unimplemented: get_ref_unit_tests_only") } diff --git a/monarch_tensor_worker/src/borrow.rs b/monarch_tensor_worker/src/borrow.rs index e36f6293..7b40b0d9 100644 --- a/monarch_tensor_worker/src/borrow.rs +++ b/monarch_tensor_worker/src/borrow.rs @@ -422,10 +422,9 @@ mod tests { let error = result .context("no such ref")? .err() - .context("expected error")? - .into_call_function_error()?; + .context("expected error")?; assert!( - error.starts_with("DependentError"), + error.contains("Computation depended on an input that failed"), "If a borrowed value contains an error, downstream calls should propagate that error (unexpected error string: {})", error, ); diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 772a8b19..33cf65c8 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -80,7 +80,6 @@ use monarch_messages::worker::Ref; use monarch_messages::worker::ResolvableFunction; use monarch_messages::worker::StreamCreationMode; use monarch_messages::worker::StreamRef; -use monarch_messages::worker::ValueError; use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerMessageHandler; use monarch_messages::worker::WorkerParams; @@ -679,10 +678,9 @@ impl WorkerMessageHandler for WorkerActor { (k, RValue::PyObject(object.into_py_object().unwrap()).into()) }) .collect(); - let device_mesh = self - .device_meshes - .get(&device_mesh) - .ok_or_else(|| CallFunctionError::RefNotFound(device_mesh))?; + let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| { + CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh)) + })?; let pipe = PipeActor::spawn( cx, PipeParams { @@ -989,10 +987,15 @@ impl WorkerMessageHandler for WorkerActor { // Get a port for the pipe let pipe = match self.pipes.get(&pipe) { - None => Err(Arc::new(CallFunctionError::RefNotFound(pipe))), + None => Err(Arc::new(CallFunctionError::Error(anyhow::anyhow!( + "ref not found: {}", + pipe + )))), Some(pipe) => match pipe.as_ref() { Ok(pipe) => Ok(pipe.port()), - Err(e) => Err(Arc::new(CallFunctionError::DependentError(e.clone()))), + Err(e) => Err(Arc::new(CallFunctionError::DependentError( + e.unwrap_dependent_error().unwrap_or(e.clone()), + ))), }, }; @@ -1020,7 +1023,7 @@ impl WorkerMessageHandler for WorkerActor { cx: &hyperactor::Context, ref_id: Ref, stream: StreamRef, - ) -> Result>> { + ) -> Result>> { let stream = self.try_get_stream(stream)?; Ok(stream .get_ref_unit_tests_only(cx, ref_id.clone()) @@ -1437,8 +1440,7 @@ mod tests { let mutated_ref = result .context("no such ref")? .err() - .context("expected error")? - .into_call_function_error()?; + .context("expected error")?; assert!(mutated_ref.contains("InvalidRemoteFunction")); let responses = controller_rx.drain(); @@ -2542,7 +2544,7 @@ mod tests { ])) .unwrap(); - let value: Result<_, ValueError> = handle + let value: Result<_, String> = handle .get_ref_unit_tests_only(&client, 3.into(), 0.into()) .await .unwrap() diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 7a5305c5..2e23deef 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -235,7 +235,7 @@ pub enum StreamMessage { GetRefUnitTestsOnly( Ref, // value - #[reply] OncePortHandle>>>, + #[reply] OncePortHandle>>, ), GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle>), @@ -661,10 +661,10 @@ impl StreamActor { .filter_map(|(result, ref_)| ref_.map(|ref_| (ref_, result))) .collect::>()) } else { - Err(CallFunctionError::UnexpectedNumberOfReturns { - expected: result_refs.len(), - actual: actual_results.len(), - }) + Err(CallFunctionError::UnexpectedNumberOfReturns( + result_refs.len(), + actual_results.len(), + )) } }); @@ -1620,10 +1620,10 @@ impl StreamMessageHandler for StreamActor { } }) }), - _ => Err(CallFunctionError::TooManyArgsForValue { - args: format!("{:?}", args), - kwargs: format!("{:?}", kwargs), - }), + _ => Err(CallFunctionError::TooManyArgsForValue( + format!("{:?}", args), + format!("{:?}", kwargs), + )), } }; @@ -1769,7 +1769,6 @@ impl StreamMessageHandler for StreamActor { CallFunctionError::DependentError(dep_err) => { Err(CallFunctionError::DependentError(dep_err.clone())) } - CallFunctionError::RefNotFound(ref_) => Err(CallFunctionError::RefNotFound(*ref_)), _ => bail!("unexpected error for pipe in set_value: {:?}", err), }, }; @@ -2034,11 +2033,11 @@ impl StreamMessageHandler for StreamActor { .messages .get(index) .unwrap(); - error = Some(Arc::new(CallFunctionError::RecordingFailed { + error = Some(Arc::new(CallFunctionError::RecordingFailed( index, - message: format!("{message:?}"), - error: err.clone(), - })); + format!("{message:?}"), + err.clone(), + ))); // Report failure to the controller. self.controller_actor .remote_function_failed( @@ -2117,7 +2116,7 @@ impl StreamMessageHandler for StreamActor { &mut self, _cx: &Context, reference: Ref, - ) -> Result>>> { + ) -> Result>> { /// For testing only, doesn't support Tensor or TensorList. fn rvalue_to_wire( value: Result>, @@ -2139,7 +2138,7 @@ impl StreamMessageHandler for StreamActor { Ok(self .env .get(&reference) - .map(|rvalue| rvalue_to_wire(rvalue.clone()))) + .map(|rvalue| rvalue_to_wire(rvalue.clone()).map_err(|err| err.to_string()))) } async fn get_tensor_ref_unit_tests_only( @@ -2253,11 +2252,12 @@ mod tests { .unwrap() .unwrap() .unwrap(); - allclose( + let x = allclose( &factory_float_tensor(data, "cpu".try_into().unwrap()), &actual.borrow(), ) - .unwrap() + .unwrap(); + x } async fn validate_dependent_error( @@ -2734,17 +2734,19 @@ mod tests { .await .unwrap(); let error = result.unwrap().unwrap_err(); - match error.as_ref() { - CallFunctionError::RecordingFailed { - error: inner_error, .. - } => match inner_error.as_ref() { - CallFunctionError::RefNotFound(err_ref) => { - assert_eq!(*err_ref, nonexistent_ref) - } - _ => panic!("Unexpected error inside RecordingFailed: {:?}", inner_error), - }, - _ => panic!("Unexpected error instead of RecordingFailed: {:?}", error), - }; + + // Check that the error contains the expected strings + let error_str = format!("{:?}", error); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); + assert!( + error_str.contains("ref not found"), + "Error should contain 'ref not found': {}", + error_str + ); } assert_refs_do_not_exist(&test_setup, &[formal0_ref, formal1_ref]).await; @@ -3031,16 +3033,18 @@ mod tests { .await? .unwrap() .unwrap_err(); - match result_error.as_ref() { - CallFunctionError::RecordingFailed { error, .. } => match error.as_ref() { - CallFunctionError::OperatorFailed(_) => (), - _ => panic!("Unexpected error inside RecordingFailed: {:?}", error), - }, - _ => panic!( - "Unexpected error instead of RecordingFailed: {:?}", - result_error - ), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", result_error); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); + assert!( + error_str.contains("torch operator failed"), + "Error should contain 'torch operator failed': {}", + error_str + ); } let controller_msg = test_setup.controller_rx.recv().await.unwrap(); @@ -3091,16 +3095,18 @@ mod tests { .await? .unwrap() .unwrap_err(); - match result_error.as_ref() { - CallFunctionError::DependentError(dep_err) => match dep_err.as_ref() { - CallFunctionError::RecordingFailed { .. } => (), - _ => panic!("Unexpected error inside DependentError: {:?}", dep_err), - }, - _ => panic!( - "Unexpected error instead of DependentError: {:?}", - result_error - ), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", result_error); + assert!( + error_str.contains("Computation depended on an input that failed"), + "Error should contain dependency message: {}", + error_str + ); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); } // This tests that the DependentError was never reported to the controller. @@ -3470,12 +3476,22 @@ mod tests { .unwrap() .unwrap_err(); - match result_error.as_ref() { - CallFunctionError::DependentError(dep_err) => { - assert!(Arc::ptr_eq(dep_err, &input_error)); - } - _ => panic!("Unexpected error: {:?}", result_error), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", result_error); + assert!( + error_str.contains("Computation depended on an input that failed"), + "Error should contain dependency message: {}", + error_str + ); + + // Since we're checking for pointer equality in the original code, we need to ensure + // the error is propagated correctly. We can check that the original error message is contained. + let input_error_str = format!("{:?}", input_error); + assert!( + error_str.contains(&input_error_str), + "Error should contain the original error: {}", + error_str + ); // Verify that neither stream sends a failure message to the controller. check_fetch_result_error( @@ -4215,10 +4231,13 @@ mod tests { .await? .unwrap() .unwrap_err(); - assert!(matches!( - real_result_err.as_ref(), - CallFunctionError::RecordingFailed { .. } - )); + // Check that the error contains the expected string + let error_str = format!("{:?}", real_result_err); + assert!( + error_str.contains("recording failed"), + "Error should contain 'recording failed': {}", + error_str + ); let controller_msg = test_setup.controller_rx.recv().await.unwrap(); match controller_msg { @@ -4289,15 +4308,18 @@ mod tests { .await? .unwrap() .unwrap_err(); - match real_result_err.as_ref() { - CallFunctionError::DependentError(err) => match err.as_ref() { - CallFunctionError::Anyhow(err) => { - assert!(err.to_string().contains("bad pipe")); - } - _ => panic!("Unexpected error: {:?}", real_result_err), - }, - _ => panic!("Unexpected error: {:?}", real_result_err), - } + // Check that the error contains the expected strings + let error_str = format!("{:?}", real_result_err); + assert!( + error_str.contains("Computation depended on an input that failed"), + "Error should contain dependency message: {}", + error_str + ); + assert!( + error_str.contains("bad pipe"), + "Error should contain 'bad pipe': {}", + error_str + ); check_fetch_result_error( &test_setup.client,