Skip to content

[18/n] tensor engine: Simplify CallFunctionError #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 88 additions & 50 deletions monarch_messages/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,74 +506,112 @@ pub enum StreamCreationMode {
CreateNewStream,
}

/// The kinds of errors that a CallFunction message can return with.
// TODO(agallagher): We should move most variants out into `ValueError`.
/// When a worker runs any function, it may not succeed either because the function itself
/// failed (Error) or because an input to the function already had an error value
/// DependentError.
#[derive(Error, Debug, Named)]
pub enum CallFunctionError {
#[error("ref not found: {0}")]
RefNotFound(Ref),
#[error("{0}")]
Error(#[from] anyhow::Error),
#[error("Computation depended on an input that failed with errror: {0}")]
DependentError(Arc<CallFunctionError>),
}

#[error("dependent error {0}")]
DependentError(#[from] Arc<CallFunctionError>),
impl CallFunctionError {
/// Checks if the error is a `DependentError` and returns the underlying
/// error if so. Otherwise, returns `None`.
pub fn unwrap_dependent_error(&self) -> Option<Arc<CallFunctionError>> {
match self {
CallFunctionError::DependentError(e) => Some(e.clone()),
_ => None,
}
}

#[error("invalid remote function: {0}")]
InvalidRemoteFunction(String),
// Static functions for backward compatibility with existing enum cases

#[error("unsupported arg type for {0} function: {1}")]
UnsupportedArgType(String, String),
#[allow(non_snake_case)]
pub fn RefNotFound(r: Ref) -> Self {
Self::Error(anyhow::anyhow!("ref not found: {}", r))
}

#[error("remote function failed: {0}")]
RemoteFunctionFailed(#[from] SerializablePyErr),
#[allow(non_snake_case)]
pub fn InvalidRemoteFunction(msg: String) -> Self {
Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
}

#[error("borrow failed: {0}")]
BorrowError(#[from] BorrowError),
#[allow(non_snake_case)]
pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
Self::Error(anyhow::anyhow!(
"unsupported arg type for {} function: {}",
function_type,
arg_type
))
}

#[error("torch operator failed: {0}")]
OperatorFailed(#[from] CallOpError),
#[allow(non_snake_case)]
pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
Self::Error(anyhow::anyhow!("remote function failed: {}", err))
}

#[error("unexpected number of returns from op, expected {expected}, got {actual}")]
UnexpectedNumberOfReturns { expected: usize, actual: usize },
#[allow(non_snake_case)]
pub fn BorrowError(err: BorrowError) -> Self {
Self::Error(anyhow::anyhow!("borrow failed: {}", err))
}

#[error(
"expected only a single arg (and no kwargs) when no function is given: {args}, {kwargs}"
)]
TooManyArgsForValue { args: String, kwargs: String },
#[allow(non_snake_case)]
pub fn OperatorFailed(err: CallOpError) -> Self {
Self::Error(anyhow::anyhow!("torch operator failed: {}", err))
}

#[error("error: {0}")]
Anyhow(#[from] anyhow::Error),
#[allow(non_snake_case)]
pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
Self::Error(anyhow::anyhow!(
"unexpected number of returns from op, expected {}, got {}",
expected,
actual
))
}

#[error("recording failed at message {index}: ({message}). Error: {error}")]
RecordingFailed {
index: usize,
message: String,
error: Arc<CallFunctionError>,
},
#[allow(non_snake_case)]
pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
Self::Error(anyhow::anyhow!(
"expected only a single arg (and no kwargs) when no function is given: {}, {}",
args,
kwargs
))
}

#[allow(non_snake_case)]
pub fn Anyhow(err: anyhow::Error) -> Self {
Self::Error(err)
}

#[allow(non_snake_case)]
pub fn RecordingFailed(index: usize, message: String, error: Arc<CallFunctionError>) -> Self {
Self::Error(anyhow::anyhow!(
"recording failed at message {}: ({}). Error: {}",
index,
message,
error
))
}
}

impl CallFunctionError {
/// Checks if the error is a `DependentError` and returns the underlying
/// error if so. Otherwise, returns `None`.
pub fn unwrap_dependent_error(&self) -> Option<Arc<CallFunctionError>> {
match self {
CallFunctionError::DependentError(e) => Some(e.clone()),
_ => None,
}
impl From<SerializablePyErr> for CallFunctionError {
fn from(v: SerializablePyErr) -> CallFunctionError {
CallFunctionError::Error(v.into())
}
}

/// Errors encountered during worker operations which get propagated in the
/// refs-to-values maps used to store values.
#[derive(Debug, Serialize, Deserialize, Error, Named, EnumAsInner)]
pub enum ValueError {
// TODO(agallagher): Migrate to variants for each error (after we cleanup
// `CallFunctionError` to make it serializable).
#[error("call function error: {0}")]
CallFunctionError(String),
impl From<BorrowError> for CallFunctionError {
fn from(v: BorrowError) -> CallFunctionError {
CallFunctionError::Error(v.into())
}
}

impl From<Arc<CallFunctionError>> for ValueError {
fn from(err: Arc<CallFunctionError>) -> Self {
ValueError::CallFunctionError(format!("{:?}", err))
impl From<CallOpError> for CallFunctionError {
fn from(v: CallOpError) -> CallFunctionError {
CallFunctionError::Error(v.into())
}
}

Expand Down Expand Up @@ -874,7 +912,7 @@ pub enum WorkerMessage {
/// The stream to retrieve from.
stream: StreamRef,
#[reply]
response_port: hyperactor::OncePortRef<Option<Result<WireValue, ValueError>>>,
response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
},
}

Expand Down
2 changes: 1 addition & 1 deletion monarch_simulator/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ impl WorkerMessageHandler for WorkerActor {
_cx: &hyperactor::Context<Self>,
_ref_id: Ref,
_stream: StreamRef,
) -> Result<Option<Result<WireValue, ValueError>>> {
) -> Result<Option<Result<WireValue, String>>> {
bail!("unimplemented: get_ref_unit_tests_only")
}

Expand Down
5 changes: 2 additions & 3 deletions monarch_tensor_worker/src/borrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,9 @@ mod tests {
let error = result
.context("no such ref")?
.err()
.context("expected error")?
.into_call_function_error()?;
.context("expected error")?;
assert!(
error.starts_with("DependentError"),
error.contains("Computation depended on an input that failed"),
"If a borrowed value contains an error, downstream calls should propagate that error (unexpected error string: {})",
error,
);
Expand Down
24 changes: 13 additions & 11 deletions monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ use monarch_messages::worker::Ref;
use monarch_messages::worker::ResolvableFunction;
use monarch_messages::worker::StreamCreationMode;
use monarch_messages::worker::StreamRef;
use monarch_messages::worker::ValueError;
use monarch_messages::worker::WorkerMessage;
use monarch_messages::worker::WorkerMessageHandler;
use monarch_messages::worker::WorkerParams;
Expand Down Expand Up @@ -679,10 +678,9 @@ impl WorkerMessageHandler for WorkerActor {
(k, RValue::PyObject(object.into_py_object().unwrap()).into())
})
.collect();
let device_mesh = self
.device_meshes
.get(&device_mesh)
.ok_or_else(|| CallFunctionError::RefNotFound(device_mesh))?;
let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
})?;
let pipe = PipeActor::spawn(
cx,
PipeParams {
Expand Down Expand Up @@ -989,10 +987,15 @@ impl WorkerMessageHandler for WorkerActor {

// Get a port for the pipe
let pipe = match self.pipes.get(&pipe) {
None => Err(Arc::new(CallFunctionError::RefNotFound(pipe))),
None => Err(Arc::new(CallFunctionError::Error(anyhow::anyhow!(
"ref not found: {}",
pipe
)))),
Some(pipe) => match pipe.as_ref() {
Ok(pipe) => Ok(pipe.port()),
Err(e) => Err(Arc::new(CallFunctionError::DependentError(e.clone()))),
Err(e) => Err(Arc::new(CallFunctionError::DependentError(
e.unwrap_dependent_error().unwrap_or(e.clone()),
))),
},
};

Expand Down Expand Up @@ -1020,7 +1023,7 @@ impl WorkerMessageHandler for WorkerActor {
cx: &hyperactor::Context<Self>,
ref_id: Ref,
stream: StreamRef,
) -> Result<Option<Result<WireValue, ValueError>>> {
) -> Result<Option<Result<WireValue, String>>> {
let stream = self.try_get_stream(stream)?;
Ok(stream
.get_ref_unit_tests_only(cx, ref_id.clone())
Expand Down Expand Up @@ -1437,8 +1440,7 @@ mod tests {
let mutated_ref = result
.context("no such ref")?
.err()
.context("expected error")?
.into_call_function_error()?;
.context("expected error")?;
assert!(mutated_ref.contains("InvalidRemoteFunction"));

let responses = controller_rx.drain();
Expand Down Expand Up @@ -2542,7 +2544,7 @@ mod tests {
]))
.unwrap();

let value: Result<_, ValueError> = handle
let value: Result<_, String> = handle
.get_ref_unit_tests_only(&client, 3.into(), 0.into())
.await
.unwrap()
Expand Down
Loading
Loading