Skip to content

Adding encoding of payload headers, currently defaults to existing behavior, not encoding #939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 8, 2025
Merged
58 changes: 58 additions & 0 deletions temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Awaitable,
Callable,
List,
Mapping,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -287,6 +288,35 @@ async def finalize_shutdown(self) -> None:
)


async def _apply_to_headers(
headers: Mapping[str, temporalio.api.common.v1.Payload],
cb: Callable[
[Sequence[temporalio.api.common.v1.Payload]],
Awaitable[List[temporalio.api.common.v1.Payload]],
],
) -> None:
"""Apply API payload callback to headers."""
for payload in headers.values():
new_payload = (await cb([payload]))[0]
payload.CopyFrom(new_payload)


async def _decode_headers(
headers: Mapping[str, temporalio.api.common.v1.Payload],
codec: temporalio.converter.PayloadCodec,
) -> None:
"""Decode headers with the given codec."""
return await _apply_to_headers(headers, codec.decode)


async def _encode_headers(
headers: Mapping[str, temporalio.api.common.v1.Payload],
codec: temporalio.converter.PayloadCodec,
) -> None:
"""Encode headers with the given codec."""
return await _apply_to_headers(headers, codec.encode)


async def _apply_to_payloads(
payloads: PayloadContainer,
cb: Callable[
Expand Down Expand Up @@ -352,11 +382,14 @@ async def _encode_payload(
async def decode_activation(
act: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
codec: temporalio.converter.PayloadCodec,
decode_headers: bool,
) -> None:
"""Decode the given activation with the codec."""
for job in act.jobs:
if job.HasField("query_workflow"):
await _decode_payloads(job.query_workflow.arguments, codec)
if decode_headers:
await _decode_headers(job.query_workflow.headers, codec)
elif job.HasField("resolve_activity"):
if job.resolve_activity.result.HasField("cancelled"):
await codec.decode_failure(
Expand Down Expand Up @@ -401,8 +434,12 @@ async def decode_activation(
await codec.decode_failure(job.resolve_signal_external_workflow.failure)
elif job.HasField("signal_workflow"):
await _decode_payloads(job.signal_workflow.input, codec)
if decode_headers:
await _decode_headers(job.signal_workflow.headers, codec)
elif job.HasField("initialize_workflow"):
await _decode_payloads(job.initialize_workflow.arguments, codec)
if decode_headers:
await _decode_headers(job.initialize_workflow.headers, codec)
if job.initialize_workflow.HasField("continued_failure"):
await codec.decode_failure(job.initialize_workflow.continued_failure)
for val in job.initialize_workflow.memo.fields.values():
Expand All @@ -416,11 +453,14 @@ async def decode_activation(
val.data = new_payload.data
elif job.HasField("do_update"):
await _decode_payloads(job.do_update.input, codec)
if decode_headers:
await _decode_headers(job.do_update.headers, codec)


async def encode_completion(
comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
codec: temporalio.converter.PayloadCodec,
encode_headers: bool,
) -> None:
"""Recursively encode the given completion with the codec."""
if comp.HasField("failed"):
Expand All @@ -436,6 +476,10 @@ async def encode_completion(
await _encode_payloads(
command.continue_as_new_workflow_execution.arguments, codec
)
if encode_headers:
await _encode_headers(
command.continue_as_new_workflow_execution.headers, codec
)
for val in command.continue_as_new_workflow_execution.memo.values():
await _encode_payload(val, codec)
elif command.HasField("fail_workflow_execution"):
Expand All @@ -451,16 +495,30 @@ async def encode_completion(
)
elif command.HasField("schedule_activity"):
await _encode_payloads(command.schedule_activity.arguments, codec)
if encode_headers:
await _encode_headers(command.schedule_activity.headers, codec)
elif command.HasField("schedule_local_activity"):
await _encode_payloads(command.schedule_local_activity.arguments, codec)
if encode_headers:
await _encode_headers(
command.schedule_local_activity.headers, codec
)
elif command.HasField("signal_external_workflow_execution"):
await _encode_payloads(
command.signal_external_workflow_execution.args, codec
)
if encode_headers:
await _encode_headers(
command.signal_external_workflow_execution.headers, codec
)
elif command.HasField("start_child_workflow_execution"):
await _encode_payloads(
command.start_child_workflow_execution.input, codec
)
if encode_headers:
await _encode_headers(
command.start_child_workflow_execution.headers, codec
)
for val in command.start_child_workflow_execution.memo.values():
await _encode_payload(val, codec)
elif command.HasField("update_response"):
Expand Down
62 changes: 53 additions & 9 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Mapping,
Optional,
Sequence,
Text,
Tuple,
Type,
Union,
Expand All @@ -37,6 +38,7 @@
import google.protobuf.duration_pb2
import google.protobuf.json_format
import google.protobuf.timestamp_pb2
from google.protobuf.internal.containers import MessageMap
from typing_extensions import Concatenate, Required, TypedDict

import temporalio.api.common.v1
Expand Down Expand Up @@ -66,6 +68,7 @@
TLSConfig,
)

from .common import HeaderCodecBehavior
from .types import (
AnyType,
LocalReturnType,
Expand Down Expand Up @@ -116,6 +119,7 @@ async def connect(
lazy: bool = False,
runtime: Optional[temporalio.runtime.Runtime] = None,
http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None,
header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC,
) -> Client:
"""Connect to a Temporal server.
Expand Down Expand Up @@ -160,6 +164,7 @@ async def connect(
used for workers.
runtime: The runtime for this client, or the default if unset.
http_connect_proxy_config: Configuration for HTTP CONNECT proxy.
header_codec_behavior: Encoding behavior for headers sent by the client.
"""
connect_config = temporalio.service.ConnectConfig(
target_host=target_host,
Expand All @@ -179,6 +184,7 @@ async def connect(
data_converter=data_converter,
interceptors=interceptors,
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
header_codec_behavior=header_codec_behavior,
)

def __init__(
Expand All @@ -191,6 +197,7 @@ def __init__(
default_workflow_query_reject_condition: Optional[
temporalio.common.QueryRejectCondition
] = None,
header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC,
):
"""Create a Temporal client from a service client.
Expand All @@ -208,6 +215,7 @@ def __init__(
data_converter=data_converter,
interceptors=interceptors,
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
header_codec_behavior=header_codec_behavior,
)

def config(self) -> ClientConfig:
Expand Down Expand Up @@ -1501,6 +1509,7 @@ class ClientConfig(TypedDict, total=False):
default_workflow_query_reject_condition: Required[
Optional[temporalio.common.QueryRejectCondition]
]
header_codec_behavior: Required[HeaderCodecBehavior]


class WorkflowHistoryEventFilterType(IntEnum):
Expand Down Expand Up @@ -3859,6 +3868,10 @@ class ScheduleActionStartWorkflow(ScheduleAction):
priority: temporalio.common.Priority

headers: Optional[Mapping[str, temporalio.api.common.v1.Payload]]
"""
Headers may still be encoded by the payload codec if present.
"""
_from_raw: bool = dataclasses.field(compare=False, init=False)

@staticmethod
def _from_proto( # pyright: ignore
Expand Down Expand Up @@ -3985,6 +3998,7 @@ def __init__(
"""
super().__init__()
if raw_info:
self._from_raw = True
# Ignore other fields
self.workflow = raw_info.workflow_type.name
self.args = raw_info.input.payloads if raw_info.input else []
Expand Down Expand Up @@ -4044,6 +4058,7 @@ def __init__(
else temporalio.common.Priority.default
)
else:
self._from_raw = False
if not id:
raise ValueError("ID required")
if not task_queue:
Expand All @@ -4067,7 +4082,7 @@ def __init__(
self.memo = memo
self.typed_search_attributes = typed_search_attributes
self.untyped_search_attributes = untyped_search_attributes
self.headers = headers
self.headers = headers # encode here
self.static_summary = static_summary
self.static_details = static_details
self.priority = priority
Expand Down Expand Up @@ -4145,8 +4160,12 @@ async def _to_proto(
self.typed_search_attributes, action.start_workflow.search_attributes
)
if self.headers:
temporalio.common._apply_headers(
self.headers, action.start_workflow.header.fields
await _apply_headers(
self.headers,
action.start_workflow.header.fields,
client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC
and not self._from_raw,
client.data_converter.payload_codec,
)
return action

Expand Down Expand Up @@ -5920,7 +5939,7 @@ async def _populate_start_workflow_execution_request(
if input.start_delay is not None:
req.workflow_start_delay.FromTimedelta(input.start_delay)
if input.headers is not None:
temporalio.common._apply_headers(input.headers, req.header.fields)
await self._apply_headers(input.headers, req.header.fields)
if input.priority is not None:
req.priority.CopyFrom(input.priority._to_proto())
if input.versioning_override is not None:
Expand Down Expand Up @@ -6006,7 +6025,7 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
await self._client.data_converter.encode(input.args)
)
if input.headers is not None:
temporalio.common._apply_headers(input.headers, req.query.header.fields)
await self._apply_headers(input.headers, req.query.header.fields)
try:
resp = await self._client.workflow_service.query_workflow(
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
Expand Down Expand Up @@ -6052,7 +6071,7 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None:
await self._client.data_converter.encode(input.args)
)
if input.headers is not None:
temporalio.common._apply_headers(input.headers, req.header.fields)
await self._apply_headers(input.headers, req.header.fields)
await self._client.workflow_service.signal_workflow_execution(
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
)
Expand Down Expand Up @@ -6163,9 +6182,7 @@ async def _build_update_workflow_execution_request(
await self._client.data_converter.encode(input.args)
)
if input.headers is not None:
temporalio.common._apply_headers(
input.headers, req.request.input.header.fields
)
await self._apply_headers(input.headers, req.request.input.header.fields)
return req

async def start_update_with_start_workflow(
Expand Down Expand Up @@ -6721,6 +6738,33 @@ async def get_worker_task_reachability(
)
return WorkerTaskReachability._from_proto(resp)

async def _apply_headers(
self,
source: Optional[Mapping[str, temporalio.api.common.v1.Payload]],
dest: MessageMap[Text, temporalio.api.common.v1.Payload],
) -> None:
await _apply_headers(
source,
dest,
self._client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC,
self._client.data_converter.payload_codec,
)


async def _apply_headers(
source: Optional[Mapping[str, temporalio.api.common.v1.Payload]],
dest: MessageMap[Text, temporalio.api.common.v1.Payload],
encode_headers: bool,
codec: Optional[temporalio.converter.PayloadCodec],
) -> None:
if source is None:
return
if encode_headers and codec is not None:
for payload in source.values():
new_payload = (await codec.encode([payload]))[0]
payload.CopyFrom(new_payload)
temporalio.common._apply_headers(source, dest)


def _history_from_json(
history: Union[str, Dict[str, Any]],
Expand Down
11 changes: 11 additions & 0 deletions temporalio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,3 +1230,14 @@ def _type_hints_from_func(
# necessarily
args.append(arg_hint) # type: ignore
return args, ret


class HeaderCodecBehavior(IntEnum):
"""Different ways to handle header encoding"""

NO_CODEC = 1
"""Don't encode or decode any headers automatically"""
CODEC = 2
"""Encode and decode all headers automatically"""
WORKFLOW_ONLY_CODEC = 3
"""Only automatically encode and decode headers in workflow activation encoding and decoding."""
10 changes: 10 additions & 0 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
data_converter: temporalio.converter.DataConverter,
interceptors: Sequence[Interceptor],
metric_meter: temporalio.common.MetricMeter,
encode_headers: bool,
) -> None:
self._bridge_worker = bridge_worker
self._task_queue = task_queue
Expand All @@ -78,6 +79,7 @@ def __init__(
self._data_converter = data_converter
self._interceptors = interceptors
self._metric_meter = metric_meter
self._encode_headers = encode_headers
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()
# Lazily created on first activity
self._worker_shutdown_event: Optional[temporalio.activity._CompositeEvent] = (
Expand Down Expand Up @@ -543,6 +545,14 @@ async def _execute_activity(
workflow_type=start.workflow_type,
priority=temporalio.common.Priority._from_proto(start.priority),
)

if self._encode_headers and self._data_converter.payload_codec is not None:
for payload in start.header_fields.values():
new_payload = (
await self._data_converter.payload_codec.decode([payload])
)[0]
payload.CopyFrom(new_payload)

running_activity.info = info
input = ExecuteActivityInput(
fn=activity_def.fn,
Expand Down
Loading