Skip to content

Commit 0f5530a

Browse files
committed
Implement a user creation endpoint in the server.
Back a POST handler onto the refactored createuser functionality file which exposes a function that can be invoked programatically with a configuration.
1 parent a8de6dc commit 0f5530a

File tree

3 files changed

+201
-21
lines changed

3 files changed

+201
-21
lines changed

mig/api/server.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import cgi
4444
import cgitb
4545
import codecs
46+
from collections import defaultdict, namedtuple
4647
from flask import Flask, request, Response
4748
from functools import partial, update_wrapper
4849
import os
@@ -56,7 +57,7 @@
5657
from wsgiref.simple_server import WSGIRequestHandler
5758

5859
from mig.shared.accountstate import check_account_accessible
59-
from mig.shared.base import client_dir_id, client_id_dir, cert_field_map
60+
from mig.shared.base import canonical_user, client_dir_id, client_id_dir, cert_field_map
6061
from mig.shared.conf import get_configuration_object
6162
from mig.shared.compat import PY2
6263
from mig.shared.griddaemons.openid import default_max_user_hits, \
@@ -72,10 +73,11 @@
7273
valid_complex_url, InputException
7374
from mig.shared.tlsserver import hardened_ssl_context
7475
from mig.shared.url import urlparse, urlencode, parse_qsl
75-
from mig.shared.useradm import create_user, get_any_oid_user_dn, check_password_scramble, \
76+
from mig.shared.useradm import get_any_oid_user_dn, check_password_scramble, \
7677
check_hash
7778
from mig.shared.userdb import default_db_path
7879
from mig.shared.validstring import possible_user_id, is_valid_email_address
80+
from mig.server.createuser import _main as createuser
7981

8082
# Update with extra fields
8183
cert_field_map.update({'role': 'ROLE', 'timezone': 'TZ', 'nickname': 'NICK',
@@ -114,8 +116,8 @@ def _ensure_encoded_string(chunk):
114116
exc.code: exc for exc in httpexceptions.__dict__.values() if hasattr(exc, 'code')}
115117

116118

117-
def http_error_from_status_code(http_status_code, http_url):
118-
return httpexceptions_by_code[http_status_code]()
119+
def http_error_from_status_code(http_status_code, http_url, description=None):
120+
return httpexceptions_by_code[http_status_code](description)
119121

120122

121123
def quoteattr(val):
@@ -156,7 +158,66 @@ def invalid_argument(arg):
156158
raise ValueError("Unexpected query variable: %s" % quoteattr(arg))
157159

158160

159-
def _create_and_expose_server(configuration):
161+
# requests
162+
163+
class ValidationReport(RuntimeError):
164+
def __init__(self, errors_by_field):
165+
self.errors_by_field = errors_by_field
166+
167+
def serialize(self, output_format='text'):
168+
if output_format == 'json':
169+
return dict(errors=self.errors_by_field)
170+
else:
171+
lines = ["- %s: required %s" % (k, v) for k, v in self.errors_by_field.items()]
172+
lines.insert(0, '')
173+
return 'payload failed to validate:%s' % ('\n'.join(lines),)
174+
175+
176+
def _is_not_none(value):
177+
"""value is not None"""
178+
return value is not None
179+
180+
181+
def _is_string_and_non_empty(value):
182+
"""value is a non-empty string"""
183+
return isinstance(value, str) and len(value) > 0
184+
185+
186+
_REQUEST_ARGS_POST_USER = namedtuple('PostUserArgs', [
187+
'full_name',
188+
'organization',
189+
'state',
190+
'country',
191+
'email',
192+
'comment',
193+
'password',
194+
])
195+
196+
197+
_REQUEST_ARGS_POST_USER._validators = defaultdict(lambda: _is_not_none, dict(
198+
full_name=_is_string_and_non_empty,
199+
organization=_is_string_and_non_empty,
200+
state=_is_string_and_non_empty,
201+
country=_is_string_and_non_empty,
202+
email=_is_string_and_non_empty,
203+
comment=_is_string_and_non_empty,
204+
password=_is_string_and_non_empty,
205+
))
206+
207+
def validate_payload(definition, payload):
208+
args = definition(*[payload.get(field, None) for field in definition._fields])
209+
210+
errors_by_field = {}
211+
for field_name, field_value in args._asdict().items():
212+
validator_fn = definition._validators[field_name]
213+
if not validator_fn(field_value):
214+
errors_by_field[field_name] = validator_fn.__doc__
215+
if errors_by_field:
216+
raise ValidationReport(errors_by_field)
217+
else:
218+
return args
219+
220+
def _create_and_expose_server(server, configuration):
160221
app = Flask('coreapi')
161222

162223
@app.get('/openid/user')
@@ -171,9 +232,28 @@ def GET_user_username(username):
171232
def POST_user():
172233
payload = request.get_json()
173234

174-
greeting = payload.get('greeting', '<none>')
175-
if greeting == 'provocation':
176-
raise http_error_from_status_code(422, None)
235+
# unpack the payload to a series of arguments
236+
try:
237+
validated = validate_payload(_REQUEST_ARGS_POST_USER, payload)
238+
except ValidationReport as vr:
239+
return http_error_from_status_code(400, None, vr.serialize())
240+
241+
args = list(validated)
242+
243+
try:
244+
245+
# user_dict = canonical_user(configuration, raw_user, raw_user.keys())
246+
# except (AttributeError, IndexError, KeyError) as e:
247+
# raise http_error_from_status_code(400, None)
248+
# except Exception as e:
249+
# pass
250+
251+
# try:
252+
createuser(configuration, args)
253+
except Exception as e:
254+
pass
255+
256+
greeting = 'hello client!'
177257
return Response(greeting, 201)
178258

179259
return app

tests/support/htmlsupp.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# --- BEGIN_HEADER ---
5+
#
6+
# htmlsupp - test support library for HTML
7+
# Copyright (C) 2003-2024 The MiG Project by the Science HPC Center at UCPH
8+
#
9+
# This file is part of MiG.
10+
#
11+
# MiG is free software: you can redistribute it and/or modify
12+
# it under the terms of the GNU General Public License as published by
13+
# the Free Software Foundation; either version 2 of the License, or
14+
# (at your option) any later version.
15+
#
16+
# MiG is distributed in the hope that it will be useful,
17+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
18+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19+
# GNU General Public License for more details.
20+
#
21+
# You should have received a copy of the GNU General Public License
22+
# along with this program; if not, write to the Free Software
23+
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
24+
#
25+
# -- END_HEADER ---
26+
#
27+
28+
"""Test support library for HTML."""
29+
30+
31+
class HtmlAssertMixin:
32+
"""Custom assertions for HTML containing strings."""
33+
34+
def assertHtmlElement(self, value, tag_name):
35+
"""Check that an occurrence of the specifid tag within an HTML input
36+
string can be found. Returns the textual content of the first match.
37+
"""
38+
39+
self.assertIsValidHtmlDocument(value, permit_no_close=True)
40+
41+
# TODO: this is a definitively stop-gap way of finding a tag within the HTML
42+
# and is used purely to keep this initial change to a reasonable size.
43+
44+
tag_open = ''.join(['<', tag_name, '>'])
45+
tag_open_index = value.index(tag_open)
46+
tag_open_index_after = tag_open_index + len(tag_open)
47+
48+
tag_close = ''.join(['</', tag_name, '>'])
49+
tag_close_index = value.index(tag_close, tag_open_index_after)
50+
51+
return value[tag_open_index_after:tag_close_index]
52+
53+
def assertIsValidHtmlDocument(self, value, permit_no_close=False):
54+
"""Check that the input string contains a valid HTML document.
55+
"""
56+
57+
assert isinstance(value, type(u""))
58+
59+
error = None
60+
try:
61+
has_doctype = value.startswith("<!DOCTYPE html") or value.startswith("<!doctype html")
62+
assert has_doctype, "no valid document opener"
63+
end_html_tag_idx = value.rfind('</html>')
64+
if end_html_tag_idx == -1 and permit_no_close:
65+
return
66+
maybe_document_end = value[end_html_tag_idx:].rstrip()
67+
assert maybe_document_end == '</html>', "no valid document closer"
68+
except Exception as exc:
69+
error = exc
70+
if error:
71+
raise AssertionError("failed to verify input string as HTML: %s", str(error))

tests/test_mig_api.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
from __future__ import print_function
22
import codecs
3+
import errno
34
import json
45
import os
6+
import shutil
57
import sys
68
import unittest
79
from threading import Thread
810
from unittest import skip
911

1012
from tests.support import PY2, MIG_BASE, TEST_OUTPUT_DIR, MigTestCase, \
1113
testmain, temppath, make_wrapped_server
14+
from tests.support.htmlsupp import HtmlAssertMixin
1215

1316
from mig.api import ThreadedApiHttpServer, _extend_configuration, _create_and_expose_server
1417
from mig.shared.conf import get_configuration_object
18+
from mig.shared.useradm import _USERADM_CONFIG_DIR_KEYS
1519

1620
_PYTHON_MAJOR = '2' if PY2 else '3'
1721
_TEST_CONF_DIR = os.path.join(
@@ -26,11 +30,19 @@
2630
from urllib.request import urlopen, Request
2731

2832

29-
class MigServerGrid_openid(MigTestCase):
33+
class MigServerGrid_openid(MigTestCase, HtmlAssertMixin):
3034
def before_each(self):
3135
self.server_addr = None
3236
self.server_thread = None
3337

38+
for config_key in _USERADM_CONFIG_DIR_KEYS:
39+
dir_path = getattr(self.configuration, config_key)[0:-1]
40+
try:
41+
shutil.rmtree(dir_path)
42+
except OSError as exc:
43+
if exc.errno != errno.ENOENT: # FileNotFoundError
44+
pass
45+
3446
def _provide_configuration(self):
3547
return 'testconfig'
3648

@@ -61,7 +73,7 @@ def issue_GET(self, request_path):
6173

6274
return (status, data)
6375

64-
def issue_POST(self, request_path, request_data=None, request_json=None):
76+
def issue_POST(self, request_path, request_data=None, request_json=None, response_encoding='textual'):
6577
assert isinstance(request_path, str) and request_path.startswith(
6678
'/'), "require http path starting with /"
6779
request_url = ''.join(
@@ -93,12 +105,18 @@ def issue_POST(self, request_path, request_data=None, request_json=None):
93105
data = response.read()
94106
except HTTPError as httpexc:
95107
status = httpexc.code
96-
data = None
108+
data = httpexc.file.read()
97109

98-
try:
99-
data = json.loads(data)
100-
except Exception as e:
101-
pass
110+
if response_encoding == 'textual':
111+
data = codecs.decode(data, 'utf8')
112+
113+
try:
114+
data = json.loads(data)
115+
except Exception as e:
116+
pass
117+
elif response_encoding != 'binary':
118+
raise AssertionError(
119+
'issue_POST: unknown response_encoding "%s"' % (response_encoding,))
102120

103121
return (status, data)
104122

@@ -180,7 +198,11 @@ def test_POST_user__bad_input_data(self):
180198
'greeting': 'provocation'
181199
})
182200

183-
self.assertEqual(status, 422)
201+
self.assertEqual(status, 400)
202+
error_description = self.assertHtmlElement(content, 'p')
203+
error_description_lines = error_description.split('<br>')
204+
self.assertEqual(
205+
error_description_lines[0], 'payload failed to validate:')
184206

185207
@unittest.skipIf(PY2, "Python 3 only")
186208
def test_POST_user(self):
@@ -191,12 +213,18 @@ def test_POST_user(self):
191213
self.server_thread = self._make_server(configuration)
192214
self.server_thread.start_wait_until_ready()
193215

194-
status, content = self.issue_POST('/openid/user', request_json={
195-
'greeting': 'hello client!',
196-
})
216+
status, content = self.issue_POST('/openid/user', response_encoding='textual', request_json=dict(
217+
full_name="Test User",
218+
organization="Test Org",
219+
state="NA",
220+
country="DK",
221+
email="dummy-user",
222+
comment="This is the create comment",
223+
password="password",
224+
))
197225

198226
self.assertEqual(status, 201)
199-
self.assertEqual(content, b'hello client!')
227+
self.assertEqual(content, 'hello client!')
200228

201229
def _make_configuration(self, test_logger, server_addr, overrides=None):
202230
configuration = self.configuration
@@ -216,7 +244,8 @@ def _make_configuration(self, test_logger, server_addr, overrides=None):
216244
@staticmethod
217245
def _make_server(configuration):
218246
def _on_instance(server):
219-
server.server_app = _create_and_expose_server(server.configuration)
247+
server.server_app = _create_and_expose_server(
248+
server, server.configuration)
220249

221250
server_thread = make_wrapped_server(
222251
ThreadedApiHttpServer, configuration, on_instance=_on_instance)

0 commit comments

Comments
 (0)