Skip to content

Commit e72d674

Browse files
committed
Resolved issue #39
Added VERIFICATION_URL_BUILDER setting
1 parent 31d8960 commit e72d674

File tree

7 files changed

+150
-26
lines changed

7 files changed

+150
-26
lines changed

rest_registration/settings_fields.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,20 @@ def __new__(cls, name, *, default=None, help=None, import_string=False):
185185
default='rest_registration.utils.html.convert_html_to_text_preserving_urls', # noqa: E501
186186
import_string=True,
187187
),
188+
Field(
189+
'VERIFICATION_URL_BUILDER',
190+
default='rest_registration.utils.verification.build_default_verification_url', # noqa: E501
191+
import_string=True,
192+
help=dedent("""\
193+
The builder function receives the ``signer`` object and construct
194+
the url using ``signer.get_base_url()``
195+
and ``signer.get_signed_data()``. The default url builder will use
196+
the base url and append the signed data as HTTP GET query string.
197+
It is be solely up to the implementer of custom builder function
198+
to encode the signed values properly in the URL.
199+
"""),
200+
201+
),
188202
]
189203

190204
CHANGE_PASSWORD_SETTINGS_FIELDS = [

rest_registration/utils/verification.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from urllib.parse import urlencode
2+
13
from django.core.signing import BadSignature, SignatureExpired
24

35
from rest_registration.exceptions import BadRequest
@@ -10,3 +12,12 @@ def verify_signer_or_bad_request(signer):
1012
raise BadRequest('Signature expired')
1113
except BadSignature:
1214
raise BadRequest('Invalid signature')
15+
16+
17+
def build_default_verification_url(signer):
18+
base_url = signer.get_base_url()
19+
params = urlencode(signer.get_signed_data())
20+
url = '{base_url}?{params}'.format(base_url=base_url, params=params)
21+
if signer.request:
22+
url = signer.request.build_absolute_uri(url)
23+
return url

rest_registration/verification.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import pickle
22
import time
3-
from urllib.parse import urlencode
43

54
from django.core.signing import BadSignature, SignatureExpired, Signer
65
from django.utils.crypto import constant_time_compare
76

7+
from rest_registration.settings import registration_settings
8+
89
PICKLE_REPR_PROTOCOL = 4
910

1011

@@ -85,9 +86,5 @@ def __init__(self, data, request=None, strict=True):
8586
self.request = request
8687

8788
def get_url(self):
88-
base_url = self.get_base_url()
89-
params = urlencode(self.get_signed_data())
90-
url = '{base_url}?{params}'.format(base_url=base_url, params=params)
91-
if self.request:
92-
url = self.request.build_absolute_uri(url)
93-
return url
89+
url_builder = registration_settings.VERIFICATION_URL_BUILDER
90+
return url_builder(self)

tests/api/base.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,30 @@ def assert_response_is_not_found(self, response):
8383
)
8484

8585
def assert_valid_verification_url(
86-
self, url, expected_path=None, expected_query_keys=None):
87-
parsed_url = urlparse(url)
86+
self, url, expected_path=None, expected_fields=None,
87+
url_parser=None):
88+
if url_parser is None:
89+
url_parser = self._parse_verification_url
90+
try:
91+
url_path, verification_data = url_parser(url, expected_fields)
92+
except ValueError as e:
93+
self.fail(str(e))
8894
if expected_path is not None:
89-
self.assertEqual(parsed_url.path, expected_path)
95+
self.assertEqual(url_path, expected_path)
96+
if expected_fields is not None:
97+
self.assertSetEqual(
98+
set(verification_data.keys()), set(expected_fields))
99+
return verification_data
100+
101+
def _parse_verification_url(self, url, verification_field_names):
102+
parsed_url = urlparse(url)
90103
query = parse_qs(parsed_url.query, strict_parsing=True)
91-
if expected_query_keys is not None:
92-
self.assertSetEqual(set(query), set(expected_query_keys))
93104

94-
for values in query.values():
95-
self.assert_len_equals(values, 1)
105+
for key, values in query.items():
106+
if len(values) == 0:
107+
raise ValueError("no values for '{key}".format(key=key))
108+
if len(values) > 1:
109+
raise ValueError("multiple values for '{key}'".format(key=key))
96110

97111
verification_data = {key: values[0] for key, values in query.items()}
98-
return verification_data
112+
return parsed_url.path, verification_data

tests/api/test_register.py

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import time
33
from unittest import mock
44
from unittest.mock import patch
5+
from urllib.parse import quote_plus as urlquote
6+
from urllib.parse import unquote_plus as urlunquote
7+
from urllib.parse import urlparse
58

69
from django.test.utils import override_settings
710
from rest_framework import status
811

912
from rest_registration.api.views.register import RegisterSigner
1013
from rest_registration.settings import registration_settings
11-
from tests.utils import shallow_merge_dicts
14+
from tests.utils import TestCase, shallow_merge_dicts
1215

1316
from .base import APIViewTestCase
1417

@@ -36,10 +39,9 @@
3639

3740

3841
@override_settings(REST_REGISTRATION=REST_REGISTRATION_WITH_VERIFICATION)
39-
class RegisterViewTestCase(APIViewTestCase):
40-
VIEW_NAME = 'register'
42+
class RegisterSerializerTestCase(TestCase):
4143

42-
def test_register_serializer_ok(self):
44+
def test_ok(self):
4345
serializer_class = registration_settings.REGISTER_SERIALIZER_CLASS
4446
serializer = serializer_class(data={})
4547
field_names = {f for f in serializer.get_fields()}
@@ -56,7 +58,7 @@ def test_register_serializer_ok(self):
5658
},
5759
),
5860
)
59-
def test_register_serializer_no_password_ok(self):
61+
def test_no_password_ok(self):
6062
serializer_class = registration_settings.REGISTER_SERIALIZER_CLASS
6163
serializer = serializer_class(data={})
6264
field_names = {f for f in serializer.get_fields()}
@@ -65,6 +67,52 @@ def test_register_serializer_no_password_ok(self):
6567
{'id', 'username', 'first_name', 'last_name', 'email', 'password'},
6668
)
6769

70+
71+
def build_custom_verification_url(signer):
72+
base_url = signer.get_base_url()
73+
signed_data = signer.get_signed_data()
74+
if signer.USE_TIMESTAMP:
75+
timestamp = signed_data.pop(signer.TIMESTAMP_FIELD)
76+
else:
77+
timestamp = None
78+
signature = signed_data.pop(signer.SIGNATURE_FIELD)
79+
segments = [signed_data[k] for k in sorted(signed_data.keys())]
80+
segments.append(signature)
81+
if timestamp:
82+
segments.append(timestamp)
83+
quoted_segments = [urlquote(str(s)) for s in segments]
84+
85+
url = base_url
86+
if not url.endswith('/'):
87+
url += '/'
88+
url += '/'.join(quoted_segments)
89+
url += '/'
90+
if signer.request:
91+
url = signer.request.build_absolute_uri(url)
92+
93+
return url
94+
95+
96+
def parse_custom_verification_url(url, verification_field_names):
97+
parsed_url = urlparse(url)
98+
num_of_fields = len(verification_field_names)
99+
url_path = parsed_url.path.rstrip('/')
100+
url_segments = url_path.rsplit('/', num_of_fields)
101+
if len(url_segments) != num_of_fields + 1:
102+
raise ValueError("Could not parse {url}".format(url=url))
103+
104+
data_segments = url_segments[1:]
105+
url_path = url_segments[0] + '/'
106+
verification_data = {
107+
name: urlunquote(value)
108+
for name, value in zip(verification_field_names, data_segments)}
109+
return url_path, verification_data
110+
111+
112+
@override_settings(REST_REGISTRATION=REST_REGISTRATION_WITH_VERIFICATION)
113+
class RegisterViewTestCase(APIViewTestCase):
114+
VIEW_NAME = 'register'
115+
68116
def test_register_ok(self):
69117
data = self._get_register_user_data(password='testpassword')
70118
request = self.create_post_request(data)
@@ -88,7 +136,48 @@ def test_register_ok(self):
88136
verification_data = self.assert_valid_verification_url(
89137
url,
90138
expected_path=REGISTER_VERIFICATION_URL,
91-
expected_query_keys={'signature', 'user_id', 'timestamp'},
139+
expected_fields={'signature', 'user_id', 'timestamp'},
140+
)
141+
url_user_id = int(verification_data['user_id'])
142+
self.assertEqual(url_user_id, user_id)
143+
url_sig_timestamp = int(verification_data['timestamp'])
144+
self.assertGreaterEqual(url_sig_timestamp, time_before)
145+
self.assertLessEqual(url_sig_timestamp, time_after)
146+
signer = RegisterSigner(verification_data)
147+
signer.verify()
148+
149+
@override_settings(
150+
REST_REGISTRATION=shallow_merge_dicts(
151+
REST_REGISTRATION_WITH_VERIFICATION, {
152+
'VERIFICATION_URL_BUILDER': build_custom_verification_url,
153+
},
154+
),
155+
)
156+
def test_register_with_custom_verification_url_ok(self):
157+
data = self._get_register_user_data(password='testpassword')
158+
request = self.create_post_request(data)
159+
time_before = math.floor(time.time())
160+
with self.assert_one_mail_sent() as sent_emails:
161+
response = self.view_func(request)
162+
time_after = math.ceil(time.time())
163+
self.assert_valid_response(response, status.HTTP_201_CREATED)
164+
user_id = response.data['id']
165+
# Check database state.
166+
user = self.user_class.objects.get(id=user_id)
167+
self.assertEqual(user.username, data['username'])
168+
self.assertTrue(user.check_password(data['password']))
169+
self.assertFalse(user.is_active)
170+
# Check verification e-mail.
171+
sent_email = sent_emails[0]
172+
self.assertEqual(sent_email.from_email, VERIFICATION_FROM_EMAIL)
173+
self.assertListEqual(sent_email.to, [data['email']])
174+
url = self.assert_one_url_line_in_text(sent_email.body)
175+
176+
verification_data = self.assert_valid_verification_url(
177+
url,
178+
expected_path=REGISTER_VERIFICATION_URL,
179+
expected_fields=['user_id', 'signature', 'timestamp'],
180+
url_parser=parse_custom_verification_url,
92181
)
93182
url_user_id = int(verification_data['user_id'])
94183
self.assertEqual(url_user_id, user_id)
@@ -98,7 +187,6 @@ def test_register_ok(self):
98187
signer = RegisterSigner(verification_data)
99188
signer.verify()
100189

101-
# TODO: unskip this test when &times entity problem will be fixed.
102190
@override_settings(
103191
REST_REGISTRATION=REST_REGISTRATION_WITH_HTML_EMAIL_VERIFICATION,
104192
)
@@ -125,7 +213,7 @@ def test_register_with_html_email_ok(self):
125213
verification_data = self.assert_valid_verification_url(
126214
url,
127215
expected_path=REGISTER_VERIFICATION_URL,
128-
expected_query_keys={'signature', 'user_id', 'timestamp'},
216+
expected_fields={'signature', 'user_id', 'timestamp'},
129217
)
130218
url_user_id = int(verification_data['user_id'])
131219
self.assertEqual(url_user_id, user_id)
@@ -166,7 +254,7 @@ def test_register_no_password_confirm_ok(self):
166254
verification_data = self.assert_valid_verification_url(
167255
url,
168256
expected_path=REGISTER_VERIFICATION_URL,
169-
expected_query_keys={'signature', 'user_id', 'timestamp'},
257+
expected_fields={'signature', 'user_id', 'timestamp'},
170258
)
171259
url_user_id = int(verification_data['user_id'])
172260
self.assertEqual(url_user_id, user_id)

tests/api/test_register_email.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_ok(self):
6262
verification_data = self.assert_valid_verification_url(
6363
url,
6464
expected_path=REGISTER_EMAIL_VERIFICATION_URL,
65-
expected_query_keys={'signature', 'user_id', 'timestamp', 'email'},
65+
expected_fields={'signature', 'user_id', 'timestamp', 'email'},
6666
)
6767
self.assertEqual(verification_data['email'], self.new_email)
6868
self.assertEqual(int(verification_data['user_id']), self.user.id)

tests/api/test_reset_password.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _assert_valid_send_link_email(
9797
verification_data = self.assert_valid_verification_url(
9898
url,
9999
expected_path=RESET_PASSWORD_VERIFICATION_URL,
100-
expected_query_keys={'signature', 'user_id', 'timestamp'},
100+
expected_fields={'signature', 'user_id', 'timestamp'},
101101
)
102102
self.assertEqual(int(verification_data['user_id']), user.id)
103103
url_sig_timestamp = int(verification_data['timestamp'])

0 commit comments

Comments
 (0)