Skip to content

Commit be4c97d

Browse files
zdevitofacebook-github-bot
authored andcommitted
tensor engine: Simplify CallFunctionError (#552)
Summary: Pull Request resolved: #552 The only case we every 'catch' is DependentError. Everything else is just a textural description. So this collapses those cases into just the two cases. This retains the enum variants as function calls so that we can also reify them again if we need to catch them but having them obscures the places where DependentError needs to be handled differently. Reviewed By: suo Differential Revision: D78363663 fbshipit-source-id: f1d2ba31dfe65e1b1202ea4d2b67d224cfa76b5d
1 parent 9efd1c8 commit be4c97d

File tree

5 files changed

+194
-133
lines changed

5 files changed

+194
-133
lines changed

monarch_messages/src/worker.rs

Lines changed: 88 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -506,74 +506,112 @@ pub enum StreamCreationMode {
506506
CreateNewStream,
507507
}
508508

509-
/// The kinds of errors that a CallFunction message can return with.
510-
// TODO(agallagher): We should move most variants out into `ValueError`.
509+
/// When a worker runs any function, it may not succeed either because the function itself
510+
/// failed (Error) or because an input to the function already had an error value
511+
/// DependentError.
511512
#[derive(Error, Debug, Named)]
512513
pub enum CallFunctionError {
513-
#[error("ref not found: {0}")]
514-
RefNotFound(Ref),
514+
#[error("{0}")]
515+
Error(#[from] anyhow::Error),
516+
#[error("Computation depended on an input that failed with errror: {0}")]
517+
DependentError(Arc<CallFunctionError>),
518+
}
515519

516-
#[error("dependent error {0}")]
517-
DependentError(#[from] Arc<CallFunctionError>),
520+
impl CallFunctionError {
521+
/// Checks if the error is a `DependentError` and returns the underlying
522+
/// error if so. Otherwise, returns `None`.
523+
pub fn unwrap_dependent_error(&self) -> Option<Arc<CallFunctionError>> {
524+
match self {
525+
CallFunctionError::DependentError(e) => Some(e.clone()),
526+
_ => None,
527+
}
528+
}
518529

519-
#[error("invalid remote function: {0}")]
520-
InvalidRemoteFunction(String),
530+
// Static functions for backward compatibility with existing enum cases
521531

522-
#[error("unsupported arg type for {0} function: {1}")]
523-
UnsupportedArgType(String, String),
532+
#[allow(non_snake_case)]
533+
pub fn RefNotFound(r: Ref) -> Self {
534+
Self::Error(anyhow::anyhow!("ref not found: {}", r))
535+
}
524536

525-
#[error("remote function failed: {0}")]
526-
RemoteFunctionFailed(#[from] SerializablePyErr),
537+
#[allow(non_snake_case)]
538+
pub fn InvalidRemoteFunction(msg: String) -> Self {
539+
Self::Error(anyhow::anyhow!("invalid remote function: {}", msg))
540+
}
527541

528-
#[error("borrow failed: {0}")]
529-
BorrowError(#[from] BorrowError),
542+
#[allow(non_snake_case)]
543+
pub fn UnsupportedArgType(function_type: String, arg_type: String) -> Self {
544+
Self::Error(anyhow::anyhow!(
545+
"unsupported arg type for {} function: {}",
546+
function_type,
547+
arg_type
548+
))
549+
}
530550

531-
#[error("torch operator failed: {0}")]
532-
OperatorFailed(#[from] CallOpError),
551+
#[allow(non_snake_case)]
552+
pub fn RemoteFunctionFailed(err: SerializablePyErr) -> Self {
553+
Self::Error(anyhow::anyhow!("remote function failed: {}", err))
554+
}
533555

534-
#[error("unexpected number of returns from op, expected {expected}, got {actual}")]
535-
UnexpectedNumberOfReturns { expected: usize, actual: usize },
556+
#[allow(non_snake_case)]
557+
pub fn BorrowError(err: BorrowError) -> Self {
558+
Self::Error(anyhow::anyhow!("borrow failed: {}", err))
559+
}
536560

537-
#[error(
538-
"expected only a single arg (and no kwargs) when no function is given: {args}, {kwargs}"
539-
)]
540-
TooManyArgsForValue { args: String, kwargs: String },
561+
#[allow(non_snake_case)]
562+
pub fn OperatorFailed(err: CallOpError) -> Self {
563+
Self::Error(anyhow::anyhow!("torch operator failed: {}", err))
564+
}
541565

542-
#[error("error: {0}")]
543-
Anyhow(#[from] anyhow::Error),
566+
#[allow(non_snake_case)]
567+
pub fn UnexpectedNumberOfReturns(expected: usize, actual: usize) -> Self {
568+
Self::Error(anyhow::anyhow!(
569+
"unexpected number of returns from op, expected {}, got {}",
570+
expected,
571+
actual
572+
))
573+
}
544574

545-
#[error("recording failed at message {index}: ({message}). Error: {error}")]
546-
RecordingFailed {
547-
index: usize,
548-
message: String,
549-
error: Arc<CallFunctionError>,
550-
},
575+
#[allow(non_snake_case)]
576+
pub fn TooManyArgsForValue(args: String, kwargs: String) -> Self {
577+
Self::Error(anyhow::anyhow!(
578+
"expected only a single arg (and no kwargs) when no function is given: {}, {}",
579+
args,
580+
kwargs
581+
))
582+
}
583+
584+
#[allow(non_snake_case)]
585+
pub fn Anyhow(err: anyhow::Error) -> Self {
586+
Self::Error(err)
587+
}
588+
589+
#[allow(non_snake_case)]
590+
pub fn RecordingFailed(index: usize, message: String, error: Arc<CallFunctionError>) -> Self {
591+
Self::Error(anyhow::anyhow!(
592+
"recording failed at message {}: ({}). Error: {}",
593+
index,
594+
message,
595+
error
596+
))
597+
}
551598
}
552599

553-
impl CallFunctionError {
554-
/// Checks if the error is a `DependentError` and returns the underlying
555-
/// error if so. Otherwise, returns `None`.
556-
pub fn unwrap_dependent_error(&self) -> Option<Arc<CallFunctionError>> {
557-
match self {
558-
CallFunctionError::DependentError(e) => Some(e.clone()),
559-
_ => None,
560-
}
600+
impl From<SerializablePyErr> for CallFunctionError {
601+
fn from(v: SerializablePyErr) -> CallFunctionError {
602+
CallFunctionError::Error(v.into())
561603
}
562604
}
563605

564-
/// Errors encountered during worker operations which get propagated in the
565-
/// refs-to-values maps used to store values.
566-
#[derive(Debug, Serialize, Deserialize, Error, Named, EnumAsInner)]
567-
pub enum ValueError {
568-
// TODO(agallagher): Migrate to variants for each error (after we cleanup
569-
// `CallFunctionError` to make it serializable).
570-
#[error("call function error: {0}")]
571-
CallFunctionError(String),
606+
impl From<BorrowError> for CallFunctionError {
607+
fn from(v: BorrowError) -> CallFunctionError {
608+
CallFunctionError::Error(v.into())
609+
}
572610
}
573611

574-
impl From<Arc<CallFunctionError>> for ValueError {
575-
fn from(err: Arc<CallFunctionError>) -> Self {
576-
ValueError::CallFunctionError(format!("{:?}", err))
612+
impl From<CallOpError> for CallFunctionError {
613+
fn from(v: CallOpError) -> CallFunctionError {
614+
CallFunctionError::Error(v.into())
577615
}
578616
}
579617

@@ -874,7 +912,7 @@ pub enum WorkerMessage {
874912
/// The stream to retrieve from.
875913
stream: StreamRef,
876914
#[reply]
877-
response_port: hyperactor::OncePortRef<Option<Result<WireValue, ValueError>>>,
915+
response_port: hyperactor::OncePortRef<Option<Result<WireValue, String>>>,
878916
},
879917
}
880918

monarch_simulator/src/worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ impl WorkerMessageHandler for WorkerActor {
639639
_cx: &hyperactor::Context<Self>,
640640
_ref_id: Ref,
641641
_stream: StreamRef,
642-
) -> Result<Option<Result<WireValue, ValueError>>> {
642+
) -> Result<Option<Result<WireValue, String>>> {
643643
bail!("unimplemented: get_ref_unit_tests_only")
644644
}
645645

monarch_tensor_worker/src/borrow.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,9 @@ mod tests {
422422
let error = result
423423
.context("no such ref")?
424424
.err()
425-
.context("expected error")?
426-
.into_call_function_error()?;
425+
.context("expected error")?;
427426
assert!(
428-
error.starts_with("DependentError"),
427+
error.contains("Computation depended on an input that failed"),
429428
"If a borrowed value contains an error, downstream calls should propagate that error (unexpected error string: {})",
430429
error,
431430
);

monarch_tensor_worker/src/lib.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ use monarch_messages::worker::Ref;
8080
use monarch_messages::worker::ResolvableFunction;
8181
use monarch_messages::worker::StreamCreationMode;
8282
use monarch_messages::worker::StreamRef;
83-
use monarch_messages::worker::ValueError;
8483
use monarch_messages::worker::WorkerMessage;
8584
use monarch_messages::worker::WorkerMessageHandler;
8685
use monarch_messages::worker::WorkerParams;
@@ -679,10 +678,9 @@ impl WorkerMessageHandler for WorkerActor {
679678
(k, RValue::PyObject(object.into_py_object().unwrap()).into())
680679
})
681680
.collect();
682-
let device_mesh = self
683-
.device_meshes
684-
.get(&device_mesh)
685-
.ok_or_else(|| CallFunctionError::RefNotFound(device_mesh))?;
681+
let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
682+
CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
683+
})?;
686684
let pipe = PipeActor::spawn(
687685
cx,
688686
PipeParams {
@@ -989,10 +987,15 @@ impl WorkerMessageHandler for WorkerActor {
989987

990988
// Get a port for the pipe
991989
let pipe = match self.pipes.get(&pipe) {
992-
None => Err(Arc::new(CallFunctionError::RefNotFound(pipe))),
990+
None => Err(Arc::new(CallFunctionError::Error(anyhow::anyhow!(
991+
"ref not found: {}",
992+
pipe
993+
)))),
993994
Some(pipe) => match pipe.as_ref() {
994995
Ok(pipe) => Ok(pipe.port()),
995-
Err(e) => Err(Arc::new(CallFunctionError::DependentError(e.clone()))),
996+
Err(e) => Err(Arc::new(CallFunctionError::DependentError(
997+
e.unwrap_dependent_error().unwrap_or(e.clone()),
998+
))),
996999
},
9971000
};
9981001

@@ -1020,7 +1023,7 @@ impl WorkerMessageHandler for WorkerActor {
10201023
cx: &hyperactor::Context<Self>,
10211024
ref_id: Ref,
10221025
stream: StreamRef,
1023-
) -> Result<Option<Result<WireValue, ValueError>>> {
1026+
) -> Result<Option<Result<WireValue, String>>> {
10241027
let stream = self.try_get_stream(stream)?;
10251028
Ok(stream
10261029
.get_ref_unit_tests_only(cx, ref_id.clone())
@@ -1437,8 +1440,7 @@ mod tests {
14371440
let mutated_ref = result
14381441
.context("no such ref")?
14391442
.err()
1440-
.context("expected error")?
1441-
.into_call_function_error()?;
1443+
.context("expected error")?;
14421444
assert!(mutated_ref.contains("InvalidRemoteFunction"));
14431445

14441446
let responses = controller_rx.drain();
@@ -2542,7 +2544,7 @@ mod tests {
25422544
]))
25432545
.unwrap();
25442546

2545-
let value: Result<_, ValueError> = handle
2547+
let value: Result<_, String> = handle
25462548
.get_ref_unit_tests_only(&client, 3.into(), 0.into())
25472549
.await
25482550
.unwrap()

0 commit comments

Comments
 (0)