Skip to content

Commit c7bb753

Browse files
committed
feat: Support Decimal type in coordinates
1 parent 6a66fa7 commit c7bb753

File tree

2 files changed

+127
-23
lines changed

2 files changed

+127
-23
lines changed

pydantic_extra_types/coordinate.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,25 @@
66
from __future__ import annotations
77

88
from dataclasses import dataclass
9-
from typing import Any, ClassVar, Tuple
9+
from decimal import Decimal
10+
from typing import Any, ClassVar, Tuple, Union
1011

1112
from pydantic import GetCoreSchemaHandler
1213
from pydantic._internal import _repr
1314
from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema
1415

16+
LatitudeType = Union[float, Decimal]
17+
LongitudeType = Union[float, Decimal]
18+
CoordinateType = Tuple[LatitudeType, LongitudeType]
19+
1520

1621
class Latitude(float):
1722
"""Latitude value should be between -90 and 90, inclusive.
1823
24+
Supports both float and Decimal types.
25+
1926
```py
27+
from decimal import Decimal
2028
from pydantic import BaseModel
2129
from pydantic_extra_types.coordinate import Latitude
2230
@@ -25,9 +33,10 @@ class Location(BaseModel):
2533
latitude: Latitude
2634
2735
28-
location = Location(latitude=41.40338)
29-
print(location)
30-
# > latitude=41.40338
36+
# Using float
37+
location1 = Location(latitude=41.40338)
38+
# Using Decimal
39+
location2 = Location(latitude=Decimal('41.40338'))
3140
```
3241
"""
3342

@@ -36,13 +45,21 @@ class Location(BaseModel):
3645

3746
@classmethod
3847
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
39-
return core_schema.float_schema(ge=cls.min, le=cls.max)
48+
return core_schema.union_schema(
49+
[
50+
core_schema.float_schema(ge=cls.min, le=cls.max),
51+
core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)),
52+
]
53+
)
4054

4155

4256
class Longitude(float):
4357
"""Longitude value should be between -180 and 180, inclusive.
4458
59+
Supports both float and Decimal types.
60+
4561
```py
62+
from decimal import Decimal
4663
from pydantic import BaseModel
4764
4865
from pydantic_extra_types.coordinate import Longitude
@@ -52,9 +69,10 @@ class Location(BaseModel):
5269
longitude: Longitude
5370
5471
55-
location = Location(longitude=2.17403)
56-
print(location)
57-
# > longitude=2.17403
72+
# Using float
73+
location1 = Location(longitude=2.17403)
74+
# Using Decimal
75+
location2 = Location(longitude=Decimal('2.17403'))
5876
```
5977
"""
6078

@@ -63,7 +81,12 @@ class Location(BaseModel):
6381

6482
@classmethod
6583
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
66-
return core_schema.float_schema(ge=cls.min, le=cls.max)
84+
return core_schema.union_schema(
85+
[
86+
core_schema.float_schema(ge=cls.min, le=cls.max),
87+
core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)),
88+
]
89+
)
6790

6891

6992
@dataclass
@@ -73,10 +96,11 @@ class Coordinate(_repr.Representation):
7396
You can use the `Coordinate` data type for storing coordinates. Coordinates can be
7497
defined using one of the following formats:
7598
76-
1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)`.
99+
1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)` or `(Decimal('41.40338'), Decimal('2.17403'))`.
77100
2. `Coordinate` instance: `Coordinate(latitude=Latitude, longitude=Longitude)`.
78101
79102
```py
103+
from decimal import Decimal
80104
from pydantic import BaseModel
81105
82106
from pydantic_extra_types.coordinate import Coordinate
@@ -86,7 +110,12 @@ class Location(BaseModel):
86110
coordinate: Coordinate
87111
88112
89-
location = Location(coordinate=(41.40338, 2.17403))
113+
# Using float values
114+
location1 = Location(coordinate=(41.40338, 2.17403))
115+
# > coordinate=Coordinate(latitude=41.40338, longitude=2.17403)
116+
117+
# Using Decimal values
118+
location2 = Location(coordinate=(Decimal('41.40338'), Decimal('2.17403')))
90119
# > coordinate=Coordinate(latitude=41.40338, longitude=2.17403)
91120
```
92121
"""
@@ -102,7 +131,7 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH
102131
core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()),
103132
core_schema.no_info_wrap_validator_function(
104133
cls._parse_tuple,
105-
handler.generate_schema(Tuple[float, float]),
134+
handler.generate_schema(CoordinateType),
106135
),
107136
handler(source),
108137
]

tests/test_coordinate.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from decimal import Decimal
12
from re import Pattern
2-
from typing import Any, Optional
3+
from typing import Any, Optional, Union
34

45
import pytest
56
from pydantic import BaseModel, ValidationError
@@ -34,7 +35,14 @@ class Lng(BaseModel):
3435
(Coordinate(latitude=0, longitude=0), (0, 0), None),
3536
(ArgsKwargs(args=()), (0, 0), None),
3637
(ArgsKwargs(args=(1, 0.0)), (1.0, 0), None),
37-
# # Invalid coordinates
38+
# Decimal test cases
39+
((Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None),
40+
((Decimal('-90.0'), Decimal('0.0')), (Decimal('-90.0'), Decimal('0.0')), None),
41+
((Decimal('45.678'), Decimal('-123.456')), (Decimal('45.678'), Decimal('-123.456')), None),
42+
(Coordinate(Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None),
43+
(Coordinate(latitude=Decimal('0'), longitude=Decimal('0')), (Decimal('0'), Decimal('0')), None),
44+
(ArgsKwargs(args=(Decimal('1'), Decimal('0.0'))), (Decimal('1.0'), Decimal('0.0')), None),
45+
# Invalid coordinates
3846
((), None, 'Field required'), # Empty tuple
3947
((10.0,), None, 'Field required'), # Tuple with only one value
4048
(('ten, '), None, 'string is not recognized as a valid coordinate'),
@@ -49,10 +57,11 @@ class Lng(BaseModel):
4957
(2, None, 'Input should be a dictionary or an instance of Coordinate'), # Wrong type
5058
],
5159
)
52-
def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error: Optional[Pattern]):
60+
def test_format_for_coordinate(
61+
coord: (Any, Any), result: (Union[float, Decimal], Union[float, Decimal]), error: Optional[Pattern]
62+
):
5363
if error is None:
5464
_coord: Coordinate = Coord(coord=coord).coord
55-
print('vars(_coord)', vars(_coord))
5665
assert _coord.latitude == result[0]
5766
assert _coord.longitude == result[1]
5867
else:
@@ -69,6 +78,16 @@ def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error:
6978
# Invalid coordinates
7079
((-91.0, 0.0), 'Input should be greater than or equal to -90'),
7180
((50.0, 181.0), 'Input should be less than or equal to 180'),
81+
# Valid Decimal coordinates
82+
((Decimal('-90.0'), Decimal('0.0')), None),
83+
((Decimal('50.0'), Decimal('180.0')), None),
84+
((Decimal('-89.999999'), Decimal('179.999999')), None),
85+
((Decimal('0.0'), Decimal('0.0')), None),
86+
# Invalid Decimal coordinates
87+
((Decimal('-90.1'), Decimal('0.0')), 'Input should be greater than or equal to -90'),
88+
((Decimal('50.0'), Decimal('180.1')), 'Input should be less than or equal to 180'),
89+
((Decimal('90.1'), Decimal('0.0')), 'Input should be less than or equal to 90'),
90+
((Decimal('0.0'), Decimal('-180.1')), 'Input should be greater than or equal to -180'),
7291
],
7392
)
7493
def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]):
@@ -91,17 +110,21 @@ def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]):
91110
('90.0', True),
92111
(-90.0, True),
93112
('-90.0', True),
113+
(Decimal('90.0'), True),
114+
(Decimal('-90.0'), True),
94115
# Unvalid latitude
95116
(91.0, False),
96117
(-91.0, False),
118+
(Decimal('91.0'), False),
119+
(Decimal('-91.0'), False),
97120
],
98121
)
99122
def test_format_latitude(latitude: float, valid: bool):
100123
if valid:
101124
_lat = Lat(lat=latitude).lat
102125
assert _lat == float(latitude)
103126
else:
104-
with pytest.raises(ValidationError, match='1 validation error for Lat'):
127+
with pytest.raises(ValidationError, match='2 validation errors for Lat'):
105128
Lat(lat=latitude)
106129

107130

@@ -119,46 +142,89 @@ def test_format_latitude(latitude: float, valid: bool):
119142
(-91.0, True),
120143
(180.0, True),
121144
(-180.0, True),
145+
(Decimal('180.0'), True),
146+
(Decimal('-180.0'), True),
122147
# Unvalid latitude
123148
(181.0, False),
124149
(-181.0, False),
150+
(Decimal('181.0'), False),
151+
(Decimal('-181.0'), False),
125152
],
126153
)
127154
def test_format_longitude(longitude: float, valid: bool):
128155
if valid:
129156
_lng = Lng(lng=longitude).lng
130157
assert _lng == float(longitude)
131158
else:
132-
with pytest.raises(ValidationError, match='1 validation error for Lng'):
159+
with pytest.raises(ValidationError, match='2 validation errors for Lng'):
133160
Lng(lng=longitude)
134161

135162

136163
def test_str_repr():
164+
# Float tests
137165
assert str(Coord(coord=(20.0, 10.0)).coord) == '20.0,10.0'
138166
assert str(Coord(coord=('20.0, 10.0')).coord) == '20.0,10.0'
139167
assert repr(Coord(coord=(20.0, 10.0)).coord) == 'Coordinate(latitude=20.0, longitude=10.0)'
168+
# Decimal tests
169+
assert str(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == '20.0,10.0'
170+
assert str(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == '20.000,10.000'
171+
assert (
172+
repr(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord)
173+
== "Coordinate(latitude=Decimal('20.0'), longitude=Decimal('10.0'))"
174+
)
140175

141176

142177
def test_eq():
178+
# Float tests
143179
assert Coord(coord=(20.0, 10.0)).coord != Coord(coord='20.0,11.0').coord
144180
assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord
145181
assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord
146182
assert Coord(coord=(20.0, 10.0)).coord == Coord(coord='20.0,10.0').coord
147183

184+
# Decimal tests
185+
assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord='20.0,10.0').coord
186+
assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord=(20.0, 10.0)).coord
187+
assert (
188+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord != Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord
189+
)
190+
assert (
191+
Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord
192+
== Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
193+
)
194+
148195

149196
def test_hashable():
197+
# Float tests
150198
assert hash(Coord(coord=(20.0, 10.0)).coord) == hash(Coord(coord=(20.0, 10.0)).coord)
151199
assert hash(Coord(coord=(20.0, 11.0)).coord) != hash(Coord(coord=(20.0, 10.0)).coord)
152200

201+
# Decimal tests
202+
assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(
203+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
204+
)
205+
assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(Coord(coord=(20.0, 10.0)).coord)
206+
assert hash(Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord) != hash(
207+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
208+
)
209+
assert hash(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == hash(
210+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
211+
)
212+
153213

154214
def test_json_schema():
155215
class Model(BaseModel):
156216
value: Coordinate
157217

158218
assert Model.model_json_schema(mode='validation')['$defs']['Coordinate'] == {
159219
'properties': {
160-
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
161-
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
220+
'latitude': {
221+
'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}],
222+
'title': 'Latitude',
223+
},
224+
'longitude': {
225+
'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}],
226+
'title': 'Longitude',
227+
},
162228
},
163229
'required': ['latitude', 'longitude'],
164230
'title': 'Coordinate',
@@ -170,7 +236,10 @@ class Model(BaseModel):
170236
{
171237
'maxItems': 2,
172238
'minItems': 2,
173-
'prefixItems': [{'type': 'number'}, {'type': 'number'}],
239+
'prefixItems': [
240+
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
241+
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
242+
],
174243
'type': 'array',
175244
},
176245
{'type': 'string'},
@@ -181,8 +250,14 @@ class Model(BaseModel):
181250
'$defs': {
182251
'Coordinate': {
183252
'properties': {
184-
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
185-
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
253+
'latitude': {
254+
'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}],
255+
'title': 'Latitude',
256+
},
257+
'longitude': {
258+
'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}],
259+
'title': 'Longitude',
260+
},
186261
},
187262
'required': ['latitude', 'longitude'],
188263
'title': 'Coordinate',

0 commit comments

Comments
 (0)