Skip to content

Commit d2be48c

Browse files
authored
Support model class inheritance. Fixes #164 (#862)
1 parent 9d57373 commit d2be48c

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

pynamodb/models.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pynamodb.expressions.update import Action
1313
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError
1414
from pynamodb.attributes import (
15-
Attribute, AttributeContainer, AttributeContainerMeta, TTLAttribute, VersionAttribute
15+
AttributeContainer, AttributeContainerMeta, TTLAttribute, VersionAttribute
1616
)
1717
from pynamodb.connection.table import TableConnection
1818
from pynamodb.expressions.condition import Condition
@@ -151,10 +151,6 @@ def commit(self) -> None:
151151
unprocessed_items = data.get(UNPROCESSED_ITEMS, {}).get(self.model.Meta.table_name)
152152

153153

154-
class DefaultMeta(object):
155-
pass
156-
157-
158154
class MetaModel(AttributeContainerMeta):
159155
table_name: str
160156
read_capacity_units: Optional[int]
@@ -184,17 +180,26 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
184180
cls = cast(Type['Model'], self)
185181
for attr_name, attribute in cls.get_attributes().items():
186182
if attribute.is_hash_key:
183+
if cls._hash_keyname and cls._hash_keyname != attr_name:
184+
raise ValueError(f"{cls.__name__} has more than one hash key: {cls._hash_keyname}, {attr_name}")
187185
cls._hash_keyname = attr_name
188186
if attribute.is_range_key:
187+
if cls._range_keyname and cls._range_keyname != attr_name:
188+
raise ValueError(f"{cls.__name__} has more than one range key: {cls._range_keyname}, {attr_name}")
189189
cls._range_keyname = attr_name
190190
if isinstance(attribute, VersionAttribute):
191-
if cls._version_attribute_name:
191+
if cls._version_attribute_name and cls._version_attribute_name != attr_name:
192192
raise ValueError(
193193
"The model has more than one Version attribute: {}, {}"
194194
.format(cls._version_attribute_name, attr_name)
195195
)
196196
cls._version_attribute_name = attr_name
197197

198+
ttl_attr_names = [name for name, attr in cls.get_attributes().items() if isinstance(attr, TTLAttribute)]
199+
if len(ttl_attr_names) > 1:
200+
raise ValueError("{} has more than one TTL attribute: {}".format(
201+
cls.__name__, ", ".join(ttl_attr_names)))
202+
198203
if isinstance(attrs, dict):
199204
for attr_name, attr_obj in attrs.items():
200205
if attr_name == META_CLASS_NAME:
@@ -226,16 +231,6 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
226231
attr_obj.Meta.model = cls
227232
if not hasattr(attr_obj.Meta, "index_name"):
228233
attr_obj.Meta.index_name = attr_name
229-
elif isinstance(attr_obj, Attribute):
230-
if attr_obj.attr_name is None:
231-
attr_obj.attr_name = attr_name
232-
233-
ttl_attr_names = [name for name, attr_obj in attrs.items() if isinstance(attr_obj, TTLAttribute)]
234-
if len(ttl_attr_names) > 1:
235-
raise ValueError("The model has more than one TTL attribute: {}".format(", ".join(ttl_attr_names)))
236-
237-
if META_CLASS_NAME not in attrs:
238-
setattr(cls, META_CLASS_NAME, DefaultMeta)
239234

240235
# create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist,
241236
# so that "except Model.DoesNotExist:" would not catch other models' exceptions

tests/test_model.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Test model API
33
"""
44
import base64
5-
import random
65
import json
76
import copy
87
from datetime import datetime
@@ -14,9 +13,7 @@
1413
import pytest
1514

1615
from .deep_eq import deep_eq
17-
from pynamodb.util import snake_to_camel_case
1816
from pynamodb.exceptions import DoesNotExist, TableError, PutError, AttributeDeserializationError
19-
from pynamodb.types import RANGE
2017
from pynamodb.constants import (
2118
ITEM, STRING, ALL, KEYS_ONLY, INCLUDE, REQUEST_ITEMS, UNPROCESSED_KEYS, CAMEL_COUNT,
2219
RESPONSES, KEYS, ITEMS, LAST_EVALUATED_KEY, EXCLUSIVE_START_KEY, ATTRIBUTES, BINARY,
@@ -2424,6 +2421,17 @@ def test_old_style_model_exception(self):
24242421
with self.assertRaises(AttributeError):
24252422
OldStyleModel.exists()
24262423

2424+
def test_no_table_name_exception(self):
2425+
"""
2426+
Display warning for Models without table names
2427+
"""
2428+
class MissingTableNameModel(Model):
2429+
class Meta:
2430+
pass
2431+
user_name = UnicodeAttribute(hash_key=True)
2432+
with self.assertRaises(AttributeError):
2433+
MissingTableNameModel.exists()
2434+
24272435
def _get_office_employee(self):
24282436
justin = Person(
24292437
fname='Justin',
@@ -3214,6 +3222,24 @@ def test_deserialized_with_ttl(self):
32143222
def test_deserialized_with_invalid_type(self):
32153223
self.assertRaises(AttributeDeserializationError, TTLModel.from_raw_data, {'my_ttl': {'S': '1546300800'}})
32163224

3225+
def test_multiple_hash_keys(self):
3226+
with self.assertRaises(ValueError):
3227+
class BadHashKeyModel(Model):
3228+
class Meta:
3229+
table_name = 'BadHashKeyModel'
3230+
3231+
foo = UnicodeAttribute(hash_key=True)
3232+
bar = UnicodeAttribute(hash_key=True)
3233+
3234+
def test_multiple_range_keys(self):
3235+
with self.assertRaises(ValueError):
3236+
class BadRangeKeyModel(Model):
3237+
class Meta:
3238+
table_name = 'BadRangeKeyModel'
3239+
3240+
foo = UnicodeAttribute(range_key=True)
3241+
bar = UnicodeAttribute(range_key=True)
3242+
32173243
def test_multiple_version_attributes(self):
32183244
with self.assertRaises(ValueError):
32193245
class BadVersionedModel(Model):
@@ -3222,3 +3248,11 @@ class Meta:
32223248

32233249
version = VersionAttribute()
32243250
another_version = VersionAttribute()
3251+
3252+
def test_inherit_metaclass(self):
3253+
class ParentModel(Model):
3254+
class Meta:
3255+
table_name = 'foo'
3256+
class ChildModel(ParentModel):
3257+
pass
3258+
self.assertEqual(ParentModel.Meta.table_name, ChildModel.Meta.table_name)

0 commit comments

Comments
 (0)