@@ -506,74 +506,112 @@ pub enum StreamCreationMode {
506
506
CreateNewStream ,
507
507
}
508
508
509
- /// The kinds of errors that a CallFunction message can return with.
510
- // TODO(agallagher): We should move most variants out into `ValueError`.
509
+ /// When a worker runs any function, it may not succeed either because the function itself
510
+ /// failed (Error) or because an input to the function already had an error value
511
+ /// DependentError.
511
512
#[ derive( Error , Debug , Named ) ]
512
513
pub enum CallFunctionError {
513
- #[ error( "ref not found: {0}" ) ]
514
- RefNotFound ( Ref ) ,
514
+ #[ error( "{0}" ) ]
515
+ Error ( #[ from] anyhow:: Error ) ,
516
+ #[ error( "Computation depended on an input that failed with errror: {0}" ) ]
517
+ DependentError ( Arc < CallFunctionError > ) ,
518
+ }
515
519
516
- #[ error( "dependent error {0}" ) ]
517
- DependentError ( #[ from] Arc < CallFunctionError > ) ,
520
+ impl CallFunctionError {
521
+ /// Checks if the error is a `DependentError` and returns the underlying
522
+ /// error if so. Otherwise, returns `None`.
523
+ pub fn unwrap_dependent_error ( & self ) -> Option < Arc < CallFunctionError > > {
524
+ match self {
525
+ CallFunctionError :: DependentError ( e) => Some ( e. clone ( ) ) ,
526
+ _ => None ,
527
+ }
528
+ }
518
529
519
- #[ error( "invalid remote function: {0}" ) ]
520
- InvalidRemoteFunction ( String ) ,
530
+ // Static functions for backward compatibility with existing enum cases
521
531
522
- #[ error( "unsupported arg type for {0} function: {1}" ) ]
523
- UnsupportedArgType ( String , String ) ,
532
+ #[ allow( non_snake_case) ]
533
+ pub fn RefNotFound ( r : Ref ) -> Self {
534
+ Self :: Error ( anyhow:: anyhow!( "ref not found: {}" , r) )
535
+ }
524
536
525
- #[ error( "remote function failed: {0}" ) ]
526
- RemoteFunctionFailed ( #[ from] SerializablePyErr ) ,
537
+ #[ allow( non_snake_case) ]
538
+ pub fn InvalidRemoteFunction ( msg : String ) -> Self {
539
+ Self :: Error ( anyhow:: anyhow!( "invalid remote function: {}" , msg) )
540
+ }
527
541
528
- #[ error( "borrow failed: {0}" ) ]
529
- BorrowError ( #[ from] BorrowError ) ,
542
+ #[ allow( non_snake_case) ]
543
+ pub fn UnsupportedArgType ( function_type : String , arg_type : String ) -> Self {
544
+ Self :: Error ( anyhow:: anyhow!(
545
+ "unsupported arg type for {} function: {}" ,
546
+ function_type,
547
+ arg_type
548
+ ) )
549
+ }
530
550
531
- #[ error( "torch operator failed: {0}" ) ]
532
- OperatorFailed ( #[ from] CallOpError ) ,
551
+ #[ allow( non_snake_case) ]
552
+ pub fn RemoteFunctionFailed ( err : SerializablePyErr ) -> Self {
553
+ Self :: Error ( anyhow:: anyhow!( "remote function failed: {}" , err) )
554
+ }
533
555
534
- #[ error( "unexpected number of returns from op, expected {expected}, got {actual}" ) ]
535
- UnexpectedNumberOfReturns { expected : usize , actual : usize } ,
556
+ #[ allow( non_snake_case) ]
557
+ pub fn BorrowError ( err : BorrowError ) -> Self {
558
+ Self :: Error ( anyhow:: anyhow!( "borrow failed: {}" , err) )
559
+ }
536
560
537
- #[ error (
538
- "expected only a single arg (and no kwargs) when no function is given: {args}, {kwargs}"
539
- ) ]
540
- TooManyArgsForValue { args : String , kwargs : String } ,
561
+ #[ allow ( non_snake_case ) ]
562
+ pub fn OperatorFailed ( err : CallOpError ) -> Self {
563
+ Self :: Error ( anyhow :: anyhow! ( "torch operator failed: {}" , err ) )
564
+ }
541
565
542
- #[ error( "error: {0}" ) ]
543
- Anyhow ( #[ from] anyhow:: Error ) ,
566
+ #[ allow( non_snake_case) ]
567
+ pub fn UnexpectedNumberOfReturns ( expected : usize , actual : usize ) -> Self {
568
+ Self :: Error ( anyhow:: anyhow!(
569
+ "unexpected number of returns from op, expected {}, got {}" ,
570
+ expected,
571
+ actual
572
+ ) )
573
+ }
544
574
545
- #[ error( "recording failed at message {index}: ({message}). Error: {error}" ) ]
546
- RecordingFailed {
547
- index : usize ,
548
- message : String ,
549
- error : Arc < CallFunctionError > ,
550
- } ,
575
+ #[ allow( non_snake_case) ]
576
+ pub fn TooManyArgsForValue ( args : String , kwargs : String ) -> Self {
577
+ Self :: Error ( anyhow:: anyhow!(
578
+ "expected only a single arg (and no kwargs) when no function is given: {}, {}" ,
579
+ args,
580
+ kwargs
581
+ ) )
582
+ }
583
+
584
+ #[ allow( non_snake_case) ]
585
+ pub fn Anyhow ( err : anyhow:: Error ) -> Self {
586
+ Self :: Error ( err)
587
+ }
588
+
589
+ #[ allow( non_snake_case) ]
590
+ pub fn RecordingFailed ( index : usize , message : String , error : Arc < CallFunctionError > ) -> Self {
591
+ Self :: Error ( anyhow:: anyhow!(
592
+ "recording failed at message {}: ({}). Error: {}" ,
593
+ index,
594
+ message,
595
+ error
596
+ ) )
597
+ }
551
598
}
552
599
553
- impl CallFunctionError {
554
- /// Checks if the error is a `DependentError` and returns the underlying
555
- /// error if so. Otherwise, returns `None`.
556
- pub fn unwrap_dependent_error ( & self ) -> Option < Arc < CallFunctionError > > {
557
- match self {
558
- CallFunctionError :: DependentError ( e) => Some ( e. clone ( ) ) ,
559
- _ => None ,
560
- }
600
+ impl From < SerializablePyErr > for CallFunctionError {
601
+ fn from ( v : SerializablePyErr ) -> CallFunctionError {
602
+ CallFunctionError :: Error ( v. into ( ) )
561
603
}
562
604
}
563
605
564
- /// Errors encountered during worker operations which get propagated in the
565
- /// refs-to-values maps used to store values.
566
- #[ derive( Debug , Serialize , Deserialize , Error , Named , EnumAsInner ) ]
567
- pub enum ValueError {
568
- // TODO(agallagher): Migrate to variants for each error (after we cleanup
569
- // `CallFunctionError` to make it serializable).
570
- #[ error( "call function error: {0}" ) ]
571
- CallFunctionError ( String ) ,
606
+ impl From < BorrowError > for CallFunctionError {
607
+ fn from ( v : BorrowError ) -> CallFunctionError {
608
+ CallFunctionError :: Error ( v. into ( ) )
609
+ }
572
610
}
573
611
574
- impl From < Arc < CallFunctionError > > for ValueError {
575
- fn from ( err : Arc < CallFunctionError > ) -> Self {
576
- ValueError :: CallFunctionError ( format ! ( "{:?}" , err ) )
612
+ impl From < CallOpError > for CallFunctionError {
613
+ fn from ( v : CallOpError ) -> CallFunctionError {
614
+ CallFunctionError :: Error ( v . into ( ) )
577
615
}
578
616
}
579
617
@@ -874,7 +912,7 @@ pub enum WorkerMessage {
874
912
/// The stream to retrieve from.
875
913
stream : StreamRef ,
876
914
#[ reply]
877
- response_port : hyperactor:: OncePortRef < Option < Result < WireValue , ValueError > > > ,
915
+ response_port : hyperactor:: OncePortRef < Option < Result < WireValue , String > > > ,
878
916
} ,
879
917
}
880
918
0 commit comments