diff --git a/aikido_firewall/__init__.py b/aikido_firewall/__init__.py index 0348d9589..678bb1c73 100644 --- a/aikido_firewall/__init__.py +++ b/aikido_firewall/__init__.py @@ -2,10 +2,10 @@ Aggregates from the different modules """ -from dotenv import load_dotenv +# Re-export set_current_user : +from aikido_firewall.context.users import set_user -# Constants -PKG_VERSION = "0.0.1" +from dotenv import load_dotenv # Import logger from aikido_firewall.helpers.logging import logger diff --git a/aikido_firewall/background_process/aikido_background_process.py b/aikido_firewall/background_process/aikido_background_process.py index 4f4505180..8bb2320b9 100644 --- a/aikido_firewall/background_process/aikido_background_process.py +++ b/aikido_firewall/background_process/aikido_background_process.py @@ -73,6 +73,8 @@ def __init__(self, address, key): self.reporter.routes.add_route( method=data[1][0], path=data[1][1] ) + elif data[0] == "USER": + self.reporter.users.add_user(data[1]) elif data[0] == "SHOULD_RATELIMIT": # Called to check if the context passed along as data should be # Rate limited diff --git a/aikido_firewall/background_process/reporter.py b/aikido_firewall/background_process/reporter.py index db44240f6..8aec2b82b 100644 --- a/aikido_firewall/background_process/reporter.py +++ b/aikido_firewall/background_process/reporter.py @@ -11,11 +11,13 @@ from aikido_firewall.helpers.get_machine_ip import get_ip from aikido_firewall.helpers.get_ua_from_context import get_ua_from_context from aikido_firewall.helpers.get_current_unixtime_ms import get_unixtime_ms -from aikido_firewall import PKG_VERSION +from aikido_firewall.config import PKG_VERSION from aikido_firewall.background_process.heartbeats import send_heartbeats_every_x_secs from aikido_firewall.background_process.routes import Routes -from .service_config import ServiceConfig from aikido_firewall.ratelimiting.rate_limiter import RateLimiter +from .service_config import ServiceConfig +from .users import Users +from .reporter_config import ReporterConfig class Reporter: @@ -33,6 +35,7 @@ def __init__(self, block, api, token, serverless, event_scheduler): self.rate_limiter = RateLimiter( max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes ) + self.users = Users(1000) if isinstance(serverless, str) and len(serverless) == 0: raise ValueError("Serverless cannot be an empty string") @@ -88,6 +91,10 @@ def send_heartbeat(self): if not self.token: return logger.debug("Aikido Reporter : Sending out heartbeat") + users = self.users.as_array() + routes = list(self.routes) + self.users.clear() + self.routes.clear() res = self.api.report( self.token, { @@ -108,8 +115,8 @@ def send_heartbeat(self): }, }, "hostnames": [], - "routes": list(self.routes), - "users": [], + "routes": routes, + "users": users, }, self.timeout_in_sec, ) diff --git a/aikido_firewall/background_process/users.py b/aikido_firewall/background_process/users.py new file mode 100644 index 000000000..f81748df7 --- /dev/null +++ b/aikido_firewall/background_process/users.py @@ -0,0 +1,59 @@ +""" +Export the Users class +""" + +from aikido_firewall.helpers.get_current_unixtime_ms import get_unixtime_ms + + +class Users: + """ + Class that holds users for the background process + """ + + def __init__(self, max_entries=1000): + self.max_entries = max_entries + self.users = {} + + def add_user(self, user): + """Store a user""" + user_id = user["id"] + current_time = get_unixtime_ms() + + existing = self.users.get(user_id) + if existing: + existing["name"] = user.get("name") + existing["lastIpAddress"] = user.get("lastIpAddress") + existing["lastSeenAt"] = current_time + return + + if len(self.users) >= self.max_entries: + # Remove the first added user (FIFO) + first_added_key = next(iter(self.users)) + del self.users[first_added_key] + + self.users[user_id] = { + "id": user_id, + "name": user.get("name"), + "lastIpAddress": user.get("lastIpAddress"), + "firstSeenAt": current_time, + "lastSeenAt": current_time, + } + + def as_array(self): + """ + Give all user entries back as an array + """ + return [ + { + "id": user["id"], + "name": user["name"], + "lastIpAddress": user["lastIpAddress"], + "firstSeenAt": user["firstSeenAt"], + "lastSeenAt": user["lastSeenAt"], + } + for user in self.users.values() + ] + + def clear(self): + """Clear out all users""" + self.users.clear() diff --git a/aikido_firewall/background_process/users_test.py b/aikido_firewall/background_process/users_test.py new file mode 100644 index 000000000..fd646185c --- /dev/null +++ b/aikido_firewall/background_process/users_test.py @@ -0,0 +1,73 @@ +import time +import pytest +from .users import Users # Assuming the Users class is in a file named users.py + + +@pytest.fixture +def users(): + """Fixture to create a Users instance with a max of 2 entries.""" + return Users(max_entries=2) + + +def test_users(users): + assert users.as_array() == [] + + users.add_user({"id": "1", "name": "John", "lastIpAddress": "::1"}) + user1 = users.as_array()[0] + assert user1["id"] == "1" + assert user1["name"] == "John" + assert user1["lastIpAddress"] == "::1" + assert user1["lastSeenAt"] >= 0 # lastSeenAt should be initialized + assert ( + user1["lastSeenAt"] == user1["firstSeenAt"] + ) # Initially, they should be equal + + # Simulate the passage of time + time.sleep(0.001) # Sleep for a short time to simulate ticking the clock + users.add_user({"id": "1", "name": "John Doe", "lastIpAddress": "1.2.3.4"}) + user1_updated = users.as_array()[0] + assert user1_updated["id"] == "1" + assert user1_updated["name"] == "John Doe" + assert user1_updated["lastIpAddress"] == "1.2.3.4" + assert ( + user1_updated["lastSeenAt"] >= user1_updated["firstSeenAt"] + ) # lastSeenAt should be >= firstSeenAt + assert ( + user1_updated["lastSeenAt"] == user1_updated["firstSeenAt"] + 1 + ) # lastSeenAt should be +1 + + users.add_user({"id": "2", "name": "Jane", "lastIpAddress": "1.2.3.4"}) + user2 = users.as_array()[1] + assert user2["id"] == "2" + assert user2["name"] == "Jane" + assert user2["lastIpAddress"] == "1.2.3.4" + assert ( + user2["lastSeenAt"] >= user2["firstSeenAt"] + ) # lastSeenAt should be >= firstSeenAt + assert ( + user2["lastSeenAt"] == user2["firstSeenAt"] + ) # Initially, they should be equal + + users.add_user({"id": "3", "name": "Alice", "lastIpAddress": "1.2.3.4"}) + user2_updated = users.as_array()[0] # Jane should still be the first user + user3 = users.as_array()[1] # Alice should be the second user + assert user2_updated["id"] == "2" + assert user2_updated["name"] == "Jane" + assert user2_updated["lastIpAddress"] == "1.2.3.4" + assert ( + user2_updated["lastSeenAt"] >= user2_updated["firstSeenAt"] + ) # lastSeenAt should be >= firstSeenAt + assert ( + user2_updated["lastSeenAt"] == user2_updated["firstSeenAt"] + ) # Should still be equal + + assert user3["id"] == "3" + assert user3["name"] == "Alice" + assert user3["lastIpAddress"] == "1.2.3.4" + assert ( + user3["lastSeenAt"] >= user3["firstSeenAt"] + ) # lastSeenAt should be >= firstSeenAt + assert user3["lastSeenAt"] == user3["firstSeenAt"] # Should be equal + + users.clear() + assert users.as_array() == [] diff --git a/aikido_firewall/config.py b/aikido_firewall/config.py new file mode 100644 index 000000000..a82a47a20 --- /dev/null +++ b/aikido_firewall/config.py @@ -0,0 +1,3 @@ +"""Contains package versions""" + +PKG_VERSION = "0.0.1" diff --git a/aikido_firewall/context/__init__.py b/aikido_firewall/context/__init__.py index 7d962d3f0..619bc877e 100644 --- a/aikido_firewall/context/__init__.py +++ b/aikido_firewall/context/__init__.py @@ -7,6 +7,7 @@ from http.cookies import SimpleCookie from aikido_firewall.helpers.build_route_from_url import build_route_from_url from aikido_firewall.helpers.get_subdomains_from_url import get_subdomains_from_url +from aikido_firewall.helpers.logging import logger from aikido_firewall.helpers.get_ip_from_request import get_ip_from_request SUPPORTED_SOURCES = ["django", "flask", "django-gunicorn"] @@ -15,6 +16,15 @@ local = threading.local() +def set_current_user(user): + """Sets the current user""" + if hasattr(local, "user") and local.user is not None: + logger.debug( + "Evicting a saved users, this probably means a user was set twice." + ) + local.user = user + + def get_current_context(): """Returns the current context""" try: @@ -71,6 +81,7 @@ def __init__(self, context_obj=None, req=None, source=None): self.set_django_gunicorn_attrs(req) self.route = build_route_from_url(self.url) self.subdomains = get_subdomains_from_url(self.url) + self.user = local.user if hasattr(local, "user") else None self.remote_address = get_ip_from_request(self.raw_ip, self.headers) def set_django_gunicorn_attrs(self, req): @@ -116,6 +127,7 @@ def __reduce__(self): "source": self.source, "route": self.route, "subdomains": self.subdomains, + "user": self.user, }, None, None, diff --git a/aikido_firewall/context/users.py b/aikido_firewall/context/users.py new file mode 100644 index 000000000..ab32ba9d5 --- /dev/null +++ b/aikido_firewall/context/users.py @@ -0,0 +1,58 @@ +""" +Users file +""" + +from aikido_firewall.helpers.logging import logger +from . import set_current_user, get_current_context +from aikido_firewall.background_process import get_comms + + +def set_user(user): + """ + External function for applications to set a user + """ + validated_user = validate_user(user) + if not validated_user: + return + logger.debug("Validated user : %s", validated_user) + + set_current_user(validated_user) + + context = get_current_context() + if not context: + return + validated_user["lastIpAddress"] = context.remote_address + + # Send validated_user object to Agent + get_comms().send_data_to_bg_process("USER", validated_user) + + +def validate_user(user): + """This validates the user object""" + if not isinstance(user, dict): + logger.info( + "set_user(...) expects a dict with 'id' and 'name' properties, found %s instead.", + type(user), + ) + return + + # Validate user's id : + if not "id" in user: + logger.info("set_user(...) expects an object with 'id' property.") + return + if not isinstance(user["id"], str) and not isinstance(user["id"], int): + logger.info( + "set_user(...) expects an object with 'id' property of type string or number, found %s instead.", + type(user["id"]), + ) + return + if isinstance(user["id"], str) and len(user["id"]) is 0: + logger.info( + "set_user(...) expects an object with 'id' property non-empty string." + ) + return + valid_user = {"id": str(user["id"])} + if "name" in user and isinstance(user["name"], str) and len(user["name"]) > 0: + valid_user["name"] = str(user["name"]) + + return valid_user diff --git a/aikido_firewall/context/users_test.py b/aikido_firewall/context/users_test.py new file mode 100644 index 000000000..bc54e6fcc --- /dev/null +++ b/aikido_firewall/context/users_test.py @@ -0,0 +1,64 @@ +import pytest + +from .users import validate_user + + +def test_validate_user_valid_input(): + user = {"id": "123", "name": "Alice"} + result = validate_user(user) + assert result == {"id": "123", "name": "Alice"} + + +def test_validate_user_valid_input_with_int_id(): + user = {"id": 456, "name": "Bob"} + result = validate_user(user) + assert result == {"id": "456", "name": "Bob"} + + +def test_validate_user_missing_id(caplog): + user = {"name": "Charlie"} + result = validate_user(user) + assert result is None + assert "expects an object with 'id' property." in caplog.text + + +def test_validate_user_invalid_id_type(caplog): + user = {"id": 12.34, "name": "David"} + result = validate_user(user) + assert result is None + assert ( + "expects an object with 'id' property of type string or number" in caplog.text + ) + + +def test_validate_user_empty_string_id(caplog): + user = {"id": "", "name": "Eve"} + result = validate_user(user) + assert result is None + assert "expects an object with 'id' property non-empty string." in caplog.text + + +def test_validate_user_missing_name(caplog): + user = {"id": "789"} + result = validate_user(user) + assert result == {"id": "789"} + + +def test_validate_user_empty_name(caplog): + user = {"id": "101", "name": ""} + result = validate_user(user) + assert result == {"id": "101"} + + +def test_validate_user_invalid_user_type(caplog): + user = ["id", "name"] + result = validate_user(user) + assert result is None + assert "expects a dict with 'id' and 'name' properties" in caplog.text + + +def test_validate_user_invalid_user_type_dict_without_id(caplog): + user = {"name": "Frank"} + result = validate_user(user) + assert result is None + assert "expects an object with 'id' property." in caplog.text diff --git a/sample-apps/flask-mysql/app.py b/sample-apps/flask-mysql/app.py index 901a2db16..15cb1f3eb 100644 --- a/sample-apps/flask-mysql/app.py +++ b/sample-apps/flask-mysql/app.py @@ -20,6 +20,10 @@ @app.route("/") def homepage(): + aikido_firewall.set_user({ + "id": 1, + "name": "Wout" + }) cursor = mysql.get_db().cursor() cursor.execute("SELECT * FROM db.dogs") dogs = cursor.fetchall() @@ -28,6 +32,10 @@ def homepage(): @app.route('/dogpage/') def get_dogpage(dog_id): + aikido_firewall.set_user({ + "id": 2, + "name": "Wout 2" + }) cursor = mysql.get_db().cursor() cursor.execute("SELECT * FROM db.dogs WHERE id = " + str(dog_id)) dog = cursor.fetchmany(1)[0]