Skip to content

Add the set_user function and report users to server #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 7, 2024
Merged
6 changes: 3 additions & 3 deletions aikido_firewall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Aggregates from the different modules
"""

from dotenv import load_dotenv
# Re-export set_current_user :
from aikido_firewall.context.users import set_user

# Constants
PKG_VERSION = "0.0.1"
from dotenv import load_dotenv

# Import logger
from aikido_firewall.helpers.logging import logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(self, address, key):
self.reporter.routes.add_route(
method=data[1][0], path=data[1][1]
)
elif data[0] == "USER":
self.reporter.users.add_user(data[1])
elif data[0] == "SHOULD_RATELIMIT":
# Called to check if the context passed along as data should be
# Rate limited
Expand Down
15 changes: 11 additions & 4 deletions aikido_firewall/background_process/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from aikido_firewall.helpers.get_machine_ip import get_ip
from aikido_firewall.helpers.get_ua_from_context import get_ua_from_context
from aikido_firewall.helpers.get_current_unixtime_ms import get_unixtime_ms
from aikido_firewall import PKG_VERSION
from aikido_firewall.config import PKG_VERSION
from aikido_firewall.background_process.heartbeats import send_heartbeats_every_x_secs
from aikido_firewall.background_process.routes import Routes
from .service_config import ServiceConfig
from aikido_firewall.ratelimiting.rate_limiter import RateLimiter
from .service_config import ServiceConfig
from .users import Users
from .reporter_config import ReporterConfig


class Reporter:
Expand All @@ -33,6 +35,7 @@ def __init__(self, block, api, token, serverless, event_scheduler):
self.rate_limiter = RateLimiter(
max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes
)
self.users = Users(1000)

if isinstance(serverless, str) and len(serverless) == 0:
raise ValueError("Serverless cannot be an empty string")
Expand Down Expand Up @@ -88,6 +91,10 @@ def send_heartbeat(self):
if not self.token:
return
logger.debug("Aikido Reporter : Sending out heartbeat")
users = self.users.as_array()
routes = list(self.routes)
self.users.clear()
self.routes.clear()
res = self.api.report(
self.token,
{
Expand All @@ -108,8 +115,8 @@ def send_heartbeat(self):
},
},
"hostnames": [],
"routes": list(self.routes),
"users": [],
"routes": routes,
"users": users,
},
self.timeout_in_sec,
)
Expand Down
59 changes: 59 additions & 0 deletions aikido_firewall/background_process/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Export the Users class
"""

from aikido_firewall.helpers.get_current_unixtime_ms import get_unixtime_ms


class Users:
"""
Class that holds users for the background process
"""

def __init__(self, max_entries=1000):
self.max_entries = max_entries
self.users = {}

def add_user(self, user):
"""Store a user"""
user_id = user["id"]
current_time = get_unixtime_ms()

existing = self.users.get(user_id)
if existing:
existing["name"] = user.get("name")
existing["lastIpAddress"] = user.get("lastIpAddress")
existing["lastSeenAt"] = current_time
return

if len(self.users) >= self.max_entries:
# Remove the first added user (FIFO)
first_added_key = next(iter(self.users))
del self.users[first_added_key]

self.users[user_id] = {
"id": user_id,
"name": user.get("name"),
"lastIpAddress": user.get("lastIpAddress"),
"firstSeenAt": current_time,
"lastSeenAt": current_time,
}

def as_array(self):
"""
Give all user entries back as an array
"""
return [
{
"id": user["id"],
"name": user["name"],
"lastIpAddress": user["lastIpAddress"],
"firstSeenAt": user["firstSeenAt"],
"lastSeenAt": user["lastSeenAt"],
}
for user in self.users.values()
]

def clear(self):
"""Clear out all users"""
self.users.clear()
73 changes: 73 additions & 0 deletions aikido_firewall/background_process/users_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import time
import pytest
from .users import Users # Assuming the Users class is in a file named users.py


@pytest.fixture
def users():
"""Fixture to create a Users instance with a max of 2 entries."""
return Users(max_entries=2)


def test_users(users):
assert users.as_array() == []

users.add_user({"id": "1", "name": "John", "lastIpAddress": "::1"})
user1 = users.as_array()[0]
assert user1["id"] == "1"
assert user1["name"] == "John"
assert user1["lastIpAddress"] == "::1"
assert user1["lastSeenAt"] >= 0 # lastSeenAt should be initialized
assert (
user1["lastSeenAt"] == user1["firstSeenAt"]
) # Initially, they should be equal

# Simulate the passage of time
time.sleep(0.001) # Sleep for a short time to simulate ticking the clock
users.add_user({"id": "1", "name": "John Doe", "lastIpAddress": "1.2.3.4"})
user1_updated = users.as_array()[0]
assert user1_updated["id"] == "1"
assert user1_updated["name"] == "John Doe"
assert user1_updated["lastIpAddress"] == "1.2.3.4"
assert (
user1_updated["lastSeenAt"] >= user1_updated["firstSeenAt"]
) # lastSeenAt should be >= firstSeenAt
assert (
user1_updated["lastSeenAt"] == user1_updated["firstSeenAt"] + 1
) # lastSeenAt should be +1

users.add_user({"id": "2", "name": "Jane", "lastIpAddress": "1.2.3.4"})
user2 = users.as_array()[1]
assert user2["id"] == "2"
assert user2["name"] == "Jane"
assert user2["lastIpAddress"] == "1.2.3.4"
assert (
user2["lastSeenAt"] >= user2["firstSeenAt"]
) # lastSeenAt should be >= firstSeenAt
assert (
user2["lastSeenAt"] == user2["firstSeenAt"]
) # Initially, they should be equal

users.add_user({"id": "3", "name": "Alice", "lastIpAddress": "1.2.3.4"})
user2_updated = users.as_array()[0] # Jane should still be the first user
user3 = users.as_array()[1] # Alice should be the second user
assert user2_updated["id"] == "2"
assert user2_updated["name"] == "Jane"
assert user2_updated["lastIpAddress"] == "1.2.3.4"
assert (
user2_updated["lastSeenAt"] >= user2_updated["firstSeenAt"]
) # lastSeenAt should be >= firstSeenAt
assert (
user2_updated["lastSeenAt"] == user2_updated["firstSeenAt"]
) # Should still be equal

assert user3["id"] == "3"
assert user3["name"] == "Alice"
assert user3["lastIpAddress"] == "1.2.3.4"
assert (
user3["lastSeenAt"] >= user3["firstSeenAt"]
) # lastSeenAt should be >= firstSeenAt
assert user3["lastSeenAt"] == user3["firstSeenAt"] # Should be equal

users.clear()
assert users.as_array() == []
3 changes: 3 additions & 0 deletions aikido_firewall/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Contains package versions"""

PKG_VERSION = "0.0.1"
12 changes: 12 additions & 0 deletions aikido_firewall/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from http.cookies import SimpleCookie
from aikido_firewall.helpers.build_route_from_url import build_route_from_url
from aikido_firewall.helpers.get_subdomains_from_url import get_subdomains_from_url
from aikido_firewall.helpers.logging import logger
from aikido_firewall.helpers.get_ip_from_request import get_ip_from_request

SUPPORTED_SOURCES = ["django", "flask", "django-gunicorn"]
Expand All @@ -15,6 +16,15 @@
local = threading.local()


def set_current_user(user):
"""Sets the current user"""
if hasattr(local, "user") and local.user is not None:
logger.debug(
"Evicting a saved users, this probably means a user was set twice."
)
local.user = user


def get_current_context():
"""Returns the current context"""
try:
Expand Down Expand Up @@ -71,6 +81,7 @@ def __init__(self, context_obj=None, req=None, source=None):
self.set_django_gunicorn_attrs(req)
self.route = build_route_from_url(self.url)
self.subdomains = get_subdomains_from_url(self.url)
self.user = local.user if hasattr(local, "user") else None
self.remote_address = get_ip_from_request(self.raw_ip, self.headers)

def set_django_gunicorn_attrs(self, req):
Expand Down Expand Up @@ -116,6 +127,7 @@ def __reduce__(self):
"source": self.source,
"route": self.route,
"subdomains": self.subdomains,
"user": self.user,
},
None,
None,
Expand Down
58 changes: 58 additions & 0 deletions aikido_firewall/context/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Users file
"""

from aikido_firewall.helpers.logging import logger
from . import set_current_user, get_current_context
from aikido_firewall.background_process import get_comms


def set_user(user):
"""
External function for applications to set a user
"""
validated_user = validate_user(user)
if not validated_user:
return
logger.debug("Validated user : %s", validated_user)

set_current_user(validated_user)

context = get_current_context()
if not context:
return
validated_user["lastIpAddress"] = context.remote_address

# Send validated_user object to Agent
get_comms().send_data_to_bg_process("USER", validated_user)


def validate_user(user):
"""This validates the user object"""
if not isinstance(user, dict):
logger.info(
"set_user(...) expects a dict with 'id' and 'name' properties, found %s instead.",
type(user),
)
return

# Validate user's id :
if not "id" in user:
logger.info("set_user(...) expects an object with 'id' property.")
return
if not isinstance(user["id"], str) and not isinstance(user["id"], int):
logger.info(
"set_user(...) expects an object with 'id' property of type string or number, found %s instead.",
type(user["id"]),
)
return
if isinstance(user["id"], str) and len(user["id"]) is 0:
logger.info(
"set_user(...) expects an object with 'id' property non-empty string."
)
return
valid_user = {"id": str(user["id"])}
if "name" in user and isinstance(user["name"], str) and len(user["name"]) > 0:
valid_user["name"] = str(user["name"])

return valid_user
64 changes: 64 additions & 0 deletions aikido_firewall/context/users_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest

from .users import validate_user


def test_validate_user_valid_input():
user = {"id": "123", "name": "Alice"}
result = validate_user(user)
assert result == {"id": "123", "name": "Alice"}


def test_validate_user_valid_input_with_int_id():
user = {"id": 456, "name": "Bob"}
result = validate_user(user)
assert result == {"id": "456", "name": "Bob"}


def test_validate_user_missing_id(caplog):
user = {"name": "Charlie"}
result = validate_user(user)
assert result is None
assert "expects an object with 'id' property." in caplog.text


def test_validate_user_invalid_id_type(caplog):
user = {"id": 12.34, "name": "David"}
result = validate_user(user)
assert result is None
assert (
"expects an object with 'id' property of type string or number" in caplog.text
)


def test_validate_user_empty_string_id(caplog):
user = {"id": "", "name": "Eve"}
result = validate_user(user)
assert result is None
assert "expects an object with 'id' property non-empty string." in caplog.text


def test_validate_user_missing_name(caplog):
user = {"id": "789"}
result = validate_user(user)
assert result == {"id": "789"}


def test_validate_user_empty_name(caplog):
user = {"id": "101", "name": ""}
result = validate_user(user)
assert result == {"id": "101"}


def test_validate_user_invalid_user_type(caplog):
user = ["id", "name"]
result = validate_user(user)
assert result is None
assert "expects a dict with 'id' and 'name' properties" in caplog.text


def test_validate_user_invalid_user_type_dict_without_id(caplog):
user = {"name": "Frank"}
result = validate_user(user)
assert result is None
assert "expects an object with 'id' property." in caplog.text
8 changes: 8 additions & 0 deletions sample-apps/flask-mysql/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

@app.route("/")
def homepage():
aikido_firewall.set_user({
"id": 1,
"name": "Wout"
})
cursor = mysql.get_db().cursor()
cursor.execute("SELECT * FROM db.dogs")
dogs = cursor.fetchall()
Expand All @@ -28,6 +32,10 @@ def homepage():

@app.route('/dogpage/<int:dog_id>')
def get_dogpage(dog_id):
aikido_firewall.set_user({
"id": 2,
"name": "Wout 2"
})
cursor = mysql.get_db().cursor()
cursor.execute("SELECT * FROM db.dogs WHERE id = " + str(dog_id))
dog = cursor.fetchmany(1)[0]
Expand Down
Loading