Skip to content

[BUG] SQL Injection through CVE Bypass in DB-GPT 0.7.0 (CVE-2024-10835 & CVE-2024-10901) #2650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 124 additions & 127 deletions packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/editor/api_editor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import re
import time
from typing import Dict, List
from typing import Dict, List, Tuple

from fastapi import APIRouter, Body, Depends

Expand Down Expand Up @@ -95,66 +95,109 @@ async def get_editor_sql(
return Result.failed(msg="not have sql!")


def sanitize_sql(sql: str, db_type: str = None) -> Tuple[bool, str, dict]:
"""Simple SQL sanitizer to prevent injection.
Returns:
Tuple of (is_safe, reason, params)
"""
# Normalize SQL (remove comments and excess whitespace)
sql = re.sub(r"/\*.*?\*/", " ", sql)
sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE)
sql = re.sub(r"\s+", " ", sql).strip()

# Block multiple statements
if re.search(r";\s*(?!--|\*/|$)", sql):
return False, "Multiple SQL statements are not allowed", {}

# Block dangerous operations for all databases
dangerous_patterns = [
r"(?i)INTO\s+(?:OUT|DUMP)FILE",
r"(?i)LOAD\s+DATA",
r"(?i)SYSTEM",
r"(?i)EXEC\s+",
r"(?i)SHELL\b",
r"(?i)DROP\s+DATABASE",
r"(?i)DROP\s+USER",
r"(?i)GRANT\s+",
r"(?i)REVOKE\s+",
r"(?i)ALTER\s+(USER|DATABASE)",
]

# Add DuckDB specific patterns
if db_type == "duckdb":
dangerous_patterns.extend(
[
r"(?i)COPY\b",
r"(?i)EXPORT\b",
r"(?i)IMPORT\b",
r"(?i)INSTALL\b",
r"(?i)READ_\w+\b",
r"(?i)WRITE_\w+\b",
r"(?i)\.EXECUTE\(",
r"(?i)PRAGMA\b",
]
)

for pattern in dangerous_patterns:
if re.search(pattern, sql):
return False, f"Operation not allowed: {pattern}", {}

# Allow SELECT, CREATE TABLE, INSERT, UPDATE, and DELETE operations
# We're no longer restricting to read-only operations
allowed_operations = re.match(
r"(?i)^\s*(SELECT|CREATE\s+TABLE|INSERT\s+INTO|UPDATE|DELETE\s+FROM|ALTER\s+TABLE)\b",
sql,
)
if not allowed_operations:
return (
False,
"Operation not supported. Only SELECT, CREATE TABLE, INSERT, UPDATE, "
"DELETE and ALTER TABLE operations are allowed",
{},
)

# Extract parameters (simplified)
params = {}
param_count = 0

# Extract string literals
def replace_string(match):
nonlocal param_count
param_name = f"param_{param_count}"
params[param_name] = match.group(1)
param_count += 1
return f":{param_name}"

# Replace string literals with parameters
parameterized_sql = re.sub(r"'([^']*)'", replace_string, sql)

return True, parameterized_sql, params


@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
async def editor_sql_run(run_param: dict = Body()):
logger.info(f"editor_sql_run:{run_param}")
db_name = run_param["db_name"]
sql = run_param["sql"]

if not db_name and not sql:
return Result.failed(msg="SQL run param error!")

# Validate database type and prevent dangerous operations
# Get database connection
conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(conn, "db_type", "").lower()

# Block dangerous operations for DuckDB
if db_type == "duckdb":
# Block file operations and system commands
dangerous_keywords = [
# File operations
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
logger.warning(f"Blocked dangerous SQL operation attempt: {sql}")
return Result.failed(msg="Operation not allowed for security reasons")

# Additional check for file path patterns
if re.search(r"['\"].*[/\\].*['\"]", sql):
logger.warning(f"Blocked file path in SQL: {sql}")
return Result.failed(msg="File operations not allowed")
# Sanitize and parameterize the SQL query
is_safe, result, params = sanitize_sql(sql, db_type)
if not is_safe:
logger.warning(f"Blocked dangerous SQL: {sql}")
return Result.failed(msg=f"Operation not allowed: {result}")

try:
start_time = time.time() * 1000
# Add timeout protection
colunms, sql_result = conn.query_ex(sql, timeout=30)
# Use the parameterized query and parameters
colunms, sql_result = conn.query_ex(result, params=params, timeout=30)
# Convert result type safely
sql_result = [
tuple(str(x) if x is not None else None for x in row) for row in sql_result
Expand Down Expand Up @@ -216,103 +259,57 @@ async def get_editor_chart_info(


@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(run_param: dict = Body()):
logger.info(f"editor_chart_run:{run_param}")
async def chart_run(run_param: dict = Body()):
logger.info(f"chart_run:{run_param}")
db_name = run_param["db_name"]
sql = run_param["sql"]
chart_type = run_param["chart_type"]

# Validate input parameters
if not db_name or not sql or not chart_type:
return Result.failed("Required parameters missing")

try:
# Validate database type and prevent dangerous operations
db_conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(db_conn, "db_type", "").lower()

# Block dangerous operations for DuckDB
if db_type == "duckdb":
# Block file operations and system commands
dangerous_keywords = [
# File operations
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
logger.warning(
f"Blocked dangerous SQL operation attempt in chart: {sql}"
)
return Result.failed(msg="Operation not allowed for security reasons")

# Additional check for file path patterns
if re.search(r"['\"].*[/\\].*['\"]", sql):
logger.warning(f"Blocked file path in chart SQL: {sql}")
return Result.failed(msg="File operations not allowed")
# Get database connection
db_conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(db_conn, "db_type", "").lower()

dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
# Sanitize and parameterize the SQL query
is_safe, result, params = sanitize_sql(sql, db_type)
if not is_safe:
logger.warning(f"Blocked dangerous SQL: {sql}")
return Result.failed(msg=f"Operation not allowed: {result}")

try:
start_time = time.time() * 1000

# Execute query with timeout
colunms, sql_result = db_conn.query_ex(sql, timeout=30)

# Safely convert and process results
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms,
[
tuple(str(x) if x is not None else None for x in row)
for row in sql_result
],
sql,
)

# Use the parameterized query and parameters
colunms, sql_result = db_conn.query_ex(result, params=params, timeout=30)
# Convert result type safely
sql_result = [
tuple(str(x) if x is not None else None for x in row) for row in sql_result
]
# Calculate execution time
end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(
result_info="",
run_cost=(end_time - start_time) / 1000,
colunms=colunms,
values=[list(row) for row in sql_result],
values=sql_result,
)
return Result.succ(
ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type

chart_values = []
for i in range(len(sql_result)):
row = sql_result[i]
chart_values.append(
{
"name": row[0],
"type": "value",
"value": row[1] if len(row) > 1 else "0",
}
)

chart_data: ChartRunData = ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
)
return Result.succ(chart_data)
except Exception as e:
logger.exception("Chart sql run failed!")
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
return Result.succ(
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
)
logger.error(f"chart_run exception: {str(e)}", exc_info=True)
return Result.failed(msg=str(e))


@router.post("/v1/chart/editor/submit", response_model=Result[bool])
Expand Down
Loading