@@ -683,7 +683,10 @@ async def test_start_operation_happy_path(
683
683
with_service_definition : bool ,
684
684
env : WorkflowEnvironment ,
685
685
):
686
- await _test_start_operation (test_case , with_service_definition , env )
686
+ if with_service_definition :
687
+ await _test_start_operation_with_service_definition (test_case , env )
688
+ else :
689
+ await _test_start_operation_without_service_definition (test_case , env )
687
690
688
691
689
692
@pytest .mark .parametrize (
@@ -702,7 +705,7 @@ async def test_start_operation_happy_path(
702
705
async def test_start_operation_protocol_level_failures (
703
706
test_case : Type [_TestCase ], env : WorkflowEnvironment
704
707
):
705
- await _test_start_operation (test_case , True , env )
708
+ await _test_start_operation_with_service_definition (test_case , env )
706
709
707
710
708
711
@pytest .mark .parametrize (
@@ -716,12 +719,11 @@ async def test_start_operation_protocol_level_failures(
716
719
async def test_start_operation_operation_failures (
717
720
test_case : Type [_TestCase ], env : WorkflowEnvironment
718
721
):
719
- await _test_start_operation (test_case , True , env )
722
+ await _test_start_operation_with_service_definition (test_case , env )
720
723
721
724
722
- async def _test_start_operation (
725
+ async def _test_start_operation_with_service_definition (
723
726
test_case : Type [_TestCase ],
724
- with_service_definition : bool ,
725
727
env : WorkflowEnvironment ,
726
728
):
727
729
if test_case .skip :
@@ -731,19 +733,45 @@ async def _test_start_operation(
731
733
service_client = ServiceClient (
732
734
server_address = server_address (env ),
733
735
endpoint = endpoint ,
734
- service = (
735
- test_case .service_defn
736
- if with_service_definition
737
- else MyServiceHandler .__name__
738
- ),
736
+ service = (test_case .service_defn ),
739
737
)
740
738
741
739
with pytest .WarningsRecorder () as warnings :
742
- decorator = (
743
- service_handler (service = MyService )
744
- if with_service_definition
745
- else service_handler
746
- )
740
+ decorator = service_handler (service = MyService )
741
+ user_service_handler = decorator (MyServiceHandler )()
742
+
743
+ async with Worker (
744
+ env .client ,
745
+ task_queue = task_queue ,
746
+ nexus_service_handlers = [user_service_handler ],
747
+ nexus_task_executor = concurrent .futures .ThreadPoolExecutor (),
748
+ ):
749
+ response = await service_client .start_operation (
750
+ test_case .operation ,
751
+ dataclass_as_dict (test_case .input ),
752
+ test_case .headers ,
753
+ )
754
+ test_case .check_response (response , with_service_definition = True )
755
+
756
+ assert not any (warnings ), [w .message for w in warnings ]
757
+
758
+
759
+ async def _test_start_operation_without_service_definition (
760
+ test_case : Type [_TestCase ],
761
+ env : WorkflowEnvironment ,
762
+ ):
763
+ if test_case .skip :
764
+ pytest .skip (test_case .skip )
765
+ task_queue = str (uuid .uuid4 ())
766
+ endpoint = (await create_nexus_endpoint (task_queue , env .client )).endpoint .id
767
+ service_client = ServiceClient (
768
+ server_address = server_address (env ),
769
+ endpoint = endpoint ,
770
+ service = MyServiceHandler .__name__ ,
771
+ )
772
+
773
+ with pytest .WarningsRecorder () as warnings :
774
+ decorator = service_handler
747
775
user_service_handler = decorator (MyServiceHandler )()
748
776
749
777
async with Worker (
@@ -757,7 +785,7 @@ async def _test_start_operation(
757
785
dataclass_as_dict (test_case .input ),
758
786
test_case .headers ,
759
787
)
760
- test_case .check_response (response , with_service_definition )
788
+ test_case .check_response (response , with_service_definition = False )
761
789
762
790
assert not any (warnings ), [w .message for w in warnings ]
763
791
0 commit comments