Skip to content

Make sure the @on_import is always at the end of the file #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions aikido_zen/sinks/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from aikido_zen.sinks import patch_function, before, on_import


@before
def _execute(func, instance, args, kwargs):
query = get_argument(args, kwargs, 0, "query")

op = f"asyncpg.connection.Connection.{func.__name__}"
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))


@on_import("asyncpg.connection", "asyncpg", version_requirement="0.27.0")
def patch(m):
"""
Expand All @@ -19,11 +27,3 @@ def patch(m):
patch_function(m, "Connection.execute", _execute)
patch_function(m, "Connection.executemany", _execute)
patch_function(m, "Connection._execute", _execute)


@before
def _execute(func, instance, args, kwargs):
query = get_argument(args, kwargs, 0, "query")

op = f"asyncpg.connection.Connection.{func.__name__}"
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))
22 changes: 11 additions & 11 deletions aikido_zen/sinks/mysqlclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,6 @@
from aikido_zen.sinks import patch_function, on_import, before


@on_import("MySQLdb.cursors", "mysqlclient", version_requirement="1.5.0")
def patch(m):
"""
patching MySQLdb.cursors (mysqlclient)
- patches Cursor.execute(query, ...)
- patches Cursor.executemany(query, ...)
"""
patch_function(m, "Cursor.execute", _execute)
patch_function(m, "Cursor.executemany", _executemany)


@before
def _execute(func, instance, args, kwargs):
query = get_argument(args, kwargs, 0, "query")
Expand All @@ -37,3 +26,14 @@ def _executemany(func, instance, args, kwargs):
vulns.run_vulnerability_scan(
kind="sql_injection", op="MySQLdb.Cursor.executemany", args=(query, "mysql")
)


@on_import("MySQLdb.cursors", "mysqlclient", version_requirement="1.5.0")
def patch(m):
"""
patching MySQLdb.cursors (mysqlclient)
- patches Cursor.execute(query, ...)
- patches Cursor.executemany(query, ...)
"""
patch_function(m, "Cursor.execute", _execute)
patch_function(m, "Cursor.executemany", _executemany)
26 changes: 13 additions & 13 deletions aikido_zen/sinks/psycopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,6 @@
from aikido_zen.sinks import patch_function, on_import, before


@on_import("psycopg.cursor", "psycopg", version_requirement="3.1.0")
def patch(m):
"""
patching module psycopg.cursor
- patches Cursor.copy
- patches Cursor.execute
- patches Cursor.executemany
"""
patch_function(m, "Cursor.copy", _copy)
patch_function(m, "Cursor.execute", _execute)
patch_function(m, "Cursor.executemany", _execute)


@before
def _copy(func, instance, args, kwargs):
statement = get_argument(args, kwargs, 0, "statement")
Expand All @@ -33,3 +20,16 @@ def _execute(func, instance, args, kwargs):
query = get_argument(args, kwargs, 0, "query")
op = f"psycopg.Cursor.{func.__name__}"
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))


@on_import("psycopg.cursor", "psycopg", version_requirement="3.1.0")
def patch(m):
"""
patching module psycopg.cursor
- patches Cursor.copy
- patches Cursor.execute
- patches Cursor.executemany
"""
patch_function(m, "Cursor.copy", _copy)
patch_function(m, "Cursor.execute", _execute)
patch_function(m, "Cursor.executemany", _execute)
38 changes: 19 additions & 19 deletions aikido_zen/sinks/psycopg2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,6 @@
from aikido_zen.sinks import on_import, before, patch_function, after


@on_import("psycopg2")
def patch(m):
"""
patching module psycopg2
- patches psycopg2.connect
cannot set 'execute' attribute of immutable type 'psycopg2.extensions.cursor',
so we create our own cursor factory to bypass this limitation.
"""
compatible = is_package_compatible(
required_version="2.9.2", packages=["psycopg2", "psycopg2-binary"]
)
if not compatible:
# Users can install either psycopg2 or psycopg2-binary, we need to check if at least
# one is installed and if they meet version requirements
return

patch_function(m, "connect", _connect)


@after
def _connect(func, instance, _args, _kwargs, rv):
"""
Expand Down Expand Up @@ -56,3 +37,22 @@
query = get_argument(args, kwargs, 0, "query")
op = f"psycopg2.Connection.Cursor.{func.__name__}"
vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres"))


@on_import("psycopg2")
def patch(m):
"""
patching module psycopg2
- patches psycopg2.connect
cannot set 'execute' attribute of immutable type 'psycopg2.extensions.cursor',
so we create our own cursor factory to bypass this limitation.
"""
compatible = is_package_compatible(
required_version="2.9.2", packages=["psycopg2", "psycopg2-binary"]
)
if not compatible:
# Users can install either psycopg2 or psycopg2-binary, we need to check if at least
# one is installed and if they meet version requirements
return

Check warning on line 56 in aikido_zen/sinks/psycopg2.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/psycopg2.py#L56

Added line #L56 was not covered by tests

patch_function(m, "connect", _connect)
72 changes: 36 additions & 36 deletions aikido_zen/sinks/pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,6 @@
from . import patch_function, on_import, before


@on_import("pymongo.collection", "pymongo", version_requirement="3.10.0")
def patch(m):
"""
patching pymongo.collection
- patches Collection.*(filter, ...)
- patches Collection.*(..., filter, ...)
- patches Collection.*(pipeline, ...)
- patches Collection.bulk_write
src: https://github.com/mongodb/mongo-python-driver/blob/98658cfd1fea42680a178373333bf27f41153759/pymongo/synchronous/collection.py#L136
"""
# func(filter, ...)
patch_function(m, "Collection.replace_one", _func_filter_first)
patch_function(m, "Collection.update_one", _func_filter_first)
patch_function(m, "Collection.update_many", _func_filter_first)
patch_function(m, "Collection.delete_one", _func_filter_first)
patch_function(m, "Collection.delete_many", _func_filter_first)
patch_function(m, "Collection.count_documents", _func_filter_first)
patch_function(m, "Collection.find_one_and_delete", _func_filter_first)
patch_function(m, "Collection.find_one_and_replace", _func_filter_first)
patch_function(m, "Collection.find_one_and_update", _func_filter_first)
patch_function(m, "Collection.find", _func_filter_first)
patch_function(m, "Collection.find_raw_batches", _func_filter_first)
# find_one not present in list since find_one calls find function.

# func(..., filter, ...)
patch_function(m, "Collection.distinct", _func_filter_second)

# func(pipeline, ...)
patch_function(m, "Collection.watch", _func_pipeline)
patch_function(m, "Collection.aggregate", _func_pipeline)
patch_function(m, "Collection.aggregate_raw_batches", _func_pipeline)

# bulk_write
patch_function(m, "Collection.bulk_write", _bulk_write)


@before
def _func_filter_first(func, instance, args, kwargs):
"""Collection.func(filter, ...)"""
Expand Down Expand Up @@ -97,3 +61,39 @@ def _bulk_write(func, instance, args, kwargs):
op="pymongo.collection.Collection.bulk_write",
args=(request._filter,),
)


@on_import("pymongo.collection", "pymongo", version_requirement="3.10.0")
def patch(m):
"""
patching pymongo.collection
- patches Collection.*(filter, ...)
- patches Collection.*(..., filter, ...)
- patches Collection.*(pipeline, ...)
- patches Collection.bulk_write
src: https://github.com/mongodb/mongo-python-driver/blob/98658cfd1fea42680a178373333bf27f41153759/pymongo/synchronous/collection.py#L136
"""
# func(filter, ...)
patch_function(m, "Collection.replace_one", _func_filter_first)
patch_function(m, "Collection.update_one", _func_filter_first)
patch_function(m, "Collection.update_many", _func_filter_first)
patch_function(m, "Collection.delete_one", _func_filter_first)
patch_function(m, "Collection.delete_many", _func_filter_first)
patch_function(m, "Collection.count_documents", _func_filter_first)
patch_function(m, "Collection.find_one_and_delete", _func_filter_first)
patch_function(m, "Collection.find_one_and_replace", _func_filter_first)
patch_function(m, "Collection.find_one_and_update", _func_filter_first)
patch_function(m, "Collection.find", _func_filter_first)
patch_function(m, "Collection.find_raw_batches", _func_filter_first)
# find_one not present in list since find_one calls find function.

# func(..., filter, ...)
patch_function(m, "Collection.distinct", _func_filter_second)

# func(pipeline, ...)
patch_function(m, "Collection.watch", _func_pipeline)
patch_function(m, "Collection.aggregate", _func_pipeline)
patch_function(m, "Collection.aggregate_raw_batches", _func_pipeline)

# bulk_write
patch_function(m, "Collection.bulk_write", _bulk_write)
Loading