|
| 1 | +"""Time zone name validation and serialization module.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import importlib |
| 6 | +import sys |
| 7 | +import warnings |
| 8 | +from typing import Any, Callable, List, Set, Type, cast |
| 9 | + |
| 10 | +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler |
| 11 | +from pydantic_core import PydanticCustomError, core_schema |
| 12 | + |
| 13 | + |
| 14 | +def _is_available(name: str) -> bool: |
| 15 | + """Check if a module is available for import.""" |
| 16 | + try: |
| 17 | + importlib.import_module(name) |
| 18 | + return True |
| 19 | + except ModuleNotFoundError: # pragma: no cover |
| 20 | + return False |
| 21 | + |
| 22 | + |
| 23 | +def _tz_provider_from_zone_info() -> Set[str]: # pragma: no cover |
| 24 | + """Get timezones from the zoneinfo module.""" |
| 25 | + from zoneinfo import available_timezones |
| 26 | + |
| 27 | + return set(available_timezones()) |
| 28 | + |
| 29 | + |
| 30 | +def _tz_provider_from_pytz() -> Set[str]: # pragma: no cover |
| 31 | + """Get timezones from the pytz module.""" |
| 32 | + from pytz import all_timezones |
| 33 | + |
| 34 | + return set(all_timezones) |
| 35 | + |
| 36 | + |
| 37 | +def _warn_about_pytz_usage() -> None: |
| 38 | + """Warn about using pytz with Python 3.9 or later.""" |
| 39 | + warnings.warn( # pragma: no cover |
| 40 | + 'Projects using Python 3.9 or later should be using the support now included as part of the standard library. ' |
| 41 | + 'Please consider switching to the standard library (zoneinfo) module.' |
| 42 | + ) |
| 43 | + |
| 44 | + |
| 45 | +def get_timezones() -> Set[str]: |
| 46 | + """Determine the timezone provider and return available timezones.""" |
| 47 | + if _is_available('zoneinfo') and _is_available('tzdata'): # pragma: no cover |
| 48 | + return _tz_provider_from_zone_info() |
| 49 | + elif _is_available('pytz'): # pragma: no cover |
| 50 | + if sys.version_info[:2] > (3, 8): |
| 51 | + _warn_about_pytz_usage() |
| 52 | + return _tz_provider_from_pytz() |
| 53 | + else: # pragma: no cover |
| 54 | + if sys.version_info[:2] == (3, 8): |
| 55 | + raise ImportError('No pytz module found. Please install it with "pip install pytz"') |
| 56 | + raise ImportError('No timezone provider found. Please install tzdata with "pip install tzdata"') |
| 57 | + |
| 58 | + |
| 59 | +class TimeZoneNameSettings(type): |
| 60 | + def __new__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> Type[TimeZoneName]: |
| 61 | + dct['strict'] = kwargs.pop('strict', True) |
| 62 | + return cast(Type[TimeZoneName], super().__new__(cls, name, bases, dct)) |
| 63 | + |
| 64 | + def __init__(cls, name: str, bases: tuple[type, ...], dct: dict[str, Any], **kwargs: Any) -> None: |
| 65 | + super().__init__(name, bases, dct) |
| 66 | + cls.strict = kwargs.get('strict', True) |
| 67 | + |
| 68 | + |
| 69 | +def timezone_name_settings(**kwargs: Any) -> Callable[[Type[TimeZoneName]], Type[TimeZoneName]]: |
| 70 | + def wrapper(cls: Type[TimeZoneName]) -> Type[TimeZoneName]: |
| 71 | + cls.strict = kwargs.get('strict', True) |
| 72 | + return cls |
| 73 | + |
| 74 | + return wrapper |
| 75 | + |
| 76 | + |
| 77 | +@timezone_name_settings(strict=True) |
| 78 | +class TimeZoneName(str): |
| 79 | + """ |
| 80 | + TimeZoneName is a custom string subclass for validating and serializing timezone names. |
| 81 | +
|
| 82 | + The TimeZoneName class uses the IANA Time Zone Database for validation. |
| 83 | + It supports both strict and non-strict modes for timezone name validation. |
| 84 | +
|
| 85 | +
|
| 86 | + ## Examples: |
| 87 | +
|
| 88 | + Some examples of using the TimeZoneName class: |
| 89 | +
|
| 90 | + ### Normal usage: |
| 91 | +
|
| 92 | + ```python |
| 93 | + from pydantic_extra_types.timezone_name import TimeZoneName |
| 94 | + from pydantic import BaseModel |
| 95 | + class Location(BaseModel): |
| 96 | + city: str |
| 97 | + timezone: TimeZoneName |
| 98 | +
|
| 99 | + loc = Location(city="New York", timezone="America/New_York") |
| 100 | + print(loc.timezone) |
| 101 | +
|
| 102 | + >> America/New_York |
| 103 | +
|
| 104 | + ``` |
| 105 | +
|
| 106 | + ### Non-strict mode: |
| 107 | +
|
| 108 | + ```python |
| 109 | +
|
| 110 | + from pydantic_extra_types.timezone_name import TimeZoneName, timezone_name_settings |
| 111 | +
|
| 112 | + @timezone_name_settings(strict=False) |
| 113 | + class TZNonStrict(TimeZoneName): |
| 114 | + pass |
| 115 | +
|
| 116 | + tz = TZNonStrict("america/new_york") |
| 117 | +
|
| 118 | + print(tz) |
| 119 | +
|
| 120 | + >> america/new_york |
| 121 | +
|
| 122 | + ``` |
| 123 | + """ |
| 124 | + |
| 125 | + __slots__: List[str] = [] |
| 126 | + allowed_values: Set[str] = set(get_timezones()) |
| 127 | + allowed_values_list: List[str] = sorted(allowed_values) |
| 128 | + allowed_values_upper_to_correct: dict[str, str] = {val.upper(): val for val in allowed_values} |
| 129 | + strict: bool |
| 130 | + |
| 131 | + @classmethod |
| 132 | + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName: |
| 133 | + """ |
| 134 | + Validate a time zone name from the provided str value. |
| 135 | +
|
| 136 | + Args: |
| 137 | + __input_value: The str value to be validated. |
| 138 | + _: The Pydantic ValidationInfo. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + The validated time zone name. |
| 142 | +
|
| 143 | + Raises: |
| 144 | + PydanticCustomError: If the timezone name is not valid. |
| 145 | + """ |
| 146 | + if __input_value not in cls.allowed_values: # be fast for the most common case |
| 147 | + if not cls.strict: |
| 148 | + upper_value = __input_value.strip().upper() |
| 149 | + if upper_value in cls.allowed_values_upper_to_correct: |
| 150 | + return cls(cls.allowed_values_upper_to_correct[upper_value]) |
| 151 | + raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.') |
| 152 | + return cls(__input_value) |
| 153 | + |
| 154 | + @classmethod |
| 155 | + def __get_pydantic_core_schema__( |
| 156 | + cls, _: Type[Any], __: GetCoreSchemaHandler |
| 157 | + ) -> core_schema.AfterValidatorFunctionSchema: |
| 158 | + """ |
| 159 | + Return a Pydantic CoreSchema with the timezone name validation. |
| 160 | +
|
| 161 | + Args: |
| 162 | + _: The source type. |
| 163 | + __: The handler to get the CoreSchema. |
| 164 | +
|
| 165 | + Returns: |
| 166 | + A Pydantic CoreSchema with the timezone name validation. |
| 167 | + """ |
| 168 | + return core_schema.with_info_after_validator_function( |
| 169 | + cls._validate, |
| 170 | + core_schema.str_schema(min_length=1), |
| 171 | + ) |
| 172 | + |
| 173 | + @classmethod |
| 174 | + def __get_pydantic_json_schema__( |
| 175 | + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler |
| 176 | + ) -> dict[str, Any]: |
| 177 | + """ |
| 178 | + Return a Pydantic JSON Schema with the timezone name validation. |
| 179 | +
|
| 180 | + Args: |
| 181 | + schema: The Pydantic CoreSchema. |
| 182 | + handler: The handler to get the JSON Schema. |
| 183 | +
|
| 184 | + Returns: |
| 185 | + A Pydantic JSON Schema with the timezone name validation. |
| 186 | + """ |
| 187 | + json_schema = handler(schema) |
| 188 | + json_schema.update({'enum': cls.allowed_values_list}) |
| 189 | + return json_schema |
0 commit comments