diff --git a/mig/shared/sanitize.py b/mig/shared/sanitize.py new file mode 100644 index 000000000..5950cc4c2 --- /dev/null +++ b/mig/shared/sanitize.py @@ -0,0 +1,158 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# --- BEGIN_HEADER --- +# +# safeeval - Safe evaluation of expressions and commands +# Copyright (C) 2003-2023 The MiG Project +# +# This file is part of MiG. +# +# MiG is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# MiG is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +# +# -- END_HEADER --- +# + +import base64 +import codecs +import os +import string +import sys + +sys.path.append(os.path.realpath( + os.path.join(os.path.dirname(__file__), "../.."))) + +from mig.shared.defaults import username_charset + +INDICATOR_CH = '\x1b' +INVALID_INSERTION_POINT = -2 +MARKER = INDICATOR_CH * 2 +MARKER_HEXDIGIT_WIDTH = 2 +UNSAFE_CHARS = sorted(list(set(string.printable) - set(username_charset))) +UNSAFE_CHARS_HEXDIGITS = None +UNSAFE_SUBSTIUTIONS = None +PY2 = sys.version_info[0] == 2 + +if PY2: + def _as_ascii_string(value): return value +else: + def _as_ascii_string(value): return codecs.decode(value, 'ascii') + +def _as_hexdigit(ch): + return _as_ascii_string(base64.b16encode(bytes(ch, 'ascii'))) + +UNSAFE_CHARS_HEXDIGITS = list(_as_hexdigit(c) for c in UNSAFE_CHARS) +UNSAFE_SUBSTIUTIONS = dict(zip(UNSAFE_CHARS, UNSAFE_CHARS_HEXDIGITS)) + +class NotAnExistingSafenameError(RuntimeError): + pass + +# TODO +# - swap to converting the ord char value to hex as a way to save bytes + +def safename_encode(value): + punycoded = _as_ascii_string(codecs.encode(value, 'punycode')) + + if len(punycoded) == 0: + return '' + + insertion_point = INVALID_INSERTION_POINT + + if punycoded[-1] == '-': + # the value is punycoded ascii - record this fact and + # remove this trailing character which will be added + # back later bsaed on the indication character + insertion_point = -1 + else: + try: + insertion_point = punycoded.rindex('-') + except ValueError: + # the marker could not be located so the insertion + # point is not updated and thus remains set invalid + pass + if insertion_point == INVALID_INSERTION_POINT: + raise AssertionError(None) + + + characters = list(punycoded) + + for index, character in enumerate(characters): + character_substitute = UNSAFE_SUBSTIUTIONS.get(character, None) + if character_substitute is not None: + characters[index] = "%s%s" % (INDICATOR_CH, character_substitute) + + if insertion_point != INVALID_INSERTION_POINT: + # replace punycode single hyphen trailer with an escaped indicator + characters[insertion_point] = INDICATOR_CH + characters.insert(insertion_point, INDICATOR_CH) + + return ''.join(characters) + + +def safename_decode(value): + if value == '': + return value + + value_to_decode = None + try: + idx = value.rindex(MARKER) + character_substitute = _as_hexdigit('-') + value_to_decode = ''.join((value[:idx + 1], character_substitute, value[idx + 2:])) + except ValueError: + raise NotAnExistingSafenameError() + + chunked = value_to_decode.split(INDICATOR_CH) + + skip_first_chunk = chunked[0] != '' + index = 1 if skip_first_chunk else 0 + + while index < len(chunked): + chunk = chunked[index] + if chunk == '': + index += 1 + continue + hexdigit = chunk[0:MARKER_HEXDIGIT_WIDTH] + character_substitute = _as_ascii_string(base64.b16decode(hexdigit)) + chunked[index] = chunked[index].replace(hexdigit, character_substitute, 1) + index += 1 + + try: + return codecs.decode(''.join(chunked), 'punycode') + except Exception as e: + raise e + + +if __name__ == '__main__': + def visibly_print(characters): + pieces = [] + for c in characters: + c_ord = ord(c) + if c == ' ': + pieces.append("\\N{SPACE}") + elif c_ord == 27: + pieces.append("\\N{ESCAPE}") + elif c == '"': + pieces.append('\\"') + elif c_ord < 10: + # single digit control chars + pieces.append("\\x0%d" % c_ord) + elif c_ord < 32: + # double digit control chars + pieces.append(str(hex(27)).replace('0', '\\', 1)) + else: + pieces.append(c) + return ''.join(pieces) + + print("%d username chars: %s" % (len(UNSAFE_CHARS), visibly_print(UNSAFE_CHARS))) diff --git a/tests/test_mig_shared_sanitize.py b/tests/test_mig_shared_sanitize.py new file mode 100644 index 000000000..513b81c00 --- /dev/null +++ b/tests/test_mig_shared_sanitize.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- + +import importlib +import os +import sys + +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "."))) +from support import MigTestCase, testmain + +from mig.shared.sanitize import safename_encode, safename_decode, NotAnExistingSafenameError + +DUMMY_ASCII = u'abcde123467890' +DUMMY_ASCII_WITH_REPLACE = "$abcde$123467890$" +DUMMY_EXOTIC = u'UniCode123½¾µßðþđŋħĸþł@ª€£$¥©®' + + +class MigSharedSanitize_safename(MigTestCase): + def test_encode_basic(self): + safename_encode("") + + def test_encode_ascii(self): + encoded = safename_encode(DUMMY_ASCII) + + self.assertEqual( + encoded, "abcde123467890\x1b\x1b") + + def test_encode_exotic(self): + encoded = safename_encode(DUMMY_EXOTIC) + + self.assertEqual( + encoded, "UniCode123@\x1b24\x1b\x1blna3a4dm6e3ftgua80ewlwka88boszo7i7iv930g") + + def test_decode_a_non_safename(self): + with self.assertRaises(Exception) as asserted: + safename_decode("foobar") + + the_exception = asserted.exception + self.assertIsInstance(the_exception, NotAnExistingSafenameError) + + def test_decode_basic(self): + safename_decode("") + + def test_decode_ascii(self): + decoded = safename_decode("abcde123467890\x1b\x1b") + + self.assertEqual(decoded, DUMMY_ASCII) + + def test_decode_exotic(self): + decoded = safename_decode("UniCode123@\x1b24\x1b\x1blna3a4dm6e3ftgua80ewlwka88boszo7i7iv930g") + + self.assertEqual(decoded, DUMMY_EXOTIC) + + def test_roundtrip_empty(self): + inputvalue = "" + + outputvalue = safename_decode(safename_encode(inputvalue)) + + self.assertEqual(outputvalue, inputvalue) + + def test_roundtrip_ascii(self): + inputvalue = DUMMY_ASCII_WITH_REPLACE + + outputvalue = safename_decode(safename_encode(inputvalue)) + + self.assertEqual(outputvalue, inputvalue) + + +def main(): + testmain(failfast=True) + + +if __name__ == '__main__': + main()