Skip to content

Commit ece2a97

Browse files
authored
Tech debt: let HistoryEncoding work with string-based type annotations in addition to the normal type-based ones (#3068)
## Changes This PR cherry-picks some changes from #3039 that updated the `HistoryEncoder` to work correctly with databases that are declared with `__future__.__annotations__` in effect. When this annotation is in effect, python converts all type-hints during import/declaration into strings and then performs deferred resolution at a later stage. (This is why forward references work.) Unfortunately the dataclass mechanism captures field types prior to deferred resolution. This PR ensures that our type checking works anyway. ### Linked issues Cherry-picks from #3039. ### Tests - updated unit tests
1 parent 3b6bcf3 commit ece2a97

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/databricks/labs/ucx/progress/history.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
import dataclasses
33
import datetime as dt
4+
import typing
45
from enum import Enum, EnumMeta
56
import json
67
import logging
@@ -100,7 +101,13 @@ def _get_field_names_with_types(cls, klass: type[Record]) -> tuple[dict[str, typ
100101
- A dictionary of fields to include in the object data, and their type.
101102
- The type of the failures field, if present.
102103
"""
103-
field_names_with_types = {field.name: field.type for field in dataclasses.fields(klass)}
104+
# Ignore the field types returned by dataclasses.fields(): it doesn't resolve string-based annotations (which
105+
# are produced automatically in a __future__.__annotations__ context). Unfortunately the dataclass mechanism
106+
# captures the type hints prior to resolution (which happens later in the class initialization process).
107+
# As such, we rely on dataclasses.fields() for the set of field names, but not the types which we fetch directly.
108+
klass_type_hints = typing.get_type_hints(klass)
109+
field_names = [field.name for field in dataclasses.fields(klass)]
110+
field_names_with_types = {field_name: klass_type_hints[field_name] for field_name in field_names}
104111
if "failures" not in field_names_with_types:
105112
failures_type = None
106113
else:

tests/unit/progress/test_history.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_historical_encoder_object_id(ownership) -> None:
118118
class _CompoundKey:
119119
a_field: str = "field-a"
120120
b_field: str = "field-b"
121-
c_field: str = "field-c"
121+
c_field: "str" = "field-c" # Annotations can be strings as well.
122122

123123
@property
124124
def d_property(self) -> str:
@@ -270,7 +270,7 @@ def test_historical_encoder_object_data_values_strings_as_is(ownership) -> None:
270270
@dataclass
271271
class _AClass:
272272
a_field: str = "value"
273-
existing_json_field: str = "[1, 2, 3]"
273+
existing_json_field: "str" = "[1, 2, 3]"
274274
optional_string_field: str | None = "value"
275275

276276
__id_attributes__: ClassVar = ("a_field",)
@@ -481,7 +481,7 @@ class _BrokenFailures2:
481481
__id_attributes__: ClassVar = ("a_field",)
482482

483483

484-
@pytest.mark.parametrize("klass,broken_type", ((_BrokenFailures1, list[int]), (_BrokenFailures2, None)))
484+
@pytest.mark.parametrize("klass,broken_type", ((_BrokenFailures1, list[int]), (_BrokenFailures2, type(None))))
485485
def test_historical_encoder_failures_verification(
486486
ownership,
487487
klass: type[DataclassWithIdAttributes],

0 commit comments

Comments
 (0)