Skip to content

Commit 067b2c2

Browse files
committed
Rework payload validation to use declared definitions.
1 parent b9c6c14 commit 067b2c2

File tree

1 file changed

+63
-24
lines changed

1 file changed

+63
-24
lines changed

mig/services/coreapi/server.py

+63-24
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import cgi
4444
import cgitb
4545
import codecs
46-
from collections import defaultdict, namedtuple
46+
from collections import defaultdict, namedtuple, OrderedDict
4747
from flask import Flask, request, Response
4848
from functools import partial, update_wrapper
4949
import os
@@ -158,6 +158,58 @@ def invalid_argument(arg):
158158
raise ValueError("Unexpected query variable: %s" % quoteattr(arg))
159159

160160

161+
class PayloadDefinition:
162+
def __init__(self, name, positional):
163+
self._item_checks = []
164+
self._item_names = []
165+
166+
for name, validator_fn in positional:
167+
self._item_names.append(name)
168+
self._item_checks.append(validator_fn)
169+
170+
@property
171+
def _fields(self):
172+
return self._item_names
173+
174+
@property
175+
def _validators(self):
176+
return self._item_checks
177+
178+
def __call__(self, *args):
179+
return self._extract_and_bundle(args, extract_by='position')
180+
181+
def _extract_and_bundle(self, args, extract_by=None):
182+
if extract_by == 'position':
183+
keys_to_bundle = list(range(len(args)))
184+
elif extract_by == 'name':
185+
keys_to_bundle = self._item_names
186+
elif extract_by == 'short':
187+
keys_to_bundle = self._item_short
188+
else:
189+
raise RuntimeError()
190+
191+
return Payload.from_args(self, args, keys_to_bundle)
192+
193+
@staticmethod
194+
def bundle_generic(thekeys):
195+
for key in thekeys:
196+
pass
197+
198+
199+
class Payload(OrderedDict):
200+
def __init__(self, definition, dictionary):
201+
super().__init__(dictionary)
202+
self._definition = definition
203+
204+
def __iter__(self):
205+
return iter(self.values())
206+
207+
@staticmethod
208+
def from_args(definition, args, keys):
209+
dictionary = {key:args[key] for key in keys}
210+
return Payload(definition, dictionary)
211+
212+
161213
class ValidationReport(RuntimeError):
162214
def __init__(self, errors_by_field):
163215
self.errors_by_field = errors_by_field
@@ -181,28 +233,17 @@ def _is_string_and_non_empty(value):
181233
return isinstance(value, str) and len(value) > 0
182234

183235

184-
_REQUEST_ARGS_POST_USER = namedtuple('PostUserArgs', [
185-
'full_name',
186-
'organization',
187-
'state',
188-
'country',
189-
'email',
190-
'comment',
191-
'password',
236+
_REQUEST_ARGS_POST_USER = PayloadDefinition('PostUserArgs', [
237+
('full_name', _is_string_and_non_empty),
238+
('organization', _is_string_and_non_empty),
239+
('state', _is_string_and_non_empty),
240+
('country', _is_string_and_non_empty),
241+
('email', _is_string_and_non_empty),
242+
('comment', _is_string_and_non_empty),
243+
('password', _is_string_and_non_empty),
192244
])
193245

194246

195-
_REQUEST_ARGS_POST_USER._validators = defaultdict(lambda: _is_not_none, dict(
196-
full_name=_is_string_and_non_empty,
197-
organization=_is_string_and_non_empty,
198-
state=_is_string_and_non_empty,
199-
country=_is_string_and_non_empty,
200-
email=_is_string_and_non_empty,
201-
comment=_is_string_and_non_empty,
202-
password=_is_string_and_non_empty,
203-
))
204-
205-
206247
def search_users(configuration, search_filter):
207248
conf_path = configuration.config_file
208249
db_path = default_db_path(configuration)
@@ -214,7 +255,7 @@ def validate_payload(definition, payload):
214255
args = definition(*[payload.get(field, None) for field in definition._fields])
215256

216257
errors_by_field = {}
217-
for field_name, field_value in args._asdict().items():
258+
for field_name, field_value in args.items():
218259
validator_fn = definition._validators[field_name]
219260
if not validator_fn(field_value):
220261
errors_by_field[field_name] = validator_fn.__doc__
@@ -257,9 +298,7 @@ def POST_user():
257298
except ValidationReport as vr:
258299
return http_error_from_status_code(400, None, vr.serialize())
259300

260-
args = list(validated)
261-
262-
ret = createuser(configuration, args)
301+
ret = createuser(configuration, validated)
263302
if ret != 0:
264303
raise http_error_from_status_code(400, None)
265304

0 commit comments

Comments
 (0)