Skip to content

Commit 585c07d

Browse files
ArslanSaleemgventuriellipsis-dev[bot]
authored
fix(sql_query): validate if the query is not malicious (#1568)
* fix(sql_query): validate if the query is not malicious * fix(sql_sanitzer): fix condition of sql * feat(sql_sanitize): integrate sql_sanitize in new loader and test cases * chore: improve error message Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * mock connector function --------- Co-authored-by: Gabriele Venturi <lele.venturi@gmail.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
1 parent 847757a commit 585c07d

File tree

5 files changed

+208
-8
lines changed

5 files changed

+208
-8
lines changed

pandasai/data_loader/sql_loader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pandas as pd
55

66
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
7-
from pandasai.exceptions import InvalidDataSourceType
7+
from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError
8+
from pandasai.helpers.sql_sanitizer import is_sql_query_safe
89

910
from ..constants import (
1011
SUPPORTED_SOURCE_CONNECTORS,
@@ -36,6 +37,12 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
3637

3738
formatted_query = self.query_builder.format_query(query)
3839
load_function = self._get_loader_function(source_type)
40+
41+
if not is_sql_query_safe(formatted_query):
42+
raise MaliciousQueryError(
43+
"The SQL query is deemed unsafe and will not be executed."
44+
)
45+
3946
try:
4047
dataframe: pd.DataFrame = load_function(
4148
connection_info, formatted_query, params

pandasai/helpers/sql_sanitizer.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import re
33

4+
import sqlglot
5+
46

57
def sanitize_sql_table_name(filepath: str) -> str:
68
# Extract the file name without extension
@@ -14,3 +16,71 @@ def sanitize_sql_table_name(filepath: str) -> str:
1416
sanitized_name = sanitized_name[:max_length]
1517

1618
return sanitized_name
19+
20+
21+
def is_sql_query_safe(query: str) -> bool:
22+
try:
23+
# List of infected keywords to block (you can add more)
24+
infected_keywords = [
25+
r"\bINSERT\b",
26+
r"\bUPDATE\b",
27+
r"\bDELETE\b",
28+
r"\bDROP\b",
29+
r"\bEXEC\b",
30+
r"\bALTER\b",
31+
r"\bCREATE\b",
32+
r"\bMERGE\b",
33+
r"\bREPLACE\b",
34+
r"\bTRUNCATE\b",
35+
r"\bLOAD\b",
36+
r"\bGRANT\b",
37+
r"\bREVOKE\b",
38+
r"\bCALL\b",
39+
r"\bEXECUTE\b",
40+
r"\bSHOW\b",
41+
r"\bDESCRIBE\b",
42+
r"\bEXPLAIN\b",
43+
r"\bUSE\b",
44+
r"\bSET\b",
45+
r"\bDECLARE\b",
46+
r"\bOPEN\b",
47+
r"\bFETCH\b",
48+
r"\bCLOSE\b",
49+
r"\bSLEEP\b",
50+
r"\bBENCHMARK\b",
51+
r"\bDATABASE\b",
52+
r"\bUSER\b",
53+
r"\bCURRENT_USER\b",
54+
r"\bSESSION_USER\b",
55+
r"\bSYSTEM_USER\b",
56+
r"\bVERSION\b",
57+
r"\b@@VERSION\b",
58+
r"--",
59+
r"/\*.*\*/", # Block comments and inline comments
60+
]
61+
# Parse the query to extract its structure
62+
parsed = sqlglot.parse_one(query)
63+
64+
# Ensure the main query is SELECT
65+
if parsed.key.upper() != "SELECT":
66+
return False
67+
68+
# Check for infected keywords in the main query
69+
if any(
70+
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
71+
):
72+
return False
73+
74+
# Check for infected keywords in subqueries
75+
for subquery in parsed.find_all(sqlglot.exp.Subquery):
76+
subquery_sql = subquery.sql() # Get the SQL of the subquery
77+
if any(
78+
re.search(keyword, subquery_sql, re.IGNORECASE)
79+
for keyword in infected_keywords
80+
):
81+
return False
82+
83+
return True
84+
85+
except sqlglot.errors.ParseError:
86+
return False

tests/unit_tests/data_loader/test_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from unittest.mock import mock_open, patch
32

43
import pandas as pd

tests/unit_tests/data_loader/test_sql_loader.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import logging
2-
from unittest.mock import MagicMock, mock_open, patch
2+
from unittest.mock import MagicMock, patch
33

44
import pandas as pd
55
import pytest
66

77
from pandasai import VirtualDataFrame
8-
from pandasai.data_loader.loader import DatasetLoader
9-
from pandasai.data_loader.local_loader import LocalDatasetLoader
10-
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
118
from pandasai.data_loader.sql_loader import SQLDatasetLoader
129
from pandasai.dataframe.base import DataFrame
13-
from pandasai.exceptions import InvalidDataSourceType
10+
from pandasai.exceptions import MaliciousQueryError
1411

1512

1613
class TestSqlDatasetLoader:
@@ -138,3 +135,62 @@ def test_load_with_transformation(self, mysql_schema):
138135
loader_function.call_args[0][1]
139136
== "SELECT email, first_name, timestamp FROM users ORDER BY RAND() LIMIT 5"
140137
)
138+
139+
def test_mysql_malicious_query(self, mysql_schema):
140+
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
141+
with patch(
142+
"pandasai.data_loader.sql_loader.is_sql_query_safe"
143+
) as mock_sql_query, patch(
144+
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
145+
) as mock_loader_function:
146+
mocked_exec_function = MagicMock()
147+
mock_df = DataFrame(
148+
pd.DataFrame(
149+
{
150+
"email": ["test@example.com"],
151+
"first_name": ["John"],
152+
"timestamp": [pd.Timestamp.now()],
153+
}
154+
)
155+
)
156+
mocked_exec_function.return_value = mock_df
157+
mock_loader_function.return_value = mocked_exec_function
158+
loader = SQLDatasetLoader(mysql_schema, "test/users")
159+
mock_sql_query.return_value = False
160+
logging.debug("Loading schema from dataset path: %s", loader)
161+
162+
with pytest.raises(MaliciousQueryError):
163+
loader.execute_query("DROP TABLE users")
164+
165+
mock_sql_query.assert_called_once_with("DROP TABLE users")
166+
167+
def test_mysql_safe_query(self, mysql_schema):
168+
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
169+
with patch(
170+
"pandasai.data_loader.sql_loader.is_sql_query_safe"
171+
) as mock_sql_query, patch(
172+
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
173+
) as mock_loader_function, patch(
174+
"pandasai.data_loader.sql_loader.SQLDatasetLoader._apply_transformations"
175+
) as mock_apply_transformations:
176+
mocked_exec_function = MagicMock()
177+
mock_df = DataFrame(
178+
pd.DataFrame(
179+
{
180+
"email": ["test@example.com"],
181+
"first_name": ["John"],
182+
"timestamp": [pd.Timestamp.now()],
183+
}
184+
)
185+
)
186+
mocked_exec_function.return_value = mock_df
187+
mock_apply_transformations.return_value = mock_df
188+
mock_loader_function.return_value = mocked_exec_function
189+
loader = SQLDatasetLoader(mysql_schema, "test/users")
190+
mock_sql_query.return_value = True
191+
logging.debug("Loading schema from dataset path: %s", loader)
192+
193+
result = loader.execute_query("select * from users")
194+
195+
assert isinstance(result, DataFrame)
196+
mock_sql_query.assert_called_once_with("select * from users")

tests/unit_tests/helpers/test_sql_sanitizer.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
1+
from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name
22

33

44
class TestSqlSanitizer:
@@ -17,3 +17,71 @@ def test_filename_with_long_name(self):
1717
filepath = "/path/to/" + "a" * 100 + ".csv"
1818
expected = "a" * 64
1919
assert sanitize_sql_table_name(filepath) == expected
20+
21+
def test_safe_select_query(self):
22+
query = "SELECT * FROM users WHERE username = 'admin';"
23+
assert is_sql_query_safe(query)
24+
25+
def test_safe_with_query(self):
26+
query = "WITH user_data AS (SELECT * FROM users) SELECT * FROM user_data;"
27+
assert is_sql_query_safe(query)
28+
29+
def test_unsafe_insert_query(self):
30+
query = "INSERT INTO users (username, password) VALUES ('admin', 'password');"
31+
assert not is_sql_query_safe(query)
32+
33+
def test_unsafe_update_query(self):
34+
query = "UPDATE users SET password = 'newpassword' WHERE username = 'admin';"
35+
assert not is_sql_query_safe(query)
36+
37+
def test_unsafe_delete_query(self):
38+
query = "DELETE FROM users WHERE username = 'admin';"
39+
assert not is_sql_query_safe(query)
40+
41+
def test_unsafe_drop_query(self):
42+
query = "DROP TABLE users;"
43+
assert not is_sql_query_safe(query)
44+
45+
def test_unsafe_alter_query(self):
46+
query = "ALTER TABLE users ADD COLUMN age INT;"
47+
assert not is_sql_query_safe(query)
48+
49+
def test_unsafe_create_query(self):
50+
query = "CREATE TABLE users (id INT, username VARCHAR(50));"
51+
assert not is_sql_query_safe(query)
52+
53+
def test_safe_select_with_comment(self):
54+
query = "SELECT * FROM users WHERE username = 'admin' -- comment"
55+
assert not is_sql_query_safe(query) # Blocked by comment detection
56+
57+
def test_safe_select_with_inline_comment(self):
58+
query = "SELECT * FROM users /* inline comment */ WHERE username = 'admin';"
59+
assert not is_sql_query_safe(query) # Blocked by comment detection
60+
61+
def test_unsafe_query_with_subquery(self):
62+
query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders);"
63+
assert is_sql_query_safe(query) # No dangerous keyword in main or subquery
64+
65+
def test_unsafe_query_with_subquery_insert(self):
66+
query = (
67+
"SELECT * FROM users WHERE id IN (INSERT INTO orders (user_id) VALUES (1));"
68+
)
69+
assert not is_sql_query_safe(query) # Subquery contains INSERT, blocked
70+
71+
def test_invalid_sql(self):
72+
query = "INVALID SQL QUERY"
73+
assert not is_sql_query_safe(query) # Invalid query should return False
74+
75+
def test_safe_query_with_multiple_keywords(self):
76+
query = "SELECT name FROM users WHERE username = 'admin' AND age > 30;"
77+
assert is_sql_query_safe(query) # Safe query with no dangerous keyword
78+
79+
def test_safe_query_with_subquery(self):
80+
query = "SELECT name FROM users WHERE username IN (SELECT username FROM users WHERE age > 30);"
81+
assert is_sql_query_safe(
82+
query
83+
) # Safe query with subquery, no dangerous keyword
84+
85+
86+
if __name__ == "__main__":
87+
unittest.main()

0 commit comments

Comments
 (0)