Skip to content

Commit f667367

Browse files
authored
fix(sql): allow paramerized query through sql sanitization (#1576)
1 parent 0c6738b commit f667367

File tree

5 files changed

+61
-4
lines changed

5 files changed

+61
-4
lines changed

extensions/connectors/sql/pandasai_sql/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Optional
23

34
import pandas as pd
@@ -17,7 +18,11 @@ def load_from_mysql(
1718
database=connection_info.database,
1819
port=connection_info.port,
1920
)
20-
return pd.read_sql(query, conn, params=params)
21+
# Suppress warnings of SqlAlchemy
22+
# TODO - Later can be removed when SqlAlchemy is to used
23+
with warnings.catch_warnings():
24+
warnings.simplefilter("ignore", category=UserWarning)
25+
return pd.read_sql(query, conn, params=params)
2126

2227

2328
def load_from_postgres(
@@ -32,7 +37,11 @@ def load_from_postgres(
3237
dbname=connection_info.database,
3338
port=connection_info.port,
3439
)
35-
return pd.read_sql(query, conn, params=params)
40+
# Suppress warnings of SqlAlchemy
41+
# TODO - Later can be removed when SqlAlchemy is to used
42+
with warnings.catch_warnings():
43+
warnings.simplefilter("ignore", category=UserWarning)
44+
return pd.read_sql(query, conn, params=params)
3645

3746

3847
def load_from_cockroachdb(
@@ -47,7 +56,11 @@ def load_from_cockroachdb(
4756
dbname=connection_info.database,
4857
port=connection_info.port,
4958
)
50-
return pd.read_sql(query, conn, params=params)
59+
# Suppress warnings of SqlAlchemy
60+
# TODO - Later can be removed when SqlAlchemy is to used
61+
with warnings.catch_warnings():
62+
warnings.simplefilter("ignore", category=UserWarning)
63+
return pd.read_sql(query, conn, params=params)
5164

5265

5366
__all__ = [

pandasai/data_loader/sql_loader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
4747
connection_info, formatted_query, params
4848
)
4949
return self._apply_transformations(dataframe)
50+
51+
except ModuleNotFoundError as e:
52+
raise ImportError(
53+
f"{source_type.capitalize()} connector not found. Please install the pandasai_sql[{source_type}] library, e.g. `pip install pandasai_sql[{source_type}]`."
54+
) from e
55+
5056
except Exception as e:
5157
raise RuntimeError(
5258
f"Failed to execute query for '{source_type}' with: {formatted_query}"

pandasai/helpers/sql_sanitizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,14 @@ def is_sql_query_safe(query: str) -> bool:
5858
r"--",
5959
r"/\*.*\*/", # Block comments and inline comments
6060
]
61+
62+
placeholder = "___PLACEHOLDER___" # Temporary placeholder for params
63+
64+
# Replace '%s' (MySQL, Psycopg2) with a unique placeholder
65+
temp_query = query.replace("%s", placeholder)
66+
6167
# Parse the query to extract its structure
62-
parsed = sqlglot.parse_one(query)
68+
parsed = sqlglot.parse_one(temp_query)
6369

6470
# Ensure the main query is SELECT
6571
if parsed.key.upper() != "SELECT":

tests/unit_tests/data_loader/test_sql_loader.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,31 @@ def test_mysql_safe_query(self, mysql_schema):
197197

198198
assert isinstance(result, DataFrame)
199199
mock_sql_query.assert_called_once_with("select * from users")
200+
201+
def test_mysql_malicious_with_no_import(self, mysql_schema):
202+
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
203+
with patch(
204+
"pandasai.data_loader.sql_loader.is_sql_query_safe"
205+
) as mock_sql_query, patch(
206+
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
207+
) as mock_loader_function:
208+
mocked_exec_function = MagicMock()
209+
mock_df = DataFrame(
210+
pd.DataFrame(
211+
{
212+
"email": ["test@example.com"],
213+
"first_name": ["John"],
214+
"timestamp": [pd.Timestamp.now()],
215+
}
216+
)
217+
)
218+
mocked_exec_function.return_value = mock_df
219+
220+
mock_exec_function = MagicMock()
221+
mock_loader_function.return_value = mock_exec_function
222+
mock_exec_function.side_effect = ModuleNotFoundError("Error")
223+
loader = SQLDatasetLoader(mysql_schema, "test/users")
224+
mock_sql_query.return_value = True
225+
logging.debug("Loading schema from dataset path: %s", loader)
226+
with pytest.raises(ImportError):
227+
loader.execute_query("select * from users")

tests/unit_tests/helpers/test_sql_sanitizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def test_safe_query_with_subquery(self):
8282
query
8383
) # Safe query with subquery, no dangerous keyword
8484

85+
def test_safe_query_with_query_params(self):
86+
query = "SELECT * FROM (SELECT * FROM heart_data) AS filtered_data LIMIT %s OFFSET %s"
87+
assert is_sql_query_safe(query)
88+
8589

8690
if __name__ == "__main__":
8791
unittest.main()

0 commit comments

Comments
 (0)