Skip to content

Commit a5b9661

Browse files
authored
Create commands after payload conversion (#591)
Fixes #540 Fixes #564
1 parent 4b93d1a commit a5b9661

File tree

2 files changed

+114
-44
lines changed

2 files changed

+114
-44
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,6 @@ def _apply_query_workflow(
599599
) -> None:
600600
# Wrap entire bunch of work in a task
601601
async def run_query() -> None:
602-
command = self._add_command()
603-
command.respond_to_query.query_id = job.query_id
604602
try:
605603
with self._as_read_only():
606604
# Named query or dynamic
@@ -632,11 +630,13 @@ async def run_query() -> None:
632630
raise ValueError(
633631
f"Expected 1 result payload, got {len(result_payloads)}"
634632
)
635-
command.respond_to_query.succeeded.response.CopyFrom(
636-
result_payloads[0]
637-
)
633+
command = self._add_command()
634+
command.respond_to_query.query_id = job.query_id
635+
command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0])
638636
except Exception as err:
639637
try:
638+
command = self._add_command()
639+
command.respond_to_query.query_id = job.query_id
640640
self._failure_converter.to_failure(
641641
err,
642642
self._payload_converter,
@@ -1427,7 +1427,7 @@ async def run_activity() -> Any:
14271427
await asyncio.sleep(
14281428
err.backoff.backoff_duration.ToTimedelta().total_seconds()
14291429
)
1430-
handle._apply_schedule_command(self._add_command(), err.backoff)
1430+
handle._apply_schedule_command(err.backoff)
14311431
# We have to put the handle back on the pending activity
14321432
# dict with its new seq
14331433
self._pending_activities[handle._seq] = handle
@@ -1437,35 +1437,41 @@ async def run_activity() -> Any:
14371437

14381438
# Create the handle and set as pending
14391439
handle = _ActivityHandle(self, input, run_activity())
1440-
handle._apply_schedule_command(self._add_command())
1440+
handle._apply_schedule_command()
14411441
self._pending_activities[handle._seq] = handle
14421442
return handle
14431443

14441444
async def _outbound_signal_child_workflow(
14451445
self, input: SignalChildWorkflowInput
14461446
) -> None:
1447+
payloads = (
1448+
self._payload_converter.to_payloads(input.args) if input.args else None
1449+
)
14471450
command = self._add_command()
14481451
v = command.signal_external_workflow_execution
14491452
v.child_workflow_id = input.child_workflow_id
14501453
v.signal_name = input.signal
1451-
if input.args:
1452-
v.args.extend(self._payload_converter.to_payloads(input.args))
1454+
if payloads:
1455+
v.args.extend(payloads)
14531456
if input.headers:
14541457
temporalio.common._apply_headers(input.headers, v.headers)
14551458
await self._signal_external_workflow(command)
14561459

14571460
async def _outbound_signal_external_workflow(
14581461
self, input: SignalExternalWorkflowInput
14591462
) -> None:
1463+
payloads = (
1464+
self._payload_converter.to_payloads(input.args) if input.args else None
1465+
)
14601466
command = self._add_command()
14611467
v = command.signal_external_workflow_execution
14621468
v.workflow_execution.namespace = input.namespace
14631469
v.workflow_execution.workflow_id = input.workflow_id
14641470
if input.workflow_run_id:
14651471
v.workflow_execution.run_id = input.workflow_run_id
14661472
v.signal_name = input.signal
1467-
if input.args:
1468-
v.args.extend(self._payload_converter.to_payloads(input.args))
1473+
if payloads:
1474+
v.args.extend(payloads)
14691475
if input.headers:
14701476
temporalio.common._apply_headers(input.headers, v.headers)
14711477
await self._signal_external_workflow(command)
@@ -1510,7 +1516,7 @@ async def run_child() -> Any:
15101516
handle = _ChildWorkflowHandle(
15111517
self, self._next_seq("child_workflow"), input, run_child()
15121518
)
1513-
handle._apply_start_command(self._add_command())
1519+
handle._apply_start_command()
15141520
self._pending_child_workflows[handle._seq] = handle
15151521

15161522
# Wait on start before returning
@@ -1761,7 +1767,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
17611767
await coro
17621768
except _ContinueAsNewError as err:
17631769
logger.debug("Workflow requested continue as new")
1764-
err._apply_command(self._add_command())
1770+
err._apply_command()
17651771
except (Exception, asyncio.CancelledError) as err:
17661772
# During tear down we can ignore exceptions. Technically the
17671773
# command-adding done later would throw a not-in-workflow exception
@@ -1776,7 +1782,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
17761782
# Handle continue as new
17771783
if isinstance(err, _ContinueAsNewError):
17781784
logger.debug("Workflow requested continue as new")
1779-
err._apply_command(self._add_command())
1785+
err._apply_command()
17801786
return
17811787

17821788
logger.debug(
@@ -2261,11 +2267,18 @@ def _resolve_backoff(
22612267

22622268
def _apply_schedule_command(
22632269
self,
2264-
command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
22652270
local_backoff: Optional[
22662271
temporalio.bridge.proto.activity_result.DoBackoff
22672272
] = None,
22682273
) -> None:
2274+
# Convert arguments before creating command in case it raises error
2275+
payloads = (
2276+
self._instance._payload_converter.to_payloads(self._input.args)
2277+
if self._input.args
2278+
else None
2279+
)
2280+
2281+
command = self._instance._add_command()
22692282
# TODO(cretz): Why can't MyPy infer this?
22702283
v: Union[
22712284
temporalio.bridge.proto.workflow_commands.ScheduleActivity,
@@ -2280,10 +2293,8 @@ def _apply_schedule_command(
22802293
v.activity_type = self._input.activity
22812294
if self._input.headers:
22822295
temporalio.common._apply_headers(self._input.headers, v.headers)
2283-
if self._input.args:
2284-
v.arguments.extend(
2285-
self._instance._payload_converter.to_payloads(self._input.args)
2286-
)
2296+
if payloads:
2297+
v.arguments.extend(payloads)
22872298
if self._input.schedule_to_close_timeout:
22882299
v.schedule_to_close_timeout.FromTimedelta(
22892300
self._input.schedule_to_close_timeout
@@ -2403,20 +2414,23 @@ def _resolve_failure(self, err: BaseException) -> None:
24032414
# future
24042415
self._result_fut.set_result(None)
24052416

2406-
def _apply_start_command(
2407-
self,
2408-
command: temporalio.bridge.proto.workflow_commands.WorkflowCommand,
2409-
) -> None:
2417+
def _apply_start_command(self) -> None:
2418+
# Convert arguments before creating command in case it raises error
2419+
payloads = (
2420+
self._instance._payload_converter.to_payloads(self._input.args)
2421+
if self._input.args
2422+
else None
2423+
)
2424+
2425+
command = self._instance._add_command()
24102426
v = command.start_child_workflow_execution
24112427
v.seq = self._seq
24122428
v.namespace = self._instance._info.namespace
24132429
v.workflow_id = self._input.id
24142430
v.workflow_type = self._input.workflow
24152431
v.task_queue = self._input.task_queue or self._instance._info.task_queue
2416-
if self._input.args:
2417-
v.input.extend(
2418-
self._instance._payload_converter.to_payloads(self._input.args)
2419-
)
2432+
if payloads:
2433+
v.input.extend(payloads)
24202434
if self._input.execution_timeout:
24212435
v.workflow_execution_timeout.FromTimedelta(self._input.execution_timeout)
24222436
if self._input.run_timeout:
@@ -2520,19 +2534,31 @@ def __init__(
25202534
self._instance = instance
25212535
self._input = input
25222536

2523-
def _apply_command(
2524-
self, command: temporalio.bridge.proto.workflow_commands.WorkflowCommand
2525-
) -> None:
2537+
def _apply_command(self) -> None:
2538+
# Convert arguments before creating command in case it raises error
2539+
payloads = (
2540+
self._instance._payload_converter.to_payloads(self._input.args)
2541+
if self._input.args
2542+
else None
2543+
)
2544+
memo_payloads = (
2545+
{
2546+
k: self._instance._payload_converter.to_payloads([val])[0]
2547+
for k, val in self._input.memo.items()
2548+
}
2549+
if self._input.memo
2550+
else None
2551+
)
2552+
2553+
command = self._instance._add_command()
25262554
v = command.continue_as_new_workflow_execution
25272555
v.SetInParent()
25282556
if self._input.workflow:
25292557
v.workflow_type = self._input.workflow
25302558
if self._input.task_queue:
25312559
v.task_queue = self._input.task_queue
2532-
if self._input.args:
2533-
v.arguments.extend(
2534-
self._instance._payload_converter.to_payloads(self._input.args)
2535-
)
2560+
if payloads:
2561+
v.arguments.extend(payloads)
25362562
if self._input.run_timeout:
25372563
v.workflow_run_timeout.FromTimedelta(self._input.run_timeout)
25382564
if self._input.task_timeout:
@@ -2541,11 +2567,9 @@ def _apply_command(
25412567
temporalio.common._apply_headers(self._input.headers, v.headers)
25422568
if self._input.retry_policy:
25432569
self._input.retry_policy.apply_to_proto(v.retry_policy)
2544-
if self._input.memo:
2545-
for k, val in self._input.memo.items():
2546-
v.memo[k].CopyFrom(
2547-
self._instance._payload_converter.to_payloads([val])[0]
2548-
)
2570+
if memo_payloads:
2571+
for k, val in memo_payloads.items():
2572+
v.memo[k].CopyFrom(val)
25492573
if self._input.search_attributes:
25502574
_encode_search_attributes(
25512575
self._input.search_attributes, v.search_attributes

tests/worker/test_workflow.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3350,15 +3350,27 @@ async def test_workflow_optional_param(client: Client):
33503350

33513351

33523352
class ExceptionRaisingPayloadConverter(DefaultPayloadConverter):
3353-
bad_str = "bad-payload-str"
3353+
bad_outbound_str = "bad-outbound-payload-str"
3354+
bad_inbound_str = "bad-inbound-payload-str"
3355+
3356+
def to_payloads(self, values: Sequence[Any]) -> List[Payload]:
3357+
if any(
3358+
value == ExceptionRaisingPayloadConverter.bad_outbound_str
3359+
for value in values
3360+
):
3361+
raise ApplicationError("Intentional outbound converter failure")
3362+
return super().to_payloads(values)
33543363

33553364
def from_payloads(
33563365
self, payloads: Sequence[Payload], type_hints: Optional[List] = None
33573366
) -> List[Any]:
33583367
# Check if any payloads contain the bad data
33593368
for payload in payloads:
3360-
if ExceptionRaisingPayloadConverter.bad_str.encode() in payload.data:
3361-
raise ApplicationError("Intentional converter failure")
3369+
if (
3370+
ExceptionRaisingPayloadConverter.bad_inbound_str.encode()
3371+
in payload.data
3372+
):
3373+
raise ApplicationError("Intentional inbound converter failure")
33623374
return super().from_payloads(payloads, type_hints)
33633375

33643376

@@ -3383,12 +3395,46 @@ async def test_exception_raising_converter_param(client: Client):
33833395
with pytest.raises(WorkflowFailureError) as err:
33843396
await client.execute_workflow(
33853397
ExceptionRaisingConverterWorkflow.run,
3386-
ExceptionRaisingPayloadConverter.bad_str,
3398+
ExceptionRaisingPayloadConverter.bad_inbound_str,
33873399
id=f"workflow-{uuid.uuid4()}",
33883400
task_queue=worker.task_queue,
33893401
)
33903402
assert isinstance(err.value.cause, ApplicationError)
3391-
assert "Intentional converter failure" in str(err.value.cause)
3403+
assert "Intentional inbound converter failure" in str(err.value.cause)
3404+
3405+
3406+
@workflow.defn
3407+
class ActivityOutboundConversionFailureWorkflow:
3408+
@workflow.run
3409+
async def run(self) -> None:
3410+
await workflow.execute_activity(
3411+
"some-activity",
3412+
ExceptionRaisingPayloadConverter.bad_outbound_str,
3413+
start_to_close_timeout=timedelta(seconds=10),
3414+
)
3415+
3416+
3417+
async def test_workflow_activity_outbound_conversion_failure(client: Client):
3418+
# This test used to fail because we created commands _before_ we attempted
3419+
# to convert the arguments thereby causing half-built commands to get sent
3420+
# to the server.
3421+
3422+
# Clone the client but change the data converter to use our converter
3423+
config = client.config()
3424+
config["data_converter"] = dataclasses.replace(
3425+
config["data_converter"],
3426+
payload_converter_class=ExceptionRaisingPayloadConverter,
3427+
)
3428+
client = Client(**config)
3429+
async with new_worker(client, ActivityOutboundConversionFailureWorkflow) as worker:
3430+
with pytest.raises(WorkflowFailureError) as err:
3431+
await client.execute_workflow(
3432+
ActivityOutboundConversionFailureWorkflow.run,
3433+
id=f"wf-{uuid.uuid4()}",
3434+
task_queue=worker.task_queue,
3435+
)
3436+
assert isinstance(err.value.cause, ApplicationError)
3437+
assert "Intentional outbound converter failure" in str(err.value.cause)
33923438

33933439

33943440
@dataclass

0 commit comments

Comments
 (0)