From 3fc85c15ee528eae8b9f6fa956d53309a543b5cf Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Fri, 21 Feb 2025 15:28:20 +0100 Subject: [PATCH 1/5] Implement declarative payloads Add the logic to automatically package values as a bundle. Implement declarative argument definitions and logic to bundle values. carve out payloads and rework them to make use of the validation helper Expand the arguments module with the notions of a defined grouping of arguments (ArgumentBundleDefinition) and bundles of particular arguments (ArgumentBundle). Use this as the mechanism by which payloads are checked for validity. As part of declaring a bundle definition the expected positional arguments are declared so implement basic length checks that catch missing positional arguments which were required. Each argument itself is also tested aganst an optional validity function which can be specified at the point of definition. try to combine payloads with the argument stuff move the test further re-integration match the other branch fixup fixup fixup repair and test error conditions fixup fixup fixup quack like a named tuple raise uniform payload exceptions and properly handle dictionaries --- .gitignore | 1 + mig/lib/coresvc/payloads.py | 212 +++++++++++++++++++++++++ tests/test_mig_lib_coresvc_payloads.py | 79 +++++++++ 3 files changed, 292 insertions(+) create mode 100644 mig/lib/coresvc/payloads.py create mode 100644 tests/test_mig_lib_coresvc_payloads.py diff --git a/.gitignore b/.gitignore index 81d3cfafe..b2f79a899 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ downloads/ eggs/ .eggs/ lib/ +!/mig/lib/ lib64/ parts/ sdist/ diff --git a/mig/lib/coresvc/payloads.py b/mig/lib/coresvc/payloads.py new file mode 100644 index 000000000..c7caf7028 --- /dev/null +++ b/mig/lib/coresvc/payloads.py @@ -0,0 +1,212 @@ +from collections import defaultdict, namedtuple, OrderedDict + +from mig.shared.safeinput import validate_helper + + +_EMPTY_LIST = {} +_REQUIRED_FIELD = object() + + +def _is_not_none(value): + """value is not None""" + assert value is not None, _is_not_none.__doc__ + + +def _is_string_and_non_empty(value): + """value is a non-empty string""" + assert isinstance(value, str) and len(value) > 0, _is_string_and_non_empty.__doc__ + + +class PayloadException(ValueError): + def __str__(self): + return self.serialize(output_format='text') + + def serialize(self, output_format='text'): + error_message = self.args[0] + + if output_format == 'json': + return dict(error=error_message) + else: + return error_message + + +class PayloadReport(PayloadException): + def __init__(self, errors_by_field): + self.errors_by_field = errors_by_field + + def serialize(self, output_format='text'): + if output_format == 'json': + return dict(errors=self.errors_by_field) + else: + lines = ["- %s: %s" % + (k, v) for k, v in self.errors_by_field.items()] + lines.insert(0, '') + return 'payload failed to validate:%s' % ('\n'.join(lines),) + + +class _MissingField: + def __init__(self, field, message=None): + assert message is not None + self._field = field + self._message = message + + def replace(self, _, __): + return self._field + + @classmethod + def assert_not_instance(cls, value): + assert not isinstance(value, cls), value._message + + +class Payload(OrderedDict): + def __init__(self, definition, dictionary): + super(Payload, self).__init__(dictionary) + self._definition = definition + + @property + def _fields(self): + return self._definition._fields + + @property + def name(self): + return self._definition._definition_name + + def __iter__(self): + return iter(self.values()) + + def items(self): + return zip(self._definition._item_names, self.values()) + + @staticmethod + def define(payload_name, payload_fields, validators_by_field): + positionals = list((field, validators_by_field[field]) for field in payload_fields) + return PayloadDefinition(payload_name, positionals) + + +class PayloadDefinition: + def __init__(self, name, positionals=_EMPTY_LIST): + self._definition_name = name + self._expected_positions = 0 + self._item_checks = [] + self._item_names = [] + + if positionals is not _EMPTY_LIST: + for positional in positionals: + self._define_positional(positional) + + @property + def _fields(self): + return self._item_names + + def __call__(self, *args): + return self._extract_and_bundle(args, extract_by='position') + + def _define_positional(self, positional): + assert len(positional) == 2 + + name, validator_fn = positional + + self._item_names.append(name) + self._item_checks.append(validator_fn) + + self._expected_positions += 1 + + def _extract_and_bundle(self, args, extract_by=None): + definition = self + + if extract_by == 'position': + actual_positions = len(args) + expected_positions = definition._expected_positions + if actual_positions < expected_positions: + raise PayloadException('Error: too few arguments given (expected %d got %d)' % ( + expected_positions, actual_positions)) + positions = list(range(actual_positions)) + dictionary = {definition._item_names[position]: args[position] for position in positions} + elif extract_by == 'name': + dictionary = {key: args.get(key, None) for key in definition._item_names} + else: + raise RuntimeError() + + return Payload(definition, dictionary) + + def ensure(self, bundle_or_args): + bundle_definition = self + + if isinstance(bundle_or_args, Payload): + assert bundle_or_args.name == bundle_definition._definition_name + return bundle_or_args + elif isinstance(bundle_or_args, dict): + bundle = self._extract_and_bundle(bundle_or_args, extract_by='name') + else: + bundle = bundle_definition(*bundle_or_args) + + return _validate_bundle(self, bundle) + + def ensure_bundle(self, bundle_or_args): + return self.ensure(bundle_or_args) + + def to_checks(self): + type_checks = {} + for key in self._fields: + type_checks[key] = _MissingField.assert_not_instance + + value_checks = dict(zip(self._item_names, self._item_checks)) + + return type_checks, value_checks + + +def _extract_field_error(bad_value): + try: + message = bad_value[0][1] + if not message: + raise IndexError + return message + except IndexError: + return 'required' + + +def _prepare_validate_helper_input(definition, payload): + def _covert_field_value(payload, field): + value = payload.get(field, _REQUIRED_FIELD) + if value is _REQUIRED_FIELD: + return _MissingField(field, 'required') + if value is None: + return _MissingField(field, 'missing') + return value + return {field: _covert_field_value(payload, field) + for field in definition._fields} + + +def _validate_bundle(definition, payload): + assert isinstance(payload, Payload) + + input_dict = _prepare_validate_helper_input(definition, payload) + type_checks, value_checks = definition.to_checks() + _, bad_values = validate_helper(input_dict, definition._fields, + type_checks, value_checks, list_wrap=True) + + if bad_values: + errors_by_field = {field_name: _extract_field_error(bad_value) + for field_name, bad_value in bad_values.items()} + raise PayloadReport(errors_by_field) + + return payload + + +PAYLOAD_POST_USER = Payload.define('PostUserArgs', [ + 'full_name', + 'organization', + 'state', + 'country', + 'email', + 'comment', + 'password', +], defaultdict(lambda: _is_not_none, dict( + full_name=_is_string_and_non_empty, + organization=_is_string_and_non_empty, + state=_is_string_and_non_empty, + country=_is_string_and_non_empty, + email=_is_string_and_non_empty, + comment=_is_string_and_non_empty, + password=_is_string_and_non_empty, +))) diff --git a/tests/test_mig_lib_coresvc_payloads.py b/tests/test_mig_lib_coresvc_payloads.py new file mode 100644 index 000000000..dae5616d6 --- /dev/null +++ b/tests/test_mig_lib_coresvc_payloads.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import sys + +from tests.support import MigTestCase, testmain + +from mig.lib.coresvc.payloads import \ + Payload as ArgumentBundle, \ + PayloadDefinition as ArgumentBundleDefinition, \ + PayloadException + + +def _contains_a_thing(value): + assert 'thing' in value + + +def _upper_case_only(value): + """value must be upper case""" + assert value == value.upper(), _upper_case_only.__doc__ + + +class TestMigSharedArguments__bundles(MigTestCase): + ThingsBundle = ArgumentBundleDefinition('Things', [ + ('some_field', _contains_a_thing), + ('other_field', _contains_a_thing), + ]) + + def assertBundleOfKind(self, value, bundle_kind=None): + assert isinstance(bundle_kind, str) and bundle_kind + self.assertIsInstance(value, ArgumentBundle, "value is not an argument bundle") + self.assertEqual(value.name, bundle_kind, "expected %s bundle, got %s" % (bundle_kind, value.name)) + + def test_bundling_arguments_produces_a_bundle(self): + bundle = self.ThingsBundle('abcthing', 'thingdef') + + self.assertBundleOfKind(bundle, bundle_kind='Things') + + def test_raises_on_missing_positional_arguments(self): + with self.assertRaises(PayloadException) as raised: + self.ThingsBundle(['a']) + self.assertEqual(str(raised.exception), 'Error: too few arguments given (expected 2 got 1)') + + def test_ensuring_arguments_returns_a_bundle(self): + bundle = self.ThingsBundle.ensure_bundle(['abcthing', 'thingdef']) + + self.assertBundleOfKind(bundle, bundle_kind='Things') + + def test_ensuring_an_existing_bundle_returns_it_unchanged(self): + existing_bundle = self.ThingsBundle('abcthing', 'thingdef') + + bundle = self.ThingsBundle.ensure_bundle(existing_bundle) + + self.assertIs(bundle, existing_bundle) + + def test_ensuring_on_a_list_of_args_validates_them(self): + with self.assertRaises(Exception) as raised: + bundle = self.ThingsBundle.ensure_bundle(['abcthing', 'def']) + self.assertEqual(str(raised.exception), 'payload failed to validate:\n- other_field: required') + + def test_ensuring_on_invalid_args_produces_reports_with_errors(self): + UpperCaseValue = ArgumentBundle.define('UpperCaseValue', ['ustring'], { + 'ustring': _upper_case_only + }) + + with self.assertRaises(Exception) as raised: + bundle = UpperCaseValue.ensure_bundle(['lowerCHARS']) + self.assertEqual(str(raised.exception), 'payload failed to validate:\n- ustring: value must be upper case') + + def test_ensuring_on_invalid_args_containing_none_behaves_correctly(self): + UpperCaseValue = ArgumentBundle.define('UpperCaseValue', ['ustring'], { + 'ustring': _upper_case_only + }) + + with self.assertRaises(Exception) as raised: + bundle = UpperCaseValue.ensure_bundle([None]) + self.assertEqual(str(raised.exception), 'payload failed to validate:\n- ustring: missing') + + +if __name__ == '__main__': + testmain() From f093e2f58a083f821ea080464900a5ade1f91435 Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Wed, 28 Aug 2024 11:43:03 +0200 Subject: [PATCH 2/5] Implement an intitial server based on flask. Do the core work necessary to have a first response served to a GET request and some basic wiring to allow submission of JSON via POST. CHECKPOINT: running coreapi server overrides? not yet sure why this was needed Implement a user creation endpoint in the server. Back a POST handler onto the refactored createuser functionality file which exposes a function that can be invoked programatically with a configuration. Split out the response data decoding chunk. axe depedency pn reworkings on userapi fixups - move closer to the server working without createuser changes Implement finding users by email address. relocate fixup carve out payloads and rework them to make use of the validation helper shut up flake complaints about things that have changed further fixup another --- mig/lib/__init__.py | 0 mig/lib/coresvc/__init__.py | 2 + mig/lib/coresvc/__main__.py | 30 ++++ mig/lib/coresvc/server.py | 243 ++++++++++++++++++++++++++++++ mig/shared/useradm.py | 4 +- requirements.txt | 1 + tests/support/httpsupp.py | 98 ++++++++++++ tests/support/serversupp.py | 3 + tests/test_mig_lib_coresvc.py | 273 ++++++++++++++++++++++++++++++++++ 9 files changed, 653 insertions(+), 1 deletion(-) create mode 100644 mig/lib/__init__.py create mode 100644 mig/lib/coresvc/__init__.py create mode 100644 mig/lib/coresvc/__main__.py create mode 100755 mig/lib/coresvc/server.py create mode 100644 tests/support/httpsupp.py create mode 100644 tests/test_mig_lib_coresvc.py diff --git a/mig/lib/__init__.py b/mig/lib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mig/lib/coresvc/__init__.py b/mig/lib/coresvc/__init__.py new file mode 100644 index 000000000..412a33c62 --- /dev/null +++ b/mig/lib/coresvc/__init__.py @@ -0,0 +1,2 @@ +from mig.lib.coresvc.server import ThreadedApiHttpServer, \ + _create_and_expose_server diff --git a/mig/lib/coresvc/__main__.py b/mig/lib/coresvc/__main__.py new file mode 100644 index 000000000..1a8155104 --- /dev/null +++ b/mig/lib/coresvc/__main__.py @@ -0,0 +1,30 @@ +from argparse import ArgumentError +from getopt import getopt +import sys + +from mig.shared.conf import get_configuration_object +from mig.services.coreapi.server import main as server_main + + +def _getopt_opts_to_options(opts): + options = {} + for k, v in opts: + options[k[1:]] = v + return options + + +def _required_argument_error(option, argument_name): + raise ArgumentError(None, 'Missing required argument: %s %s' % + (option, argument_name.upper())) + + +if __name__ == '__main__': + (opts, args) = getopt(sys.argv[1:], 'c:') + opts_dict = _getopt_opts_to_options(opts) + + if 'c' not in opts_dict: + raise _required_argument_error('-c', 'config_file') + + configuration = get_configuration_object(opts_dict['c'], + skip_log=True, disable_auth_log=True) + server_main(configuration) diff --git a/mig/lib/coresvc/server.py b/mig/lib/coresvc/server.py new file mode 100755 index 000000000..bf6cd860f --- /dev/null +++ b/mig/lib/coresvc/server.py @@ -0,0 +1,243 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# --- BEGIN_HEADER --- +# +# mig/services/coreapi/server - coreapi service server internals +# Copyright (C) 2003-2025 The MiG Project by the Science HPC Center at UCPH +# +# This file is part of MiG. +# +# MiG is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# MiG is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +# +# -- END_HEADER --- +# + + +"""HTTP server parts of the coreapi service.""" + +from __future__ import print_function +from __future__ import absolute_import + +from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn + +import base64 +from collections import defaultdict, namedtuple +from flask import Flask, request, Response +import os +import sys +import threading +import time +import werkzeug.exceptions as httpexceptions +from wsgiref.simple_server import WSGIRequestHandler + +from mig.lib.coresvc.payloads import PayloadException, \ + PAYLOAD_POST_USER as _REQUEST_ARGS_POST_USER +from mig.shared.base import canonical_user, keyword_auto, force_native_str_rec +from mig.shared.useradm import fill_user, \ + create_user as useradm_create_user, search_users as useradm_search_users +from mig.shared.userdb import default_db_path + + +httpexceptions_by_code = { + exc.code: exc for exc in httpexceptions.__dict__.values() if hasattr(exc, 'code')} + + +def http_error_from_status_code(http_status_code, http_url, description=None): + return httpexceptions_by_code[http_status_code](description) + + +def _create_user(user_dict, conf_path, **kwargs): + try: + useradm_create_user(user_dict, conf_path, keyword_auto, **kwargs) + except Exception as exc: + return 1 + return 0 + + +def search_users(configuration, search_filter): + _, hits = useradm_search_users(search_filter, configuration, keyword_auto) + return list((obj for _, obj in hits)) + + +def _create_and_expose_server(server, configuration): + app = Flask('coreapi') + + @app.get('/user') + def GET_user(): + raise http_error_from_status_code(400, None) + + @app.get('/user/') + def GET_user_username(username): + return 'FOOBAR' + + @app.get('/user/find') + def GET_user_find(): + query_params = request.args + + objects = search_users(configuration, { + 'email': query_params['email'] + }) + + if len(objects) != 1: + raise http_error_from_status_code(404, None) + + return dict(objects=objects) + + @app.post('/user') + def POST_user(): + payload = request.get_json() + + try: + validated = _REQUEST_ARGS_POST_USER.ensure(payload) + except PayloadException as vr: + return http_error_from_status_code(400, None, vr.serialize()) + + user_dict = canonical_user( + configuration, validated, _REQUEST_ARGS_POST_USER._fields) + fill_user(user_dict) + force_native_str_rec(user_dict) + + ret = _create_user(user_dict, configuration, default_renew=True) + if ret != 0: + raise http_error_from_status_code(400, None) + + greeting = 'hello client!' + return Response(greeting, 201) + + return app + + +class ApiHttpServer(HTTPServer): + """ + http(s) server that contains a reference to an OpenID Server and + knows its base URL. + Extended to fork on requests to avoid one slow or broken login stalling + the rest. + """ + + def __init__(self, configuration, logger=None, host=None, port=None, **kwargs): + self.configuration = configuration + self.logger = logger if logger else configuration.logger + self.server_app = None + self._on_start = kwargs.pop('on_start', lambda _: None) + + addr = (host, port) + HTTPServer.__init__(self, addr, ApiHttpRequestHandler, **kwargs) + + @property + def base_environ(self): + return {} + + def get_app(self): + return self.server_app + + def server_activate(self): + HTTPServer.server_activate(self) + self._on_start(self) + + +class ThreadedApiHttpServer(ThreadingMixIn, ApiHttpServer): + """Multi-threaded version of the ApiHttpServer""" + + @property + def base_url(self): + proto = 'http' + return '%s://%s:%d/' % (proto, self.server_name, self.server_port) + + +class ApiHttpRequestHandler(WSGIRequestHandler): + """TODO: docstring""" + + def __init__(self, socket, addr, server, **kwargs): + self.server = server + + # NOTE: drop idle clients after N seconds to clean stale connections. + # Does NOT include clients that connect and do nothing at all :-( + self.timeout = 120 + + self._http_url = None + self.parsed_uri = None + self.path_parts = None + self.retry_url = '' + + WSGIRequestHandler.__init__(self, socket, addr, server, **kwargs) + + @property + def configuration(self): + return self.server.configuration + + @property + def daemon_conf(self): + return self.server.configuration.daemon_conf + + @property + def logger(self): + return self.server.logger + + +def start_service(configuration, host=None, port=None): + assert host is not None, "required kwarg: host" + assert port is not None, "required kwarg: port" + + logger = configuration.logger + + def _on_start(server, *args, **kwargs): + server.server_app = _create_and_expose_server( + None, server.configuration) + + httpserver = ThreadedApiHttpServer( + configuration, host=host, port=port, on_start=_on_start) + + serve_msg = 'Server running at: %s' % httpserver.base_url + logger.info(serve_msg) + print(serve_msg) + while True: + logger.debug('handle next request') + httpserver.handle_request() + logger.debug('done handling request') + httpserver.expire_volatile() + + +def main(configuration=None): + if not configuration: + from mig.shared.conf import get_configuration_object + # Force no log init since we use separate logger + configuration = get_configuration_object(skip_log=True) + + logger = configuration.logger + + # Allow e.g. logrotate to force log re-open after rotates + #register_hangup_handler(configuration) + + # FIXME: + host = 'localhost' # configuration.user_openid_address + port = 5555 # configuration.user_openid_port + server_address = (host, port) + + info_msg = "Starting coreapi..." + logger.info(info_msg) + print(info_msg) + + try: + start_service(configuration, host=host, port=port) + except KeyboardInterrupt: + info_msg = "Received user interrupt" + logger.info(info_msg) + print(info_msg) + info_msg = "Leaving with no more workers active" + logger.info(info_msg) + print(info_msg) diff --git a/mig/shared/useradm.py b/mig/shared/useradm.py index 2812dc5f5..9b2eb9cb9 100644 --- a/mig/shared/useradm.py +++ b/mig/shared/useradm.py @@ -2318,7 +2318,9 @@ def search_users(search_filter, conf_path, db_path, fnmatch for. """ - if conf_path: + if isinstance(conf_path, Configuration): + configuration = conf_path + elif conf_path: if isinstance(conf_path, basestring): configuration = Configuration(conf_path) else: diff --git a/requirements.txt b/requirements.txt index 5c2b1bc8f..8c398a2ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ # migrid core dependencies on a format suitable for pip install as described on # https://pip.pypa.io/en/stable/reference/requirement-specifiers/ +flask future # cgi was removed from the standard library in Python 3.13 diff --git a/tests/support/httpsupp.py b/tests/support/httpsupp.py new file mode 100644 index 000000000..4a115c489 --- /dev/null +++ b/tests/support/httpsupp.py @@ -0,0 +1,98 @@ +import codecs +import json + +from tests.support._env import PY2 + +if PY2: + from urllib2 import HTTPError, Request, urlopen + from urllib import urlencode +else: + from urllib.error import HTTPError + from urllib.parse import urlencode + from urllib.request import urlopen, Request + + +def attempt_to_decode_response_data(data, response_encoding=None): + if data is None: + return None + elif response_encoding == 'textual': + data = codecs.decode(data, 'utf8') + + try: + return json.loads(data) + except Exception as e: + return data + elif response_encoding == 'binary': + return data + else: + raise AssertionError( + 'issue_POST: unknown response_encoding "%s"' % (response_encoding,)) + + +class HttpAssertMixin: + + def _issue_GET(self, server_address, request_path, query_dict=None, response_encoding='textual'): + assert isinstance(server_address, tuple) and len( + server_address) == 2, "require server address tuple" + assert isinstance(request_path, str) and request_path.startswith( + '/'), "require http path starting with /" + request_url = ''.join( + ('http://', server_address[0], ':', str(server_address[1]), request_path)) + + if query_dict is not None: + query_string = urlencode(query_dict) + request_url = ''.join((request_url, '?', query_string)) + + status = 0 + data = None + + try: + response = urlopen(request_url, None, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = None + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) + + def _issue_POST(self, server_address, request_path, request_data=None, request_json=None, response_encoding='textual'): + assert isinstance(server_address, tuple) and len( + server_address) == 2, "require server address tuple" + assert isinstance(request_path, str) and request_path.startswith( + '/'), "require http path starting with /" + request_url = ''.join( + ('http://', server_address[0], ':', str(server_address[1]), request_path)) + + if request_data and request_json: + raise ValueError( + "only one of data or json request data may be specified") + + status = 0 + data = None + + try: + if request_json is not None: + request_data = codecs.encode(json.dumps(request_json), 'utf8') + request_headers = { + 'Content-Type': 'application/json' + } + request = Request(request_url, request_data, + headers=request_headers) + elif request_data is not None: + request = Request(request_url, request_data) + else: + request = Request(request_url) + + response = urlopen(request, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = httpexc.file.read() + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) diff --git a/tests/support/serversupp.py b/tests/support/serversupp.py index 0e0fd4b94..74baa2a62 100644 --- a/tests/support/serversupp.py +++ b/tests/support/serversupp.py @@ -41,6 +41,7 @@ class ServerWithinThreadExecutor: def __init__(self, ServerClass, *args, **kwargs): self._serverclass = ServerClass + self._serverclass_on_instance = kwargs.pop('on_instance') self._arguments = (args, kwargs) self._started = ThreadEvent() self._thread = None @@ -53,6 +54,8 @@ def run(self): server_kwargs['on_start'] = lambda _: self._started.set() self._wrapped = self._serverclass(*server_args, **server_kwargs) + if self._serverclass_on_instance: + self._serverclass_on_instance(self._wrapped) try: self._wrapped.serve_forever() diff --git a/tests/test_mig_lib_coresvc.py b/tests/test_mig_lib_coresvc.py new file mode 100644 index 000000000..bd2cd151e --- /dev/null +++ b/tests/test_mig_lib_coresvc.py @@ -0,0 +1,273 @@ +from __future__ import print_function +import codecs +import errno +import json +import os +import shutil +import sys +import unittest +from threading import Thread +from unittest import skip + +from tests.support import PY2, MigTestCase, testmain, temppath, \ + make_wrapped_server +from tests.support.httpsupp import HttpAssertMixin + +from mig.shared.base import keyword_auto +from mig.shared.useradm import create_user +from mig.lib.coresvc import ThreadedApiHttpServer, \ + _create_and_expose_server + +_TAG_P_OPEN = '

' +_TAG_P_CLOSE = '

' +_USERADM_PATH_KEYS = ('user_cache', 'user_db_home', 'user_home', + 'user_settings', 'mrsl_files_dir', 'resource_pending') + + +def _extend_configuration(*args): + pass + + +def ensure_dirs_needed_by_create_user(configuration): + for config_key in _USERADM_PATH_KEYS: + dir_path = getattr(configuration, config_key)[0:-1] + try: + os.mkdir(dir_path) + except OSError as exc: + pass + + +def extract_error_description_from_html(content): + open_tag_index = content.find(_TAG_P_OPEN) + start_index = open_tag_index + len(_TAG_P_OPEN) + end_index = content.find(_TAG_P_CLOSE) + error_desription = content[start_index:end_index] + return error_desription + + +class MigServerGrid_openid(MigTestCase, HttpAssertMixin): + def before_each(self): + self.server_addr = None + self.server_thread = None + + ensure_dirs_needed_by_create_user(self.configuration) + + def _provide_configuration(self): + return 'testconfig' + + def after_each(self): + if self.server_thread: + self.server_thread.stop() + + def issue_GET(self, request_path): + return self._issue_GET(self.server_addr, request_path) + + def issue_POST(self, request_path, **kwargs): + return self._issue_POST(self.server_addr, request_path, **kwargs) + + @unittest.skipIf(PY2, "Python 3 only") + def test__GET_returns_not_found_for_missing_path(self): + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + self.server_thread.start_wait_until_ready() + + status, _ = self.issue_GET('/nonexistent') + + self.assertEqual(status, 404) + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET_user__top_level_request(self): + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + self.server_thread.start_wait_until_ready() + + status, _ = self.issue_GET('/user') + + self.assertEqual(status, 400) + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET__user_userid_request_succeeds_with_status_ok(self): + example_username = 'dummy-user' + example_username_home_dir = temppath( + 'state/user_home/%s' % example_username, self, ensure_dir=True) + test_user_home = os.path.dirname( + example_username_home_dir) # strip user from path + test_state_dir = os.path.dirname(test_user_home) + test_user_db_home = os.path.join(test_state_dir, "user_db_home") + + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + self.server_thread.start_wait_until_ready() + + the_url = '/user/%s' % (example_username,) + status, content = self.issue_GET(the_url) + + self.assertEqual(status, 200) + self.assertEqual(content, 'FOOBAR') + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET_openid_user_username(self): + flask_app = None + + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + self.server_thread.start_wait_until_ready() + + request_json = json.dumps({}) + request_data = codecs.encode(request_json, 'utf8') + + status, content = self.issue_GET('/user/dummy-user') + + self.assertEqual(status, 200) + self.assertEqual(content, 'FOOBAR') + + @unittest.skipIf(PY2, "Python 3 only") + def test_POST_user__bad_input_data(self): + flask_app = None + + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + self.server_thread.start_wait_until_ready() + + status, content = self.issue_POST('/user', request_json={ + 'greeting': 'provocation' + }) + + self.assertEqual(status, 400) + error_description = extract_error_description_from_html(content) + error_description_lines = error_description.split('
') + self.assertEqual( + error_description_lines[0], 'payload failed to validate:') + + @unittest.skipIf(PY2, "Python 3 only") + def test_POST_user(self): + flask_app = None + + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + self.server_thread.start_wait_until_ready() + + status, content = self.issue_POST('/user', response_encoding='textual', request_json=dict( + full_name="Test User", + organization="Test Org", + state="NA", + country="DK", + email="user@example.com", + comment="This is the create comment", + password="password", + )) + + self.assertEqual(status, 201) + self.assertEqual(content, 'hello client!') + + def _make_configuration(self, test_logger, server_addr): + configuration = self.configuration + _extend_configuration( + configuration, + server_addr[0], + server_addr[1], + logger=test_logger, + expandusername=False, + host_rsa_key='', + nossl=True, + show_address=False, + show_port=False, + ) + return configuration + + @staticmethod + def _make_server(configuration, logger=None, server_address=None): + def _on_instance(server): + server.server_app = _create_and_expose_server( + server, server.configuration) + + (host, port) = server_address + server_thread = make_wrapped_server(ThreadedApiHttpServer, + configuration, logger, host, port, on_instance=_on_instance) + return server_thread + + +class MigServerGrid_openid__existing_user(MigTestCase, HttpAssertMixin): + def before_each(self): + self.server_addr = None + self.server_thread = None + + ensure_dirs_needed_by_create_user(self.configuration) + + user_dict = { + 'full_name': "Test User", + 'organization': "Test Org", + 'state': "NA", + 'country': "DK", + 'email': "user@example.com", + 'comment': "This is the create comment", + 'password': "password", + } + create_user(user_dict, self.configuration, + keyword_auto, default_renew=True) + + def _provide_configuration(self): + return 'testconfig' + + def after_each(self): + if self.server_thread: + self.server_thread.stop() + + @unittest.skipIf(PY2, "Python 3 only") + def test_GET_openid_user_find(self): + flask_app = None + + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + self.server_thread.start_wait_until_ready() + + status, content = self._issue_GET(self.server_addr, '/user/find', { + 'email': 'user@example.com' + }) + + self.assertEqual(status, 200) + + self.assertIsInstance(content, dict) + self.assertIn('objects', content) + self.assertIsInstance(content['objects'], list) + + user = content['objects'][0] + # check we received the correct user + self.assertEqual(user['full_name'], 'Test User') + + def _make_configuration(self, test_logger, server_addr): + configuration = self.configuration + _extend_configuration( + configuration, + server_addr[0], + server_addr[1], + logger=test_logger, + expandusername=False, + host_rsa_key='', + nossl=True, + show_address=False, + show_port=False, + ) + return configuration + + @staticmethod + def _make_server(configuration, logger=None, server_address=None): + def _on_instance(server): + server.server_app = _create_and_expose_server( + server, server.configuration) + + (host, port) = server_address + server_thread = make_wrapped_server(ThreadedApiHttpServer, + configuration, logger, host, port, on_instance=_on_instance) + return server_thread + + +if __name__ == '__main__': + testmain() From 023780d2ac773906dbef54768901b33328b09633 Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Mon, 10 Mar 2025 14:51:27 +0100 Subject: [PATCH 3/5] further server work --- mig/lib/coresvc/server.py | 42 +++++++++++++++++------------ mig/shared/useradm.py | 4 ++- tests/test_mig_lib_coresvc.py | 50 ++++++++--------------------------- 3 files changed, 39 insertions(+), 57 deletions(-) diff --git a/mig/lib/coresvc/server.py b/mig/lib/coresvc/server.py index bf6cd860f..f96f4fa28 100755 --- a/mig/lib/coresvc/server.py +++ b/mig/lib/coresvc/server.py @@ -37,6 +37,7 @@ import base64 from collections import defaultdict, namedtuple from flask import Flask, request, Response +import json import os import sys import threading @@ -60,12 +61,28 @@ def http_error_from_status_code(http_status_code, http_url, description=None): return httpexceptions_by_code[http_status_code](description) -def _create_user(user_dict, conf_path, **kwargs): +def json_reponse_from_status_code(http_status_code, content): + json_content = json.dumps(content) + return Response(json_content, http_status_code, { 'Content-Type': 'application/json' }) + + +def _create_user(configuration, payload): + user_dict = canonical_user( + configuration, payload, _REQUEST_ARGS_POST_USER._fields) + fill_user(user_dict) + force_native_str_rec(user_dict) + try: - useradm_create_user(user_dict, conf_path, keyword_auto, **kwargs) - except Exception as exc: - return 1 - return 0 + useradm_create_user(user_dict, configuration, keyword_auto, default_renew=True) + except: + raise http_error_from_status_code(500, None) + user_email = user_dict['email'] + objects = search_users(configuration, { + 'email': user_email + }) + if len(objects) != 1: + raise http_error_from_status_code(400, None) + return objects[0] def search_users(configuration, search_filter): @@ -102,21 +119,12 @@ def POST_user(): payload = request.get_json() try: - validated = _REQUEST_ARGS_POST_USER.ensure(payload) + payload = _REQUEST_ARGS_POST_USER.ensure(payload) except PayloadException as vr: return http_error_from_status_code(400, None, vr.serialize()) - user_dict = canonical_user( - configuration, validated, _REQUEST_ARGS_POST_USER._fields) - fill_user(user_dict) - force_native_str_rec(user_dict) - - ret = _create_user(user_dict, configuration, default_renew=True) - if ret != 0: - raise http_error_from_status_code(400, None) - - greeting = 'hello client!' - return Response(greeting, 201) + user = _create_user(configuration, payload) + return json_reponse_from_status_code(201, user) return app diff --git a/mig/shared/useradm.py b/mig/shared/useradm.py index 9b2eb9cb9..a144f8c0a 100644 --- a/mig/shared/useradm.py +++ b/mig/shared/useradm.py @@ -1027,7 +1027,9 @@ def create_user(user, conf_path, db_path, force=False, verbose=False, format as a first step. """ - if conf_path: + if isinstance(conf_path, Configuration): + configuration = conf_path + elif conf_path: if isinstance(conf_path, basestring): # has been checked for accessibility above... diff --git a/tests/test_mig_lib_coresvc.py b/tests/test_mig_lib_coresvc.py index bd2cd151e..75ca71fa2 100644 --- a/tests/test_mig_lib_coresvc.py +++ b/tests/test_mig_lib_coresvc.py @@ -52,6 +52,10 @@ def before_each(self): ensure_dirs_needed_by_create_user(self.configuration) + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + def _provide_configuration(self): return 'testconfig' @@ -67,9 +71,6 @@ def issue_POST(self, request_path, **kwargs): @unittest.skipIf(PY2, "Python 3 only") def test__GET_returns_not_found_for_missing_path(self): - self.server_addr = ('localhost', 4567) - self.server_thread = self._make_server( - self.configuration, self.logger, self.server_addr) self.server_thread.start_wait_until_ready() status, _ = self.issue_GET('/nonexistent') @@ -78,9 +79,6 @@ def test__GET_returns_not_found_for_missing_path(self): @unittest.skipIf(PY2, "Python 3 only") def test_GET_user__top_level_request(self): - self.server_addr = ('localhost', 4567) - self.server_thread = self._make_server( - self.configuration, self.logger, self.server_addr) self.server_thread.start_wait_until_ready() status, _ = self.issue_GET('/user') @@ -96,30 +94,17 @@ def test_GET__user_userid_request_succeeds_with_status_ok(self): example_username_home_dir) # strip user from path test_state_dir = os.path.dirname(test_user_home) test_user_db_home = os.path.join(test_state_dir, "user_db_home") - - self.server_addr = ('localhost', 4567) - self.server_thread = self._make_server( - self.configuration, self.logger, self.server_addr) self.server_thread.start_wait_until_ready() - the_url = '/user/%s' % (example_username,) - status, content = self.issue_GET(the_url) + status, content = self.issue_GET('/user/dummy-user') self.assertEqual(status, 200) self.assertEqual(content, 'FOOBAR') @unittest.skipIf(PY2, "Python 3 only") def test_GET_openid_user_username(self): - flask_app = None - - self.server_addr = ('localhost', 4567) - self.server_thread = self._make_server( - self.configuration, self.logger, self.server_addr) self.server_thread.start_wait_until_ready() - request_json = json.dumps({}) - request_data = codecs.encode(request_json, 'utf8') - status, content = self.issue_GET('/user/dummy-user') self.assertEqual(status, 200) @@ -127,11 +112,6 @@ def test_GET_openid_user_username(self): @unittest.skipIf(PY2, "Python 3 only") def test_POST_user__bad_input_data(self): - flask_app = None - - self.server_addr = ('localhost', 4567) - self.server_thread = self._make_server( - self.configuration, self.logger, self.server_addr) self.server_thread.start_wait_until_ready() status, content = self.issue_POST('/user', request_json={ @@ -146,11 +126,6 @@ def test_POST_user__bad_input_data(self): @unittest.skipIf(PY2, "Python 3 only") def test_POST_user(self): - flask_app = None - - self.server_addr = ('localhost', 4567) - self.server_thread = self._make_server( - self.configuration, self.logger, self.server_addr) self.server_thread.start_wait_until_ready() status, content = self.issue_POST('/user', response_encoding='textual', request_json=dict( @@ -164,7 +139,8 @@ def test_POST_user(self): )) self.assertEqual(status, 201) - self.assertEqual(content, 'hello client!') + self.assertIsInstance(content, dict) + self.assertIn('unique_id', content) def _make_configuration(self, test_logger, server_addr): configuration = self.configuration @@ -195,9 +171,6 @@ def _on_instance(server): class MigServerGrid_openid__existing_user(MigTestCase, HttpAssertMixin): def before_each(self): - self.server_addr = None - self.server_thread = None - ensure_dirs_needed_by_create_user(self.configuration) user_dict = { @@ -212,6 +185,10 @@ def before_each(self): create_user(user_dict, self.configuration, keyword_auto, default_renew=True) + self.server_addr = ('localhost', 4567) + self.server_thread = self._make_server( + self.configuration, self.logger, self.server_addr) + def _provide_configuration(self): return 'testconfig' @@ -221,11 +198,6 @@ def after_each(self): @unittest.skipIf(PY2, "Python 3 only") def test_GET_openid_user_find(self): - flask_app = None - - self.server_addr = ('localhost', 4567) - self.server_thread = self._make_server( - self.configuration, self.logger, self.server_addr) self.server_thread.start_wait_until_ready() status, content = self._issue_GET(self.server_addr, '/user/find', { From 6614c05685f1383994976cd4f0ae4f78b10a96ab Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Wed, 5 Mar 2025 20:50:37 +0100 Subject: [PATCH 4/5] Put together first pass of a self contained core service API client. --- mig/lib/coreapi/__init__.py | 97 +++++++++++++++++++++++++++++++++++ tests/support/serversupp.py | 18 ++++--- tests/test_mig_lib_coreapi.py | 70 +++++++++++++++++++++++++ 3 files changed, 179 insertions(+), 6 deletions(-) create mode 100644 mig/lib/coreapi/__init__.py create mode 100644 tests/test_mig_lib_coreapi.py diff --git a/mig/lib/coreapi/__init__.py b/mig/lib/coreapi/__init__.py new file mode 100644 index 000000000..1bfeeeae9 --- /dev/null +++ b/mig/lib/coreapi/__init__.py @@ -0,0 +1,97 @@ +import codecs +import json + +from tests.support._env import PY2 + +if PY2: + from urllib2 import HTTPError, Request, urlopen + from urllib import urlencode +else: + from urllib.error import HTTPError + from urllib.parse import urlencode + from urllib.request import urlopen, Request + +from mig.lib.coresvc.payloads import PAYLOAD_POST_USER + + +def attempt_to_decode_response_data(data, response_encoding=None): + if data is None: + return None + elif response_encoding == 'textual': + data = codecs.decode(data, 'utf8') + + try: + return json.loads(data) + except Exception as e: + return data + elif response_encoding == 'binary': + return data + else: + raise AssertionError( + 'issue_POST: unknown response_encoding "%s"' % (response_encoding,)) + + +class CoreApiClient: + def __init__(self, base_url): + self._base_url = base_url + + def _issue_GET(self, request_path, query_dict=None, response_encoding='textual'): + request_url = ''.join((self._base_url, request_path)) + + if query_dict is not None: + query_string = urlencode(query_dict) + request_url = ''.join((request_url, '?', query_string)) + + status = 0 + data = None + + try: + response = urlopen(request_url, None, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = None + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) + + def _issue_POST(self, request_path, request_data=None, request_json=None, response_encoding='textual'): + request_url = ''.join((self._base_url, request_path)) + + if request_data and request_json: + raise ValueError( + "only one of data or json request data may be specified") + + status = 0 + data = None + + try: + if request_json is not None: + request_data = codecs.encode(json.dumps(request_json), 'utf8') + request_headers = { + 'Content-Type': 'application/json' + } + request = Request(request_url, request_data, + headers=request_headers) + elif request_data is not None: + request = Request(request_url, request_data) + else: + request = Request(request_url) + + response = urlopen(request, timeout=2000) + + status = response.getcode() + data = response.read() + except HTTPError as httpexc: + status = httpexc.code + data = httpexc.fp.read() + + content = attempt_to_decode_response_data(data, response_encoding) + return (status, content) + + def createUser(self, user_dict): + payload = PAYLOAD_POST_USER.ensure(user_dict) + + return self._issue_POST('/user', request_json=dict(payload)) diff --git a/tests/support/serversupp.py b/tests/support/serversupp.py index 74baa2a62..f7d78bffd 100644 --- a/tests/support/serversupp.py +++ b/tests/support/serversupp.py @@ -41,12 +41,16 @@ class ServerWithinThreadExecutor: def __init__(self, ServerClass, *args, **kwargs): self._serverclass = ServerClass - self._serverclass_on_instance = kwargs.pop('on_instance') + self._serverclass_on_instance = kwargs.pop('on_instance', None) self._arguments = (args, kwargs) self._started = ThreadEvent() self._thread = None self._wrapped = None + def __getattr__(self, attr): + assert self._wrapped, "wrapped instance was not created" + return getattr(self._wrapped, attr) + def run(self): """Mimic the same method from the standard thread API""" server_args, server_kwargs = self._arguments @@ -76,14 +80,16 @@ def start_wait_until_ready(self): def stop(self): """Mimic the same method from the standard thread API""" self.stop_server() - self._wrapped = None - self._thread.join() - self._thread = None + if self._thread: + self._thread.join() + self._thread = None def stop_server(self): """Stop server thread""" - self._wrapped.shutdown() - self._wrapped.server_close() + if self._wrapped: + self._wrapped.shutdown() + self._wrapped.server_close() + self._wrapped = None def make_wrapped_server(ServerClass, *args, **kwargs): diff --git a/tests/test_mig_lib_coreapi.py b/tests/test_mig_lib_coreapi.py new file mode 100644 index 000000000..73f73ae09 --- /dev/null +++ b/tests/test_mig_lib_coreapi.py @@ -0,0 +1,70 @@ +from http.server import HTTPServer, BaseHTTPRequestHandler + +from tests.support import MigTestCase, testmain +from tests.support.serversupp import make_wrapped_server + +from mig.lib.coreapi import CoreApiClient + + +class TestRequestHandler(BaseHTTPRequestHandler): + def do_POST(self): + test_server = self.server + + programmed_error = test_server._programmed_error + if programmed_error: + status, content = programmed_error + self.send_response(status) + self.end_headers() + self.wfile.write(content) + + +class TestHTTPServer(HTTPServer): + def __init__(self, addr, **kwargs): + self._programmed_error = None + self._on_start = kwargs.pop('on_start', lambda _: None) + + HTTPServer.__init__(self, addr, TestRequestHandler, **kwargs) + + def clear_programmed(self): + self._programmed_error = None + + def set_programmed_error(self, status, content): + assert isinstance(content, bytes) + self._programmed_error = (status, content) + + def server_activate(self): + HTTPServer.server_activate(self) + self._on_start(self) + + +class TestMigLibCoreapi(MigTestCase): + def before_each(self): + self.server_addr = ('localhost', 4567) + self.server = make_wrapped_server(TestHTTPServer, self.server_addr) + + def after_each(self): + server = getattr(self, 'server', None) + setattr(self, 'server', None) + if server: + server.stop() + + def test_true(self): + self.server.start_wait_until_ready() + self.server.set_programmed_error(418, b'tea; earl grey; hot') + instance = CoreApiClient("http://%s:%s/" % self.server_addr) + + status, content = instance.createUser({ + 'full_name': "Test User", + 'organization': "Test Org", + 'state': "NA", + 'country': "DK", + 'email': "user@example.com", + 'comment': "This is the create comment", + 'password': "password", + }) + + self.assertEqual(status, 418) + + +if __name__ == '__main__': + testmain() From 811b34002a69af735ad4f6747d94033673d98bb5 Mon Sep 17 00:00:00 2001 From: Alex Burke Date: Mon, 10 Mar 2025 16:05:21 +0100 Subject: [PATCH 5/5] make it usable --- mig/lib/coreapi/__init__.py | 15 +++++++++- tests/test_mig_lib_coreapi.py | 54 ++++++++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/mig/lib/coreapi/__init__.py b/mig/lib/coreapi/__init__.py index 1bfeeeae9..995053549 100644 --- a/mig/lib/coreapi/__init__.py +++ b/mig/lib/coreapi/__init__.py @@ -1,5 +1,6 @@ import codecs import json +import werkzeug.exceptions as httpexceptions from tests.support._env import PY2 @@ -14,6 +15,10 @@ from mig.lib.coresvc.payloads import PAYLOAD_POST_USER +httpexceptions_by_code = { + exc.code: exc for exc in httpexceptions.__dict__.values() if hasattr(exc, 'code')} + + def attempt_to_decode_response_data(data, response_encoding=None): if data is None: return None @@ -31,6 +36,10 @@ def attempt_to_decode_response_data(data, response_encoding=None): 'issue_POST: unknown response_encoding "%s"' % (response_encoding,)) +def http_error_from_status_code(http_status_code, description=None): + return httpexceptions_by_code[http_status_code](description) + + class CoreApiClient: def __init__(self, base_url): self._base_url = base_url @@ -94,4 +103,8 @@ def _issue_POST(self, request_path, request_data=None, request_json=None, respon def createUser(self, user_dict): payload = PAYLOAD_POST_USER.ensure(user_dict) - return self._issue_POST('/user', request_json=dict(payload)) + status, output = self._issue_POST('/user', request_json=dict(payload)) + if status != 201: + description = output if isinstance(output, str) else None + raise http_error_from_status_code(status, description) + return output diff --git a/tests/test_mig_lib_coreapi.py b/tests/test_mig_lib_coreapi.py index 73f73ae09..06a2e0a8c 100644 --- a/tests/test_mig_lib_coreapi.py +++ b/tests/test_mig_lib_coreapi.py @@ -1,3 +1,5 @@ +import codecs +import json from http.server import HTTPServer, BaseHTTPRequestHandler from tests.support import MigTestCase, testmain @@ -10,17 +12,20 @@ class TestRequestHandler(BaseHTTPRequestHandler): def do_POST(self): test_server = self.server - programmed_error = test_server._programmed_error - if programmed_error: - status, content = programmed_error - self.send_response(status) - self.end_headers() - self.wfile.write(content) + if test_server._programmed_response: + status, content = test_server._programmed_response + elif test_server._programmed_error: + status, content = test_server._programmed_error + + self.send_response(status) + self.end_headers() + self.wfile.write(content) class TestHTTPServer(HTTPServer): def __init__(self, addr, **kwargs): self._programmed_error = None + self._programmed_response = None self._on_start = kwargs.pop('on_start', lambda _: None) HTTPServer.__init__(self, addr, TestRequestHandler, **kwargs) @@ -29,9 +34,18 @@ def clear_programmed(self): self._programmed_error = None def set_programmed_error(self, status, content): + assert self._programmed_response is None assert isinstance(content, bytes) self._programmed_error = (status, content) + def set_programmed_response(self, status, content): + assert self._programmed_error is None + assert isinstance(content, bytes) + self._programmed_response = (status, content) + + def set_programmed_json_response(self, status, content): + self.set_programmed_response(status, codecs.encode(json.dumps(content), 'utf8')) + def server_activate(self): HTTPServer.server_activate(self) self._on_start(self) @@ -48,12 +62,32 @@ def after_each(self): if server: server.stop() - def test_true(self): + def test_raises_in_the_absence_of_success(self): self.server.start_wait_until_ready() self.server.set_programmed_error(418, b'tea; earl grey; hot') instance = CoreApiClient("http://%s:%s/" % self.server_addr) - status, content = instance.createUser({ + with self.assertRaises(Exception): + instance.createUser({ + 'full_name': "Test User", + 'organization': "Test Org", + 'state': "NA", + 'country': "DK", + 'email': "user@example.com", + 'comment': "This is the create comment", + 'password': "password", + }) + + def test_returs_a_user_object(self): + test_content = { + 'foo': 1, + 'bar': True + } + self.server.start_wait_until_ready() + self.server.set_programmed_json_response(201, test_content) + instance = CoreApiClient("http://%s:%s/" % self.server_addr) + + content = instance.createUser({ 'full_name': "Test User", 'organization': "Test Org", 'state': "NA", @@ -63,8 +97,8 @@ def test_true(self): 'password': "password", }) - self.assertEqual(status, 418) - + self.assertIsInstance(content, dict) + self.assertEqual(content, test_content) if __name__ == '__main__': testmain()