Skip to content

Commit e17146f

Browse files
authored
Adding encoding of payload headers, currently defaults to existing behavior, not encoding (#939)
* Adding encoding of payload headers, currently defaults to true * Switching to enum for header codec behavior * Linting * Address comments * Fix connect argument * Exclude _from_raw from eq comparison * Exclude _from_raw from init * Debugging CI * Skip on time skipping server
1 parent 61af3ea commit e17146f

File tree

8 files changed

+342
-12
lines changed

8 files changed

+342
-12
lines changed

temporalio/bridge/worker.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Awaitable,
1212
Callable,
1313
List,
14+
Mapping,
1415
Optional,
1516
Sequence,
1617
Set,
@@ -287,6 +288,35 @@ async def finalize_shutdown(self) -> None:
287288
)
288289

289290

291+
async def _apply_to_headers(
292+
headers: Mapping[str, temporalio.api.common.v1.Payload],
293+
cb: Callable[
294+
[Sequence[temporalio.api.common.v1.Payload]],
295+
Awaitable[List[temporalio.api.common.v1.Payload]],
296+
],
297+
) -> None:
298+
"""Apply API payload callback to headers."""
299+
for payload in headers.values():
300+
new_payload = (await cb([payload]))[0]
301+
payload.CopyFrom(new_payload)
302+
303+
304+
async def _decode_headers(
305+
headers: Mapping[str, temporalio.api.common.v1.Payload],
306+
codec: temporalio.converter.PayloadCodec,
307+
) -> None:
308+
"""Decode headers with the given codec."""
309+
return await _apply_to_headers(headers, codec.decode)
310+
311+
312+
async def _encode_headers(
313+
headers: Mapping[str, temporalio.api.common.v1.Payload],
314+
codec: temporalio.converter.PayloadCodec,
315+
) -> None:
316+
"""Encode headers with the given codec."""
317+
return await _apply_to_headers(headers, codec.encode)
318+
319+
290320
async def _apply_to_payloads(
291321
payloads: PayloadContainer,
292322
cb: Callable[
@@ -352,11 +382,14 @@ async def _encode_payload(
352382
async def decode_activation(
353383
act: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
354384
codec: temporalio.converter.PayloadCodec,
385+
decode_headers: bool,
355386
) -> None:
356387
"""Decode the given activation with the codec."""
357388
for job in act.jobs:
358389
if job.HasField("query_workflow"):
359390
await _decode_payloads(job.query_workflow.arguments, codec)
391+
if decode_headers:
392+
await _decode_headers(job.query_workflow.headers, codec)
360393
elif job.HasField("resolve_activity"):
361394
if job.resolve_activity.result.HasField("cancelled"):
362395
await codec.decode_failure(
@@ -401,8 +434,12 @@ async def decode_activation(
401434
await codec.decode_failure(job.resolve_signal_external_workflow.failure)
402435
elif job.HasField("signal_workflow"):
403436
await _decode_payloads(job.signal_workflow.input, codec)
437+
if decode_headers:
438+
await _decode_headers(job.signal_workflow.headers, codec)
404439
elif job.HasField("initialize_workflow"):
405440
await _decode_payloads(job.initialize_workflow.arguments, codec)
441+
if decode_headers:
442+
await _decode_headers(job.initialize_workflow.headers, codec)
406443
if job.initialize_workflow.HasField("continued_failure"):
407444
await codec.decode_failure(job.initialize_workflow.continued_failure)
408445
for val in job.initialize_workflow.memo.fields.values():
@@ -416,11 +453,14 @@ async def decode_activation(
416453
val.data = new_payload.data
417454
elif job.HasField("do_update"):
418455
await _decode_payloads(job.do_update.input, codec)
456+
if decode_headers:
457+
await _decode_headers(job.do_update.headers, codec)
419458

420459

421460
async def encode_completion(
422461
comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
423462
codec: temporalio.converter.PayloadCodec,
463+
encode_headers: bool,
424464
) -> None:
425465
"""Recursively encode the given completion with the codec."""
426466
if comp.HasField("failed"):
@@ -436,6 +476,10 @@ async def encode_completion(
436476
await _encode_payloads(
437477
command.continue_as_new_workflow_execution.arguments, codec
438478
)
479+
if encode_headers:
480+
await _encode_headers(
481+
command.continue_as_new_workflow_execution.headers, codec
482+
)
439483
for val in command.continue_as_new_workflow_execution.memo.values():
440484
await _encode_payload(val, codec)
441485
elif command.HasField("fail_workflow_execution"):
@@ -451,16 +495,30 @@ async def encode_completion(
451495
)
452496
elif command.HasField("schedule_activity"):
453497
await _encode_payloads(command.schedule_activity.arguments, codec)
498+
if encode_headers:
499+
await _encode_headers(command.schedule_activity.headers, codec)
454500
elif command.HasField("schedule_local_activity"):
455501
await _encode_payloads(command.schedule_local_activity.arguments, codec)
502+
if encode_headers:
503+
await _encode_headers(
504+
command.schedule_local_activity.headers, codec
505+
)
456506
elif command.HasField("signal_external_workflow_execution"):
457507
await _encode_payloads(
458508
command.signal_external_workflow_execution.args, codec
459509
)
510+
if encode_headers:
511+
await _encode_headers(
512+
command.signal_external_workflow_execution.headers, codec
513+
)
460514
elif command.HasField("start_child_workflow_execution"):
461515
await _encode_payloads(
462516
command.start_child_workflow_execution.input, codec
463517
)
518+
if encode_headers:
519+
await _encode_headers(
520+
command.start_child_workflow_execution.headers, codec
521+
)
464522
for val in command.start_child_workflow_execution.memo.values():
465523
await _encode_payload(val, codec)
466524
elif command.HasField("update_response"):

temporalio/client.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Mapping,
2828
Optional,
2929
Sequence,
30+
Text,
3031
Tuple,
3132
Type,
3233
Union,
@@ -37,6 +38,7 @@
3738
import google.protobuf.duration_pb2
3839
import google.protobuf.json_format
3940
import google.protobuf.timestamp_pb2
41+
from google.protobuf.internal.containers import MessageMap
4042
from typing_extensions import Concatenate, Required, TypedDict
4143

4244
import temporalio.api.common.v1
@@ -66,6 +68,7 @@
6668
TLSConfig,
6769
)
6870

71+
from .common import HeaderCodecBehavior
6972
from .types import (
7073
AnyType,
7174
LocalReturnType,
@@ -116,6 +119,7 @@ async def connect(
116119
lazy: bool = False,
117120
runtime: Optional[temporalio.runtime.Runtime] = None,
118121
http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None,
122+
header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC,
119123
) -> Client:
120124
"""Connect to a Temporal server.
121125
@@ -160,6 +164,7 @@ async def connect(
160164
used for workers.
161165
runtime: The runtime for this client, or the default if unset.
162166
http_connect_proxy_config: Configuration for HTTP CONNECT proxy.
167+
header_codec_behavior: Encoding behavior for headers sent by the client.
163168
"""
164169
connect_config = temporalio.service.ConnectConfig(
165170
target_host=target_host,
@@ -179,6 +184,7 @@ async def connect(
179184
data_converter=data_converter,
180185
interceptors=interceptors,
181186
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
187+
header_codec_behavior=header_codec_behavior,
182188
)
183189

184190
def __init__(
@@ -191,6 +197,7 @@ def __init__(
191197
default_workflow_query_reject_condition: Optional[
192198
temporalio.common.QueryRejectCondition
193199
] = None,
200+
header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC,
194201
):
195202
"""Create a Temporal client from a service client.
196203
@@ -208,6 +215,7 @@ def __init__(
208215
data_converter=data_converter,
209216
interceptors=interceptors,
210217
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
218+
header_codec_behavior=header_codec_behavior,
211219
)
212220

213221
def config(self) -> ClientConfig:
@@ -1501,6 +1509,7 @@ class ClientConfig(TypedDict, total=False):
15011509
default_workflow_query_reject_condition: Required[
15021510
Optional[temporalio.common.QueryRejectCondition]
15031511
]
1512+
header_codec_behavior: Required[HeaderCodecBehavior]
15041513

15051514

15061515
class WorkflowHistoryEventFilterType(IntEnum):
@@ -3859,6 +3868,10 @@ class ScheduleActionStartWorkflow(ScheduleAction):
38593868
priority: temporalio.common.Priority
38603869

38613870
headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]]
3871+
"""
3872+
Headers may still be encoded by the payload codec if present.
3873+
"""
3874+
_from_raw: bool = dataclasses.field(compare=False, init=False)
38623875

38633876
@staticmethod
38643877
def _from_proto( # pyright: ignore
@@ -3985,6 +3998,7 @@ def __init__(
39853998
"""
39863999
super().__init__()
39874000
if raw_info:
4001+
self._from_raw = True
39884002
# Ignore other fields
39894003
self.workflow = raw_info.workflow_type.name
39904004
self.args = raw_info.input.payloads if raw_info.input else []
@@ -4044,6 +4058,7 @@ def __init__(
40444058
else temporalio.common.Priority.default
40454059
)
40464060
else:
4061+
self._from_raw = False
40474062
if not id:
40484063
raise ValueError("ID required")
40494064
if not task_queue:
@@ -4067,7 +4082,7 @@ def __init__(
40674082
self.memo = memo
40684083
self.typed_search_attributes = typed_search_attributes
40694084
self.untyped_search_attributes = untyped_search_attributes
4070-
self.headers = headers
4085+
self.headers = headers # encode here
40714086
self.static_summary = static_summary
40724087
self.static_details = static_details
40734088
self.priority = priority
@@ -4145,8 +4160,12 @@ async def _to_proto(
41454160
self.typed_search_attributes, action.start_workflow.search_attributes
41464161
)
41474162
if self.headers:
4148-
temporalio.common._apply_headers(
4149-
self.headers, action.start_workflow.header.fields
4163+
await _apply_headers(
4164+
self.headers,
4165+
action.start_workflow.header.fields,
4166+
client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC
4167+
and not self._from_raw,
4168+
client.data_converter.payload_codec,
41504169
)
41514170
return action
41524171

@@ -5920,7 +5939,7 @@ async def _populate_start_workflow_execution_request(
59205939
if input.start_delay is not None:
59215940
req.workflow_start_delay.FromTimedelta(input.start_delay)
59225941
if input.headers is not None:
5923-
temporalio.common._apply_headers(input.headers, req.header.fields)
5942+
await self._apply_headers(input.headers, req.header.fields)
59245943
if input.priority is not None:
59255944
req.priority.CopyFrom(input.priority._to_proto())
59265945
if input.versioning_override is not None:
@@ -6006,7 +6025,7 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
60066025
await self._client.data_converter.encode(input.args)
60076026
)
60086027
if input.headers is not None:
6009-
temporalio.common._apply_headers(input.headers, req.query.header.fields)
6028+
await self._apply_headers(input.headers, req.query.header.fields)
60106029
try:
60116030
resp = await self._client.workflow_service.query_workflow(
60126031
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
@@ -6052,7 +6071,7 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None:
60526071
await self._client.data_converter.encode(input.args)
60536072
)
60546073
if input.headers is not None:
6055-
temporalio.common._apply_headers(input.headers, req.header.fields)
6074+
await self._apply_headers(input.headers, req.header.fields)
60566075
await self._client.workflow_service.signal_workflow_execution(
60576076
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
60586077
)
@@ -6163,9 +6182,7 @@ async def _build_update_workflow_execution_request(
61636182
await self._client.data_converter.encode(input.args)
61646183
)
61656184
if input.headers is not None:
6166-
temporalio.common._apply_headers(
6167-
input.headers, req.request.input.header.fields
6168-
)
6185+
await self._apply_headers(input.headers, req.request.input.header.fields)
61696186
return req
61706187

61716188
async def start_update_with_start_workflow(
@@ -6721,6 +6738,33 @@ async def get_worker_task_reachability(
67216738
)
67226739
return WorkerTaskReachability._from_proto(resp)
67236740

6741+
async def _apply_headers(
6742+
self,
6743+
source: Optional[Mapping[str, temporalio.api.common.v1.Payload]],
6744+
dest: MessageMap[Text, temporalio.api.common.v1.Payload],
6745+
) -> None:
6746+
await _apply_headers(
6747+
source,
6748+
dest,
6749+
self._client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC,
6750+
self._client.data_converter.payload_codec,
6751+
)
6752+
6753+
6754+
async def _apply_headers(
6755+
source: Optional[Mapping[str, temporalio.api.common.v1.Payload]],
6756+
dest: MessageMap[Text, temporalio.api.common.v1.Payload],
6757+
encode_headers: bool,
6758+
codec: Optional[temporalio.converter.PayloadCodec],
6759+
) -> None:
6760+
if source is None:
6761+
return
6762+
if encode_headers and codec is not None:
6763+
for payload in source.values():
6764+
new_payload = (await codec.encode([payload]))[0]
6765+
payload.CopyFrom(new_payload)
6766+
temporalio.common._apply_headers(source, dest)
6767+
67246768

67256769
def _history_from_json(
67266770
history: Union[str, Dict[str, Any]],

temporalio/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,3 +1230,14 @@ def _type_hints_from_func(
12301230
# necessarily
12311231
args.append(arg_hint) # type: ignore
12321232
return args, ret
1233+
1234+
1235+
class HeaderCodecBehavior(IntEnum):
1236+
"""Different ways to handle header encoding"""
1237+
1238+
NO_CODEC = 1
1239+
"""Don't encode or decode any headers automatically"""
1240+
CODEC = 2
1241+
"""Encode and decode all headers automatically"""
1242+
WORKFLOW_ONLY_CODEC = 3
1243+
"""Only automatically encode and decode headers in workflow activation encoding and decoding."""

temporalio/worker/_activity.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
data_converter: temporalio.converter.DataConverter,
7070
interceptors: Sequence[Interceptor],
7171
metric_meter: temporalio.common.MetricMeter,
72+
encode_headers: bool,
7273
) -> None:
7374
self._bridge_worker = bridge_worker
7475
self._task_queue = task_queue
@@ -78,6 +79,7 @@ def __init__(
7879
self._data_converter = data_converter
7980
self._interceptors = interceptors
8081
self._metric_meter = metric_meter
82+
self._encode_headers = encode_headers
8183
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()
8284
# Lazily created on first activity
8385
self._worker_shutdown_event: Optional[temporalio.activity._CompositeEvent] = (
@@ -543,6 +545,14 @@ async def _execute_activity(
543545
workflow_type=start.workflow_type,
544546
priority=temporalio.common.Priority._from_proto(start.priority),
545547
)
548+
549+
if self._encode_headers and self._data_converter.payload_codec is not None:
550+
for payload in start.header_fields.values():
551+
new_payload = (
552+
await self._data_converter.payload_codec.decode([payload])
553+
)[0]
554+
payload.CopyFrom(new_payload)
555+
546556
running_activity.info = info
547557
input = ExecuteActivityInput(
548558
fn=activity_def.fn,

0 commit comments

Comments
 (0)