Skip to content

Commit c1182a2

Browse files
committed
Implement a user creation endpoint in the server.
Back a POST handler onto the refactored createuser functionality file which exposes a function that can be invoked programatically with a configuration.
1 parent d8a1c53 commit c1182a2

File tree

2 files changed

+120
-21
lines changed

2 files changed

+120
-21
lines changed

mig/services/coreapi/server.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import cgi
4444
import cgitb
4545
import codecs
46+
from collections import defaultdict, namedtuple
4647
from flask import Flask, request, Response
4748
from functools import partial, update_wrapper
4849
import os
@@ -56,7 +57,7 @@
5657
from wsgiref.simple_server import WSGIRequestHandler
5758

5859
from mig.shared.accountstate import check_account_accessible
59-
from mig.shared.base import client_dir_id, client_id_dir, cert_field_map
60+
from mig.shared.base import canonical_user, client_dir_id, client_id_dir, cert_field_map
6061
from mig.shared.conf import get_configuration_object
6162
from mig.shared.compat import PY2
6263
from mig.shared.griddaemons.openid import default_max_user_hits, \
@@ -72,10 +73,11 @@
7273
valid_complex_url, InputException
7374
from mig.shared.tlsserver import hardened_ssl_context
7475
from mig.shared.url import urlparse, urlencode, parse_qsl
75-
from mig.shared.useradm import create_user, get_any_oid_user_dn, check_password_scramble, \
76+
from mig.shared.useradm import get_any_oid_user_dn, check_password_scramble, \
7677
check_hash
7778
from mig.shared.userdb import default_db_path
7879
from mig.shared.validstring import possible_user_id, is_valid_email_address
80+
from mig.server.createuser import _main as createuser
7981

8082
# Update with extra fields
8183
cert_field_map.update({'role': 'ROLE', 'timezone': 'TZ', 'nickname': 'NICK',
@@ -114,8 +116,8 @@ def _ensure_encoded_string(chunk):
114116
exc.code: exc for exc in httpexceptions.__dict__.values() if hasattr(exc, 'code')}
115117

116118

117-
def http_error_from_status_code(http_status_code, http_url):
118-
return httpexceptions_by_code[http_status_code]()
119+
def http_error_from_status_code(http_status_code, http_url, description=None):
120+
return httpexceptions_by_code[http_status_code](description)
119121

120122

121123
def quoteattr(val):
@@ -156,7 +158,66 @@ def invalid_argument(arg):
156158
raise ValueError("Unexpected query variable: %s" % quoteattr(arg))
157159

158160

159-
def _create_and_expose_server(configuration):
161+
class ValidationReport(RuntimeError):
162+
def __init__(self, errors_by_field):
163+
self.errors_by_field = errors_by_field
164+
165+
def serialize(self, output_format='text'):
166+
if output_format == 'json':
167+
return dict(errors=self.errors_by_field)
168+
else:
169+
lines = ["- %s: required %s" % (k, v) for k, v in self.errors_by_field.items()]
170+
lines.insert(0, '')
171+
return 'payload failed to validate:%s' % ('\n'.join(lines),)
172+
173+
174+
def _is_not_none(value):
175+
"""value is not None"""
176+
return value is not None
177+
178+
179+
def _is_string_and_non_empty(value):
180+
"""value is a non-empty string"""
181+
return isinstance(value, str) and len(value) > 0
182+
183+
184+
_REQUEST_ARGS_POST_USER = namedtuple('PostUserArgs', [
185+
'full_name',
186+
'organization',
187+
'state',
188+
'country',
189+
'email',
190+
'comment',
191+
'password',
192+
])
193+
194+
195+
_REQUEST_ARGS_POST_USER._validators = defaultdict(lambda: _is_not_none, dict(
196+
full_name=_is_string_and_non_empty,
197+
organization=_is_string_and_non_empty,
198+
state=_is_string_and_non_empty,
199+
country=_is_string_and_non_empty,
200+
email=_is_string_and_non_empty,
201+
comment=_is_string_and_non_empty,
202+
password=_is_string_and_non_empty,
203+
))
204+
205+
206+
def validate_payload(definition, payload):
207+
args = definition(*[payload.get(field, None) for field in definition._fields])
208+
209+
errors_by_field = {}
210+
for field_name, field_value in args._asdict().items():
211+
validator_fn = definition._validators[field_name]
212+
if not validator_fn(field_value):
213+
errors_by_field[field_name] = validator_fn.__doc__
214+
if errors_by_field:
215+
raise ValidationReport(errors_by_field)
216+
else:
217+
return args
218+
219+
220+
def _create_and_expose_server(server, configuration):
160221
app = Flask('coreapi')
161222

162223
@app.get('/user')
@@ -171,9 +232,18 @@ def GET_user_username(username):
171232
def POST_user():
172233
payload = request.get_json()
173234

174-
greeting = payload.get('greeting', '<none>')
175-
if greeting == 'provocation':
176-
raise http_error_from_status_code(422, None)
235+
try:
236+
validated = validate_payload(_REQUEST_ARGS_POST_USER, payload)
237+
except ValidationReport as vr:
238+
return http_error_from_status_code(400, None, vr.serialize())
239+
240+
args = list(validated)
241+
242+
ret = createuser(configuration, args)
243+
if ret != 0:
244+
raise http_error_from_status_code(400, None)
245+
246+
greeting = 'hello client!'
177247
return Response(greeting, 201)
178248

179249
return app

tests/test_mig_services_coreapi.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
from __future__ import print_function
22
import codecs
3+
import errno
34
import json
45
import os
6+
import shutil
57
import sys
68
import unittest
79
from threading import Thread
810
from unittest import skip
911

1012
from tests.support import PY2, MIG_BASE, TEST_OUTPUT_DIR, MigTestCase, \
1113
testmain, temppath, make_wrapped_server
14+
from tests.support.htmlsupp import HtmlAssertMixin
1215

1316
from mig.services.coreapi import ThreadedApiHttpServer, \
1417
_create_and_expose_server
1518
from mig.shared.conf import get_configuration_object
19+
from mig.shared.useradm import _USERADM_CONFIG_DIR_KEYS
1620

1721
_PYTHON_MAJOR = '2' if PY2 else '3'
1822
_TEST_CONF_DIR = os.path.join(
@@ -27,11 +31,19 @@
2731
from urllib.request import urlopen, Request
2832

2933

30-
class MigServerGrid_openid(MigTestCase):
34+
class MigServerGrid_openid(MigTestCase, HtmlAssertMixin):
3135
def before_each(self):
3236
self.server_addr = None
3337
self.server_thread = None
3438

39+
for config_key in _USERADM_CONFIG_DIR_KEYS:
40+
dir_path = getattr(self.configuration, config_key)[0:-1]
41+
try:
42+
shutil.rmtree(dir_path)
43+
except OSError as exc:
44+
if exc.errno != errno.ENOENT: # FileNotFoundError
45+
pass
46+
3547
def _provide_configuration(self):
3648
return 'testconfig'
3749

@@ -62,7 +74,7 @@ def issue_GET(self, request_path):
6274

6375
return (status, data)
6476

65-
def issue_POST(self, request_path, request_data=None, request_json=None):
77+
def issue_POST(self, request_path, request_data=None, request_json=None, response_encoding='textual'):
6678
assert isinstance(request_path, str) and request_path.startswith(
6779
'/'), "require http path starting with /"
6880
request_url = ''.join(
@@ -94,12 +106,18 @@ def issue_POST(self, request_path, request_data=None, request_json=None):
94106
data = response.read()
95107
except HTTPError as httpexc:
96108
status = httpexc.code
97-
data = None
109+
data = httpexc.file.read()
98110

99-
try:
100-
data = json.loads(data)
101-
except Exception as e:
102-
pass
111+
if response_encoding == 'textual':
112+
data = codecs.decode(data, 'utf8')
113+
114+
try:
115+
data = json.loads(data)
116+
except Exception as e:
117+
pass
118+
elif response_encoding != 'binary':
119+
raise AssertionError(
120+
'issue_POST: unknown response_encoding "%s"' % (response_encoding,))
103121

104122
return (status, data)
105123

@@ -171,7 +189,11 @@ def test_POST_user__bad_input_data(self):
171189
'greeting': 'provocation'
172190
})
173191

174-
self.assertEqual(status, 422)
192+
self.assertEqual(status, 400)
193+
error_description = self.assertHtmlElement(content, 'p')
194+
error_description_lines = error_description.split('<br>')
195+
self.assertEqual(
196+
error_description_lines[0], 'payload failed to validate:')
175197

176198
@unittest.skipIf(PY2, "Python 3 only")
177199
def test_POST_user(self):
@@ -181,12 +203,18 @@ def test_POST_user(self):
181203
self.server_thread = self._make_server(self.configuration, self.logger, self.server_addr)
182204
self.server_thread.start_wait_until_ready()
183205

184-
status, content = self.issue_POST('/user', request_json={
185-
'greeting': 'hello client!',
186-
})
206+
status, content = self.issue_POST('/user', response_encoding='textual', request_json=dict(
207+
full_name="Test User",
208+
organization="Test Org",
209+
state="NA",
210+
country="DK",
211+
email="dummy-user",
212+
comment="This is the create comment",
213+
password="password",
214+
))
187215

188216
self.assertEqual(status, 201)
189-
self.assertEqual(content, b'hello client!')
217+
self.assertEqual(content, 'hello client!')
190218

191219
def _make_configuration(self, test_logger, server_addr, overrides=None):
192220
configuration = self.configuration
@@ -206,7 +234,8 @@ def _make_configuration(self, test_logger, server_addr, overrides=None):
206234
@staticmethod
207235
def _make_server(configuration, logger=None, server_address=None):
208236
def _on_instance(server):
209-
server.server_app = _create_and_expose_server(server.configuration)
237+
server.server_app = _create_and_expose_server(
238+
server, server.configuration)
210239

211240
(host, port) = server_address
212241
server_thread = make_wrapped_server(ThreadedApiHttpServer, \

0 commit comments

Comments
 (0)