Skip to content

Commit 9a5dc09

Browse files
committed
Split generic tests into it's own file.
1 parent f2734cc commit 9a5dc09

File tree

2 files changed

+87
-74
lines changed

2 files changed

+87
-74
lines changed

tests/test_class_schema.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from marshmallow.fields import List as ListField
1818
from marshmallow.validate import Validator
1919

20-
from marshmallow_dataclass import NewType, _is_generic_alias_of_dataclass, class_schema
20+
from marshmallow_dataclass import NewType, class_schema
2121

2222

2323
class TestClassSchema(unittest.TestCase):
@@ -460,79 +460,6 @@ class Meta:
460460
self.assertNotIn("no_init", class_schema(NoInit)().fields)
461461
self.assertIn("no_init", class_schema(Init)().fields)
462462

463-
def test_generic_dataclass(self):
464-
T = typing.TypeVar("T")
465-
466-
@dataclasses.dataclass
467-
class SimpleGeneric(typing.Generic[T]):
468-
data: T
469-
470-
@dataclasses.dataclass
471-
class NestedFixed:
472-
data: SimpleGeneric[int]
473-
474-
@dataclasses.dataclass
475-
class NestedGeneric(typing.Generic[T]):
476-
data: SimpleGeneric[T]
477-
478-
self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
479-
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))
480-
481-
schema_s = class_schema(SimpleGeneric[str])()
482-
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
483-
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
484-
with self.assertRaises(ValidationError):
485-
schema_s.load({"data": 2})
486-
487-
schema_nested = class_schema(NestedFixed)()
488-
self.assertEqual(
489-
NestedFixed(data=SimpleGeneric(1)),
490-
schema_nested.load({"data": {"data": 1}}),
491-
)
492-
self.assertEqual(
493-
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
494-
{"data": {"data": 1}},
495-
)
496-
with self.assertRaises(ValidationError):
497-
schema_nested.load({"data": {"data": "str"}})
498-
499-
schema_nested_generic = class_schema(NestedGeneric[int])()
500-
self.assertEqual(
501-
NestedGeneric(data=SimpleGeneric(1)),
502-
schema_nested_generic.load({"data": {"data": 1}}),
503-
)
504-
self.assertEqual(
505-
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
506-
{"data": {"data": 1}},
507-
)
508-
with self.assertRaises(ValidationError):
509-
schema_nested_generic.load({"data": {"data": "str"}})
510-
511-
def test_generic_dataclass_repeated_fields(self):
512-
T = typing.TypeVar("T")
513-
514-
@dataclasses.dataclass
515-
class AA:
516-
a: int
517-
518-
@dataclasses.dataclass
519-
class BB(typing.Generic[T]):
520-
b: T
521-
522-
@dataclasses.dataclass
523-
class Nested:
524-
x: BB[float]
525-
z: BB[float]
526-
# if y is the first field in this class, deserialisation will fail.
527-
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
528-
y: BB[AA]
529-
530-
schema_nested = class_schema(Nested)()
531-
self.assertEqual(
532-
Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))),
533-
schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}),
534-
)
535-
536463

537464
if __name__ == "__main__":
538465
unittest.main()

tests/test_generics.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import dataclasses
2+
import typing
3+
import unittest
4+
5+
from marshmallow import ValidationError
6+
7+
from marshmallow_dataclass import _is_generic_alias_of_dataclass, class_schema
8+
9+
10+
class TestGenerics(unittest.TestCase):
11+
def test_generic_dataclass(self):
12+
T = typing.TypeVar("T")
13+
14+
@dataclasses.dataclass
15+
class SimpleGeneric(typing.Generic[T]):
16+
data: T
17+
18+
@dataclasses.dataclass
19+
class NestedFixed:
20+
data: SimpleGeneric[int]
21+
22+
@dataclasses.dataclass
23+
class NestedGeneric(typing.Generic[T]):
24+
data: SimpleGeneric[T]
25+
26+
self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
27+
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))
28+
29+
schema_s = class_schema(SimpleGeneric[str])()
30+
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
31+
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
32+
with self.assertRaises(ValidationError):
33+
schema_s.load({"data": 2})
34+
35+
schema_nested = class_schema(NestedFixed)()
36+
self.assertEqual(
37+
NestedFixed(data=SimpleGeneric(1)),
38+
schema_nested.load({"data": {"data": 1}}),
39+
)
40+
self.assertEqual(
41+
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
42+
{"data": {"data": 1}},
43+
)
44+
with self.assertRaises(ValidationError):
45+
schema_nested.load({"data": {"data": "str"}})
46+
47+
schema_nested_generic = class_schema(NestedGeneric[int])()
48+
self.assertEqual(
49+
NestedGeneric(data=SimpleGeneric(1)),
50+
schema_nested_generic.load({"data": {"data": 1}}),
51+
)
52+
self.assertEqual(
53+
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
54+
{"data": {"data": 1}},
55+
)
56+
with self.assertRaises(ValidationError):
57+
schema_nested_generic.load({"data": {"data": "str"}})
58+
59+
def test_generic_dataclass_repeated_fields(self):
60+
T = typing.TypeVar("T")
61+
62+
@dataclasses.dataclass
63+
class AA:
64+
a: int
65+
66+
@dataclasses.dataclass
67+
class BB(typing.Generic[T]):
68+
b: T
69+
70+
@dataclasses.dataclass
71+
class Nested:
72+
x: BB[float]
73+
z: BB[float]
74+
# if y is the first field in this class, deserialisation will fail.
75+
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
76+
y: BB[AA]
77+
78+
schema_nested = class_schema(Nested)()
79+
self.assertEqual(
80+
Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))),
81+
schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}),
82+
)
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

0 commit comments

Comments
 (0)