Skip to content

Commit 20166b3

Browse files
committed
Fix up find_hostname_in_userinput to be more efficient
1 parent 18ea2cd commit 20166b3

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed
Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,53 @@
11
"""
22
Only exports find_hostname_in_userinput function
33
"""
4-
from typing import Optional
4+
5+
from typing import Optional, List, Dict
56

67
from aikido_zen.helpers.get_port_from_url import get_port_from_url
78
from aikido_zen.helpers.try_parse_url import try_parse_url
89

910

10-
def find_hostname_in_userinput(user_input: str, normalized_hostname: str, port: Optional[int] = None):
11+
def find_hostname_in_userinput(
12+
user_input: str, normalized_hostname: str, port: Optional[int] = None
13+
):
1114
"""
1215
Returns true if the hostname is in userinput
1316
"""
17+
normalized_hostname = normalized_hostname.lower()
1418
if len(user_input) <= 1:
1519
return False
1620
if port and not str(port) in user_input:
1721
# Easy way for an early return: If a port is defined, it has to be inside the user input.
1822
return False
1923

20-
variants = [user_input, f"http://{user_input}", f"https://{user_input}"]
21-
for variant in variants:
22-
user_input_url = try_parse_url(variant)
23-
if not user_input_url:
24-
continue
25-
if user_input_url.hostname.lower() == normalized_hostname.lower():
26-
user_port = get_port_from_url(user_input_url.geturl())
24+
user_input_variants = [user_input, f"http://{user_input}", f"https://{user_input}"]
25+
user_input_normalized_variants = normalize_raw_url_variants(user_input_variants)
2726

27+
for user_input_hostname, user_input_port in user_input_normalized_variants:
28+
hostname_variants = [normalized_hostname, f"[{normalized_hostname}]"]
29+
if user_input_hostname in hostname_variants:
2830
# We were unable to retrieve the port from the URL, likely because it contains an invalid port.
2931
# Let's assume we have found the hostname in the user input, even though it doesn't match on port.
3032
# See: https://github.com/AikidoSec/firewall-python/pull/180.
31-
if user_port is None:
33+
if user_input_port is None:
3234
return True
3335

3436
if port is None:
3537
return True
36-
if port is not None and user_port == port:
38+
if port is not None and user_input_port == port:
3739
return True
3840

3941
return False
42+
43+
44+
def normalize_raw_url_variants(url_variants: List[str]) -> Dict[str, Optional[int]]:
45+
normalized_variants = {}
46+
for variant in url_variants:
47+
# Try parse the variant as an url,
48+
user_input_url = try_parse_url(variant)
49+
if not user_input_url or not user_input_url.hostname:
50+
continue
51+
port = get_port_from_url(user_input_url.geturl())
52+
normalized_variants[user_input_url.hostname.lower()] = port
53+
return normalized_variants

0 commit comments

Comments
 (0)