|
1 | 1 | """
|
2 | 2 | Only exports find_hostname_in_userinput function
|
3 | 3 | """
|
4 |
| -from typing import Optional |
| 4 | + |
| 5 | +from typing import Optional, List, Dict |
5 | 6 |
|
6 | 7 | from aikido_zen.helpers.get_port_from_url import get_port_from_url
|
7 | 8 | from aikido_zen.helpers.try_parse_url import try_parse_url
|
8 | 9 |
|
9 | 10 |
|
10 |
| -def find_hostname_in_userinput(user_input: str, normalized_hostname: str, port: Optional[int] = None): |
| 11 | +def find_hostname_in_userinput( |
| 12 | + user_input: str, normalized_hostname: str, port: Optional[int] = None |
| 13 | +): |
11 | 14 | """
|
12 | 15 | Returns true if the hostname is in userinput
|
13 | 16 | """
|
| 17 | + normalized_hostname = normalized_hostname.lower() |
14 | 18 | if len(user_input) <= 1:
|
15 | 19 | return False
|
16 | 20 | if port and not str(port) in user_input:
|
17 | 21 | # Easy way for an early return: If a port is defined, it has to be inside the user input.
|
18 | 22 | return False
|
19 | 23 |
|
20 |
| - variants = [user_input, f"http://{user_input}", f"https://{user_input}"] |
21 |
| - for variant in variants: |
22 |
| - user_input_url = try_parse_url(variant) |
23 |
| - if not user_input_url: |
24 |
| - continue |
25 |
| - if user_input_url.hostname.lower() == normalized_hostname.lower(): |
26 |
| - user_port = get_port_from_url(user_input_url.geturl()) |
| 24 | + user_input_variants = [user_input, f"http://{user_input}", f"https://{user_input}"] |
| 25 | + user_input_normalized_variants = normalize_raw_url_variants(user_input_variants) |
27 | 26 |
|
| 27 | + for user_input_hostname, user_input_port in user_input_normalized_variants: |
| 28 | + hostname_variants = [normalized_hostname, f"[{normalized_hostname}]"] |
| 29 | + if user_input_hostname in hostname_variants: |
28 | 30 | # We were unable to retrieve the port from the URL, likely because it contains an invalid port.
|
29 | 31 | # Let's assume we have found the hostname in the user input, even though it doesn't match on port.
|
30 | 32 | # See: https://github.com/AikidoSec/firewall-python/pull/180.
|
31 |
| - if user_port is None: |
| 33 | + if user_input_port is None: |
32 | 34 | return True
|
33 | 35 |
|
34 | 36 | if port is None:
|
35 | 37 | return True
|
36 |
| - if port is not None and user_port == port: |
| 38 | + if port is not None and user_input_port == port: |
37 | 39 | return True
|
38 | 40 |
|
39 | 41 | return False
|
| 42 | + |
| 43 | + |
| 44 | +def normalize_raw_url_variants(url_variants: List[str]) -> Dict[str, Optional[int]]: |
| 45 | + normalized_variants = {} |
| 46 | + for variant in url_variants: |
| 47 | + # Try parse the variant as an url, |
| 48 | + user_input_url = try_parse_url(variant) |
| 49 | + if not user_input_url or not user_input_url.hostname: |
| 50 | + continue |
| 51 | + port = get_port_from_url(user_input_url.geturl()) |
| 52 | + normalized_variants[user_input_url.hostname.lower()] = port |
| 53 | + return normalized_variants |
0 commit comments