diff --git a/aikido_firewall/__init__.py b/aikido_firewall/__init__.py index 8eccc6e18..4d70284c4 100644 --- a/aikido_firewall/__init__.py +++ b/aikido_firewall/__init__.py @@ -36,5 +36,6 @@ def protect(module="any", server=True): # Import sinks import aikido_firewall.sinks.pymysql import aikido_firewall.sinks.mysqlclient + import aikido_firewall.sinks.pymongo logger.info("Aikido python firewall started") diff --git a/aikido_firewall/sinks/pymongo.py b/aikido_firewall/sinks/pymongo.py new file mode 100644 index 000000000..9b41f1a83 --- /dev/null +++ b/aikido_firewall/sinks/pymongo.py @@ -0,0 +1,61 @@ +""" +Sink module for `pymongo` +""" + +from importlib.metadata import version +from copy import deepcopy +import importhook +from aikido_firewall.helpers.logging import logger +from aikido_firewall.vulnerabilities.nosql_injection import detect_nosql_injection +from aikido_firewall.context import get_current_context +from aikido_firewall.background_process import get_comms + +OPERATIONS_WITH_FILTER = [ + "replace_one", # L1087 + "update_one", # L1189 + "update_many", # L1302 + "delete_one", # L1542 + "delete_many", # L1607 + "find_one", # L1672 + "count_documents", # L2020 + "find_one_and_delete", # L3207 + "find_one_and_replace", # L3296 + "find_one_and_update", # L3403 +] + +# ISSUE : Asynchronous +# ISSUE : `find` on L1707 and `find_raw_batches` on L1895 +# ISSUE : `aggregate` on L2847 and `aggregate_raw_batches` on L2942 +# ISSUE : `distinct` on L3054 + + +# Synchronous : +@importhook.on_import("pymongo.collection") +def on_pymongo_import(pymongo): + """ + Hook 'n wrap on `pymongo.collection` + Our goal is to wrap the following functions in the Collection class : + https://github.com/mongodb/mongo-python-driver/blob/98658cfd1fea42680a178373333bf27f41153759/pymongo/synchronous/collection.py#L136 + Returns : Modified pymongo.collection.Collection object + """ + modified_pymongo = importhook.copy_module(pymongo) + for operation in OPERATIONS_WITH_FILTER: + if not hasattr(pymongo.Collection, operation): + logger.warning("Operation `%s` not found on Collection object.", operation) + + prev_func = deepcopy(getattr(pymongo.Collection, operation)) + + def wrapped_operation_function( + _self, _filter, *args, prev_func=prev_func, op=operation, **kwargs + ): + context = get_current_context() + injection_results = detect_nosql_injection(context, _filter) + if injection_results["injection"]: + get_comms().send_data("ATTACK", injection_results) + raise Exception("NOSQL Injection [aikido_firewall]") + return prev_func(_self, _filter, *args, **kwargs) + + setattr(modified_pymongo.Collection, operation, wrapped_operation_function) + + # logger.debug("Wrapped `pymongo` with version %s", version("pymongo")) + return modified_pymongo diff --git a/aikido_firewall/sinks/pymongo_test.py b/aikido_firewall/sinks/pymongo_test.py new file mode 100644 index 000000000..3116a53cd --- /dev/null +++ b/aikido_firewall/sinks/pymongo_test.py @@ -0,0 +1,61 @@ +import pytest +from unittest.mock import MagicMock +from aikido_firewall.sinks.pymongo import on_pymongo_import + + +@pytest.fixture +def mock_pymongo(): + mock_pymongo = MagicMock() + mock_collection = MagicMock() + mock_pymongo.Collection = mock_collection + return mock_pymongo + + +def test_on_pymongo_import(mocker, mock_pymongo): + mocker.patch("importhook.copy_module", return_value=mock_pymongo) + + for operation in [ + "replace_one", + "update_one", + "update_many", + "delete_one", + "delete_many", + "find_one", + "count_documents", + "find_one_and_delete", + "find_one_and_replace", + "find_one_and_update", + ]: + setattr( + mock_pymongo.Collection, + operation, + MagicMock(return_value="original_result"), + ) + + mocker.patch("aikido_firewall.helpers.logging.logger") + mocker.patch( + "aikido_firewall.vulnerabilities.nosql_injection.detect_nosql_injection", + return_value={"injection": False}, + ) + mocker.patch("aikido_firewall.context.get_current_context", return_value={}) + mocker.patch( + "aikido_firewall.background_process.get_comms", + return_value=MagicMock(send_data=MagicMock()), + ) + + modified_pymongo = on_pymongo_import(mock_pymongo) + + for operation in [ + "replace_one", + "update_one", + "update_many", + "delete_one", + "delete_many", + "find_one", + "count_documents", + "find_one_and_delete", + "find_one_and_replace", + "find_one_and_update", + ]: + wrapped_function = getattr(modified_pymongo.Collection, operation) + assert wrapped_function is not None diff --git a/aikido_firewall/vulnerabilities/nosql_injection/__init__.py b/aikido_firewall/vulnerabilities/nosql_injection/__init__.py index 0637fad11..810c0c208 100644 --- a/aikido_firewall/vulnerabilities/nosql_injection/__init__.py +++ b/aikido_firewall/vulnerabilities/nosql_injection/__init__.py @@ -100,8 +100,8 @@ def detect_nosql_injection(request, _filter): return {"injection": False} for source in UINPUT_SOURCES: - if request.get(source): - result = find_filter_part_with_operators(request[source], _filter) + if hasattr(request, source): + result = find_filter_part_with_operators(getattr(request, source), _filter) if result.get("found"): return { diff --git a/aikido_firewall/vulnerabilities/nosql_injection/init_test.py b/aikido_firewall/vulnerabilities/nosql_injection/init_test.py index d3458e688..449d13bd6 100644 --- a/aikido_firewall/vulnerabilities/nosql_injection/init_test.py +++ b/aikido_firewall/vulnerabilities/nosql_injection/init_test.py @@ -1,5 +1,6 @@ import pytest from aikido_firewall.vulnerabilities.nosql_injection import detect_nosql_injection +from aikido_firewall.context import Context @pytest.fixture @@ -7,19 +8,24 @@ def create_context(): def _create_context( query=None, headers=None, body=None, cookies=None, route_params=None ): - context = { - "remote_address": "::1", - "method": "GET", - "url": "http://localhost:4000", - "query": query if query else {}, - "headers": headers if headers else {}, - "body": body, - "cookies": cookies if cookies else {}, - "route_params": route_params if route_params else {}, - "source": "express", - "route": "/posts/:id", - } - return context + class RequestContext: + remote_address = "::1" + method = "GET" + url = "http://localhost:4000" + query = {} + headers = {} + body = None + cookies = {} + route_params = {} + source = "express" + route = "/posts/:id" + + RequestContext.query = query if query else {} + RequestContext.headers = headers if headers else {} + RequestContext.body = body + RequestContext.cookies = cookies if cookies else {} + RequestContext.route_params = route_params if route_params else {} + return RequestContext() return _create_context