Skip to content

Commit 82f723b

Browse files
authored
Properly encode failure encoded attributes (#251)
Fixes #247
1 parent f22efea commit 82f723b

File tree

4 files changed

+106
-26
lines changed

4 files changed

+106
-26
lines changed

poetry.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ typing-extensions = "^4.2.0"
3838
black = "^22.3.0"
3939
cibuildwheel = "^2.11.0"
4040
grpcio-tools = "^1.48.0"
41-
isort = "^5.10.1"
41+
isort = "^5.11.3"
4242
mypy = "^0.971"
4343
mypy-protobuf = "^3.3.0"
4444
protoc-wheel-0 = "^21.1"

temporalio/converter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,13 @@ async def _apply_to_failure_payloads(
573573
failure: temporalio.api.failure.v1.Failure,
574574
cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]],
575575
) -> None:
576+
if failure.HasField("encoded_attributes"):
577+
# Wrap in payloads and merge back
578+
payloads = temporalio.api.common.v1.Payloads(
579+
payloads=[failure.encoded_attributes]
580+
)
581+
await cb(payloads)
582+
failure.encoded_attributes.CopyFrom(payloads.payloads[0])
576583
if failure.HasField(
577584
"application_failure_info"
578585
) and failure.application_failure_info.HasField("details"):

tests/test_converter.py

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,19 @@
3333

3434
import temporalio.api.common.v1
3535
import temporalio.common
36-
import temporalio.converter
36+
from temporalio.api.common.v1 import Payload
3737
from temporalio.api.common.v1 import Payload as AnotherNameForPayload
38+
from temporalio.api.common.v1 import Payloads
3839
from temporalio.api.failure.v1 import Failure
40+
from temporalio.converter import (
41+
BinaryProtoPayloadConverter,
42+
DataConverter,
43+
DefaultFailureConverterWithEncodedAttributes,
44+
JSONPlainPayloadConverter,
45+
PayloadCodec,
46+
decode_search_attributes,
47+
encode_search_attribute_values,
48+
)
3949
from temporalio.exceptions import ApplicationError, FailureError
4050

4151
# StrEnum is available in 3.11+
@@ -77,7 +87,7 @@ async def assert_payload(
7787
expected_decoded_input=None,
7888
type_hint=None,
7989
):
80-
payloads = await temporalio.converter.DataConverter().encode([input])
90+
payloads = await DataConverter().encode([input])
8191
# Check encoding and data
8292
assert len(payloads) == 1
8393
if isinstance(expected_encoding, str):
@@ -87,9 +97,7 @@ async def assert_payload(
8797
expected_data = expected_data.encode()
8898
assert payloads[0].data == expected_data
8999
# Decode and check
90-
actual_inputs = await temporalio.converter.DataConverter().decode(
91-
payloads, [type_hint]
92-
)
100+
actual_inputs = await DataConverter().decode(payloads, [type_hint])
93101
assert len(actual_inputs) == 1
94102
if expected_decoded_input is None:
95103
expected_decoded_input = input
@@ -158,7 +166,7 @@ async def assert_payload(
158166
def test_binary_proto():
159167
# We have to test this separately because by default it never encodes
160168
# anything since JSON proto takes precedence
161-
conv = temporalio.converter.BinaryProtoPayloadConverter()
169+
conv = BinaryProtoPayloadConverter()
162170
proto = temporalio.api.common.v1.WorkflowExecution(workflow_id="id1", run_id="id2")
163171
payload = conv.to_payload(proto)
164172
assert payload.metadata["encoding"] == b"binary/protobuf"
@@ -172,11 +180,11 @@ def test_binary_proto():
172180

173181
def test_encode_search_attribute_values():
174182
with pytest.raises(TypeError, match="of type tuple not one of"):
175-
temporalio.converter.encode_search_attribute_values([("bad type",)])
183+
encode_search_attribute_values([("bad type",)])
176184
with pytest.raises(ValueError, match="Timezone must be present"):
177-
temporalio.converter.encode_search_attribute_values([datetime.utcnow()])
185+
encode_search_attribute_values([datetime.utcnow()])
178186
with pytest.raises(TypeError, match="must have the same type"):
179-
temporalio.converter.encode_search_attribute_values(["foo", 123])
187+
encode_search_attribute_values(["foo", 123])
180188

181189

182190
def test_decode_search_attributes():
@@ -192,25 +200,23 @@ def payload(key, dtype, data, encoding=None):
192200
return temporalio.api.common.v1.SearchAttributes(indexed_fields={key: check})
193201

194202
# Check basic keyword parsing works
195-
kw_check = temporalio.converter.decode_search_attributes(
196-
payload("kw", "Keyword", '"test-id"')
197-
)
203+
kw_check = decode_search_attributes(payload("kw", "Keyword", '"test-id"'))
198204
assert kw_check["kw"][0] == "test-id"
199205

200206
# Ensure original DT functionality works
201-
dt_check = temporalio.converter.decode_search_attributes(
207+
dt_check = decode_search_attributes(
202208
payload("dt", "Datetime", '"2020-01-01T00:00:00"')
203209
)
204210
assert dt_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0)
205211

206212
# Check timezone aware works as server is using ISO 8601
207-
dttz_check = temporalio.converter.decode_search_attributes(
213+
dttz_check = decode_search_attributes(
208214
payload("dt", "Datetime", '"2020-01-01T00:00:00Z"')
209215
)
210216
assert dttz_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
211217

212218
# Check timezone aware, hour offset
213-
dttz_check = temporalio.converter.decode_search_attributes(
219+
dttz_check = decode_search_attributes(
214220
payload("dt", "Datetime", '"2020-01-01T00:00:00+00:00"')
215221
)
216222
assert dttz_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@@ -245,7 +251,7 @@ class MyPydanticClass(pydantic.BaseModel):
245251

246252

247253
def test_json_type_hints():
248-
converter = temporalio.converter.JSONPlainPayloadConverter()
254+
converter = JSONPlainPayloadConverter()
249255

250256
def ok(
251257
hint: Type, value: Any, expected_result: Any = temporalio.common._arg_unset
@@ -415,10 +421,8 @@ async def test_exception_format():
415421

416422
# Convert to failure and back
417423
failure = Failure()
418-
await temporalio.converter.DataConverter.default.encode_failure(actual_err, failure)
419-
failure_error = await temporalio.converter.DataConverter.default.decode_failure(
420-
failure
421-
)
424+
await DataConverter.default.encode_failure(actual_err, failure)
425+
failure_error = await DataConverter.default.decode_failure(failure)
422426
# Confirm type is prepended
423427
assert isinstance(failure_error, ApplicationError)
424428
assert "RuntimeError: error2" == str(failure_error)
@@ -440,3 +444,72 @@ async def test_exception_format():
440444
logging.getLogger(__name__).debug(
441445
"Showing appended exception", exc_info=failure_error
442446
)
447+
448+
449+
# Just serializes in a "payloads" wrapper
450+
class SimpleCodec(PayloadCodec):
451+
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
452+
wrapper = Payloads(payloads=payloads)
453+
return [
454+
Payload(
455+
metadata={"simple-codec": b"true"}, data=wrapper.SerializeToString()
456+
)
457+
]
458+
459+
async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
460+
payloads = list(payloads)
461+
if len(payloads) != 1:
462+
raise RuntimeError("Expected only a single payload")
463+
elif payloads[0].metadata.get("simple-codec") != b"true":
464+
raise RuntimeError("Not encoded with this codec")
465+
wrapper = Payloads()
466+
wrapper.ParseFromString(payloads[0].data)
467+
return list(wrapper.payloads)
468+
469+
470+
async def test_failure_encoded_attributes():
471+
try:
472+
raise ApplicationError("some message", "some detail")
473+
except ApplicationError as err:
474+
some_err = err
475+
476+
conv = DataConverter(
477+
failure_converter_class=DefaultFailureConverterWithEncodedAttributes,
478+
payload_codec=SimpleCodec(),
479+
)
480+
481+
# Check failure
482+
failure = Failure()
483+
conv.failure_converter.to_failure(some_err, conv.payload_converter, failure)
484+
assert failure.message == "Encoded failure"
485+
assert failure.stack_trace == ""
486+
assert conv.payload_converter.from_payloads(
487+
failure.application_failure_info.details.payloads
488+
) == ["some detail"]
489+
encoded_attr = conv.payload_converter.from_payloads([failure.encoded_attributes])[0]
490+
assert encoded_attr["message"] == "some message"
491+
assert "test_converter" in encoded_attr["stack_trace"]
492+
493+
# Encode it and check encoded
494+
orig_failure = Failure()
495+
orig_failure.CopyFrom(failure)
496+
await conv.payload_codec.encode_failure(failure)
497+
assert "encoding" not in failure.encoded_attributes.metadata
498+
assert "simple-codec" in failure.encoded_attributes.metadata
499+
assert (
500+
"encoding" not in failure.application_failure_info.details.payloads[0].metadata
501+
)
502+
assert (
503+
"simple-codec" in failure.application_failure_info.details.payloads[0].metadata
504+
)
505+
506+
# Decode and check
507+
await conv.payload_codec.decode_failure(failure)
508+
assert "encoding" in failure.encoded_attributes.metadata
509+
assert "simple-codec" not in failure.encoded_attributes.metadata
510+
assert "encoding" in failure.application_failure_info.details.payloads[0].metadata
511+
assert (
512+
"simple-codec"
513+
not in failure.application_failure_info.details.payloads[0].metadata
514+
)
515+
assert failure == orig_failure

0 commit comments

Comments
 (0)