Skip to content

Commit 9824a23

Browse files
committed
Split test
1 parent 036a02e commit 9824a23

File tree

1 file changed

+44
-16
lines changed

1 file changed

+44
-16
lines changed

tests/nexus/test_handler.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,10 @@ async def test_start_operation_happy_path(
683683
with_service_definition: bool,
684684
env: WorkflowEnvironment,
685685
):
686-
await _test_start_operation(test_case, with_service_definition, env)
686+
if with_service_definition:
687+
await _test_start_operation_with_service_definition(test_case, env)
688+
else:
689+
await _test_start_operation_without_service_definition(test_case, env)
687690

688691

689692
@pytest.mark.parametrize(
@@ -702,7 +705,7 @@ async def test_start_operation_happy_path(
702705
async def test_start_operation_protocol_level_failures(
703706
test_case: Type[_TestCase], env: WorkflowEnvironment
704707
):
705-
await _test_start_operation(test_case, True, env)
708+
await _test_start_operation_with_service_definition(test_case, env)
706709

707710

708711
@pytest.mark.parametrize(
@@ -716,12 +719,11 @@ async def test_start_operation_protocol_level_failures(
716719
async def test_start_operation_operation_failures(
717720
test_case: Type[_TestCase], env: WorkflowEnvironment
718721
):
719-
await _test_start_operation(test_case, True, env)
722+
await _test_start_operation_with_service_definition(test_case, env)
720723

721724

722-
async def _test_start_operation(
725+
async def _test_start_operation_with_service_definition(
723726
test_case: Type[_TestCase],
724-
with_service_definition: bool,
725727
env: WorkflowEnvironment,
726728
):
727729
if test_case.skip:
@@ -731,19 +733,45 @@ async def _test_start_operation(
731733
service_client = ServiceClient(
732734
server_address=server_address(env),
733735
endpoint=endpoint,
734-
service=(
735-
test_case.service_defn
736-
if with_service_definition
737-
else MyServiceHandler.__name__
738-
),
736+
service=(test_case.service_defn),
739737
)
740738

741739
with pytest.WarningsRecorder() as warnings:
742-
decorator = (
743-
service_handler(service=MyService)
744-
if with_service_definition
745-
else service_handler
746-
)
740+
decorator = service_handler(service=MyService)
741+
user_service_handler = decorator(MyServiceHandler)()
742+
743+
async with Worker(
744+
env.client,
745+
task_queue=task_queue,
746+
nexus_service_handlers=[user_service_handler],
747+
nexus_task_executor=concurrent.futures.ThreadPoolExecutor(),
748+
):
749+
response = await service_client.start_operation(
750+
test_case.operation,
751+
dataclass_as_dict(test_case.input),
752+
test_case.headers,
753+
)
754+
test_case.check_response(response, with_service_definition=True)
755+
756+
assert not any(warnings), [w.message for w in warnings]
757+
758+
759+
async def _test_start_operation_without_service_definition(
760+
test_case: Type[_TestCase],
761+
env: WorkflowEnvironment,
762+
):
763+
if test_case.skip:
764+
pytest.skip(test_case.skip)
765+
task_queue = str(uuid.uuid4())
766+
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
767+
service_client = ServiceClient(
768+
server_address=server_address(env),
769+
endpoint=endpoint,
770+
service=MyServiceHandler.__name__,
771+
)
772+
773+
with pytest.WarningsRecorder() as warnings:
774+
decorator = service_handler
747775
user_service_handler = decorator(MyServiceHandler)()
748776

749777
async with Worker(
@@ -757,7 +785,7 @@ async def _test_start_operation(
757785
dataclass_as_dict(test_case.input),
758786
test_case.headers,
759787
)
760-
test_case.check_response(response, with_service_definition)
788+
test_case.check_response(response, with_service_definition=False)
761789

762790
assert not any(warnings), [w.message for w in warnings]
763791

0 commit comments

Comments
 (0)