From 58bae691899fd1359dcbc4b81b550759a9ffb28e Mon Sep 17 00:00:00 2001 From: Andrew Snare Date: Thu, 24 Oct 2024 15:11:33 +0200 Subject: [PATCH] Update HistoryEncoding to handle string-based type annotations, in addition to the normal class-based ones. This is needed for correct functioning if a dataclass is declared in a __future__.__annotations__ context: here python converts all type hints to strings (at which point the dataclass metaclass captures them) prior to deferred resolution. --- src/databricks/labs/ucx/progress/history.py | 9 ++++++++- tests/unit/progress/test_history.py | 6 +++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/databricks/labs/ucx/progress/history.py b/src/databricks/labs/ucx/progress/history.py index f7c4909556..1dd5efa68a 100644 --- a/src/databricks/labs/ucx/progress/history.py +++ b/src/databricks/labs/ucx/progress/history.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses import datetime as dt +import typing from enum import Enum, EnumMeta import json import logging @@ -100,7 +101,13 @@ def _get_field_names_with_types(cls, klass: type[Record]) -> tuple[dict[str, typ - A dictionary of fields to include in the object data, and their type. - The type of the failures field, if present. """ - field_names_with_types = {field.name: field.type for field in dataclasses.fields(klass)} + # Ignore the field types returned by dataclasses.fields(): it doesn't resolve string-based annotations (which + # are produced automatically in a __future__.__annotations__ context). Unfortunately the dataclass mechanism + # captures the type hints prior to resolution (which happens later in the class initialization process). + # As such, we rely on dataclasses.fields() for the set of field names, but not the types which we fetch directly. + klass_type_hints = typing.get_type_hints(klass) + field_names = [field.name for field in dataclasses.fields(klass)] + field_names_with_types = {field_name: klass_type_hints[field_name] for field_name in field_names} if "failures" not in field_names_with_types: failures_type = None else: diff --git a/tests/unit/progress/test_history.py b/tests/unit/progress/test_history.py index afa2e406d6..01d6284023 100644 --- a/tests/unit/progress/test_history.py +++ b/tests/unit/progress/test_history.py @@ -118,7 +118,7 @@ def test_historical_encoder_object_id(ownership) -> None: class _CompoundKey: a_field: str = "field-a" b_field: str = "field-b" - c_field: str = "field-c" + c_field: "str" = "field-c" # Annotations can be strings as well. @property def d_property(self) -> str: @@ -270,7 +270,7 @@ def test_historical_encoder_object_data_values_strings_as_is(ownership) -> None: @dataclass class _AClass: a_field: str = "value" - existing_json_field: str = "[1, 2, 3]" + existing_json_field: "str" = "[1, 2, 3]" optional_string_field: str | None = "value" __id_attributes__: ClassVar = ("a_field",) @@ -481,7 +481,7 @@ class _BrokenFailures2: __id_attributes__: ClassVar = ("a_field",) -@pytest.mark.parametrize("klass,broken_type", ((_BrokenFailures1, list[int]), (_BrokenFailures2, None))) +@pytest.mark.parametrize("klass,broken_type", ((_BrokenFailures1, list[int]), (_BrokenFailures2, type(None)))) def test_historical_encoder_failures_verification( ownership, klass: type[DataclassWithIdAttributes],