|
2 | 2 | import logging
|
3 | 3 | import re
|
4 | 4 | import time
|
5 |
| -from typing import Dict, List |
| 5 | +from typing import Dict, List, Tuple |
6 | 6 |
|
7 | 7 | from fastapi import APIRouter, Body, Depends
|
8 | 8 |
|
@@ -95,66 +95,109 @@ async def get_editor_sql(
|
95 | 95 | return Result.failed(msg="not have sql!")
|
96 | 96 |
|
97 | 97 |
|
| 98 | +def sanitize_sql(sql: str, db_type: str = None) -> Tuple[bool, str, dict]: |
| 99 | + """Simple SQL sanitizer to prevent injection. |
| 100 | +
|
| 101 | + Returns: |
| 102 | + Tuple of (is_safe, reason, params) |
| 103 | + """ |
| 104 | + # Normalize SQL (remove comments and excess whitespace) |
| 105 | + sql = re.sub(r"/\*.*?\*/", " ", sql) |
| 106 | + sql = re.sub(r"--.*?$", " ", sql, flags=re.MULTILINE) |
| 107 | + sql = re.sub(r"\s+", " ", sql).strip() |
| 108 | + |
| 109 | + # Block multiple statements |
| 110 | + if re.search(r";\s*(?!--|\*/|$)", sql): |
| 111 | + return False, "Multiple SQL statements are not allowed", {} |
| 112 | + |
| 113 | + # Block dangerous operations for all databases |
| 114 | + dangerous_patterns = [ |
| 115 | + r"(?i)INTO\s+(?:OUT|DUMP)FILE", |
| 116 | + r"(?i)LOAD\s+DATA", |
| 117 | + r"(?i)SYSTEM", |
| 118 | + r"(?i)EXEC\s+", |
| 119 | + r"(?i)SHELL\b", |
| 120 | + r"(?i)DROP\s+DATABASE", |
| 121 | + r"(?i)DROP\s+USER", |
| 122 | + r"(?i)GRANT\s+", |
| 123 | + r"(?i)REVOKE\s+", |
| 124 | + r"(?i)ALTER\s+(USER|DATABASE)", |
| 125 | + ] |
| 126 | + |
| 127 | + # Add DuckDB specific patterns |
| 128 | + if db_type == "duckdb": |
| 129 | + dangerous_patterns.extend( |
| 130 | + [ |
| 131 | + r"(?i)COPY\b", |
| 132 | + r"(?i)EXPORT\b", |
| 133 | + r"(?i)IMPORT\b", |
| 134 | + r"(?i)INSTALL\b", |
| 135 | + r"(?i)READ_\w+\b", |
| 136 | + r"(?i)WRITE_\w+\b", |
| 137 | + r"(?i)\.EXECUTE\(", |
| 138 | + r"(?i)PRAGMA\b", |
| 139 | + ] |
| 140 | + ) |
| 141 | + |
| 142 | + for pattern in dangerous_patterns: |
| 143 | + if re.search(pattern, sql): |
| 144 | + return False, f"Operation not allowed: {pattern}", {} |
| 145 | + |
| 146 | + # Allow SELECT, CREATE TABLE, INSERT, UPDATE, and DELETE operations |
| 147 | + # We're no longer restricting to read-only operations |
| 148 | + allowed_operations = re.match( |
| 149 | + r"(?i)^\s*(SELECT|CREATE\s+TABLE|INSERT\s+INTO|UPDATE|DELETE\s+FROM|ALTER\s+TABLE)\b", |
| 150 | + sql, |
| 151 | + ) |
| 152 | + if not allowed_operations: |
| 153 | + return ( |
| 154 | + False, |
| 155 | + "Operation not supported. Only SELECT, CREATE TABLE, INSERT, UPDATE, " |
| 156 | + "DELETE and ALTER TABLE operations are allowed", |
| 157 | + {}, |
| 158 | + ) |
| 159 | + |
| 160 | + # Extract parameters (simplified) |
| 161 | + params = {} |
| 162 | + param_count = 0 |
| 163 | + |
| 164 | + # Extract string literals |
| 165 | + def replace_string(match): |
| 166 | + nonlocal param_count |
| 167 | + param_name = f"param_{param_count}" |
| 168 | + params[param_name] = match.group(1) |
| 169 | + param_count += 1 |
| 170 | + return f":{param_name}" |
| 171 | + |
| 172 | + # Replace string literals with parameters |
| 173 | + parameterized_sql = re.sub(r"'([^']*)'", replace_string, sql) |
| 174 | + |
| 175 | + return True, parameterized_sql, params |
| 176 | + |
| 177 | + |
98 | 178 | @router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
99 | 179 | async def editor_sql_run(run_param: dict = Body()):
|
100 | 180 | logger.info(f"editor_sql_run:{run_param}")
|
101 | 181 | db_name = run_param["db_name"]
|
102 | 182 | sql = run_param["sql"]
|
| 183 | + |
103 | 184 | if not db_name and not sql:
|
104 | 185 | return Result.failed(msg="SQL run param error!")
|
105 | 186 |
|
106 |
| - # Validate database type and prevent dangerous operations |
| 187 | + # Get database connection |
107 | 188 | conn = CFG.local_db_manager.get_connector(db_name)
|
108 | 189 | db_type = getattr(conn, "db_type", "").lower()
|
109 | 190 |
|
110 |
| - # Block dangerous operations for DuckDB |
111 |
| - if db_type == "duckdb": |
112 |
| - # Block file operations and system commands |
113 |
| - dangerous_keywords = [ |
114 |
| - # File operations |
115 |
| - "copy", |
116 |
| - "export", |
117 |
| - "import", |
118 |
| - "load", |
119 |
| - "install", |
120 |
| - "read_", |
121 |
| - "write_", |
122 |
| - "save", |
123 |
| - "from_", |
124 |
| - "to_", |
125 |
| - # System commands |
126 |
| - "create_", |
127 |
| - "drop_", |
128 |
| - ".execute(", |
129 |
| - "system", |
130 |
| - "shell", |
131 |
| - # Additional DuckDB specific operations |
132 |
| - "attach", |
133 |
| - "detach", |
134 |
| - "pragma", |
135 |
| - "checkpoint", |
136 |
| - "load_extension", |
137 |
| - "unload_extension", |
138 |
| - # File paths |
139 |
| - "/'", |
140 |
| - "'/'", |
141 |
| - "\\", |
142 |
| - "://", |
143 |
| - ] |
144 |
| - sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass |
145 |
| - if any(keyword in sql_lower for keyword in dangerous_keywords): |
146 |
| - logger.warning(f"Blocked dangerous SQL operation attempt: {sql}") |
147 |
| - return Result.failed(msg="Operation not allowed for security reasons") |
148 |
| - |
149 |
| - # Additional check for file path patterns |
150 |
| - if re.search(r"['\"].*[/\\].*['\"]", sql): |
151 |
| - logger.warning(f"Blocked file path in SQL: {sql}") |
152 |
| - return Result.failed(msg="File operations not allowed") |
| 191 | + # Sanitize and parameterize the SQL query |
| 192 | + is_safe, result, params = sanitize_sql(sql, db_type) |
| 193 | + if not is_safe: |
| 194 | + logger.warning(f"Blocked dangerous SQL: {sql}") |
| 195 | + return Result.failed(msg=f"Operation not allowed: {result}") |
153 | 196 |
|
154 | 197 | try:
|
155 | 198 | start_time = time.time() * 1000
|
156 |
| - # Add timeout protection |
157 |
| - colunms, sql_result = conn.query_ex(sql, timeout=30) |
| 199 | + # Use the parameterized query and parameters |
| 200 | + colunms, sql_result = conn.query_ex(result, params=params, timeout=30) |
158 | 201 | # Convert result type safely
|
159 | 202 | sql_result = [
|
160 | 203 | tuple(str(x) if x is not None else None for x in row) for row in sql_result
|
@@ -216,103 +259,57 @@ async def get_editor_chart_info(
|
216 | 259 |
|
217 | 260 |
|
218 | 261 | @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
219 |
| -async def editor_chart_run(run_param: dict = Body()): |
220 |
| - logger.info(f"editor_chart_run:{run_param}") |
| 262 | +async def chart_run(run_param: dict = Body()): |
| 263 | + logger.info(f"chart_run:{run_param}") |
221 | 264 | db_name = run_param["db_name"]
|
222 | 265 | sql = run_param["sql"]
|
223 | 266 | chart_type = run_param["chart_type"]
|
224 | 267 |
|
225 |
| - # Validate input parameters |
226 |
| - if not db_name or not sql or not chart_type: |
227 |
| - return Result.failed("Required parameters missing") |
228 |
| - |
229 |
| - try: |
230 |
| - # Validate database type and prevent dangerous operations |
231 |
| - db_conn = CFG.local_db_manager.get_connector(db_name) |
232 |
| - db_type = getattr(db_conn, "db_type", "").lower() |
233 |
| - |
234 |
| - # Block dangerous operations for DuckDB |
235 |
| - if db_type == "duckdb": |
236 |
| - # Block file operations and system commands |
237 |
| - dangerous_keywords = [ |
238 |
| - # File operations |
239 |
| - "copy", |
240 |
| - "export", |
241 |
| - "import", |
242 |
| - "load", |
243 |
| - "install", |
244 |
| - "read_", |
245 |
| - "write_", |
246 |
| - "save", |
247 |
| - "from_", |
248 |
| - "to_", |
249 |
| - # System commands |
250 |
| - "create_", |
251 |
| - "drop_", |
252 |
| - ".execute(", |
253 |
| - "system", |
254 |
| - "shell", |
255 |
| - # Additional DuckDB specific operations |
256 |
| - "attach", |
257 |
| - "detach", |
258 |
| - "pragma", |
259 |
| - "checkpoint", |
260 |
| - "load_extension", |
261 |
| - "unload_extension", |
262 |
| - # File paths |
263 |
| - "/'", |
264 |
| - "'/'", |
265 |
| - "\\", |
266 |
| - "://", |
267 |
| - ] |
268 |
| - sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass |
269 |
| - if any(keyword in sql_lower for keyword in dangerous_keywords): |
270 |
| - logger.warning( |
271 |
| - f"Blocked dangerous SQL operation attempt in chart: {sql}" |
272 |
| - ) |
273 |
| - return Result.failed(msg="Operation not allowed for security reasons") |
274 |
| - |
275 |
| - # Additional check for file path patterns |
276 |
| - if re.search(r"['\"].*[/\\].*['\"]", sql): |
277 |
| - logger.warning(f"Blocked file path in chart SQL: {sql}") |
278 |
| - return Result.failed(msg="File operations not allowed") |
| 268 | + # Get database connection |
| 269 | + db_conn = CFG.local_db_manager.get_connector(db_name) |
| 270 | + db_type = getattr(db_conn, "db_type", "").lower() |
279 | 271 |
|
280 |
| - dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() |
| 272 | + # Sanitize and parameterize the SQL query |
| 273 | + is_safe, result, params = sanitize_sql(sql, db_type) |
| 274 | + if not is_safe: |
| 275 | + logger.warning(f"Blocked dangerous SQL: {sql}") |
| 276 | + return Result.failed(msg=f"Operation not allowed: {result}") |
281 | 277 |
|
| 278 | + try: |
282 | 279 | start_time = time.time() * 1000
|
283 |
| - |
284 |
| - # Execute query with timeout |
285 |
| - colunms, sql_result = db_conn.query_ex(sql, timeout=30) |
286 |
| - |
287 |
| - # Safely convert and process results |
288 |
| - field_names, chart_values = dashboard_data_loader.get_chart_values_by_data( |
289 |
| - colunms, |
290 |
| - [ |
291 |
| - tuple(str(x) if x is not None else None for x in row) |
292 |
| - for row in sql_result |
293 |
| - ], |
294 |
| - sql, |
295 |
| - ) |
296 |
| - |
| 280 | + # Use the parameterized query and parameters |
| 281 | + colunms, sql_result = db_conn.query_ex(result, params=params, timeout=30) |
| 282 | + # Convert result type safely |
| 283 | + sql_result = [ |
| 284 | + tuple(str(x) if x is not None else None for x in row) for row in sql_result |
| 285 | + ] |
297 | 286 | # Calculate execution time
|
298 | 287 | end_time = time.time() * 1000
|
299 | 288 | sql_run_data: SqlRunData = SqlRunData(
|
300 | 289 | result_info="",
|
301 | 290 | run_cost=(end_time - start_time) / 1000,
|
302 | 291 | colunms=colunms,
|
303 |
| - values=[list(row) for row in sql_result], |
| 292 | + values=sql_result, |
304 | 293 | )
|
305 |
| - return Result.succ( |
306 |
| - ChartRunData( |
307 |
| - sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type |
| 294 | + |
| 295 | + chart_values = [] |
| 296 | + for i in range(len(sql_result)): |
| 297 | + row = sql_result[i] |
| 298 | + chart_values.append( |
| 299 | + { |
| 300 | + "name": row[0], |
| 301 | + "type": "value", |
| 302 | + "value": row[1] if len(row) > 1 else "0", |
| 303 | + } |
308 | 304 | )
|
| 305 | + |
| 306 | + chart_data: ChartRunData = ChartRunData( |
| 307 | + sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type |
309 | 308 | )
|
| 309 | + return Result.succ(chart_data) |
310 | 310 | except Exception as e:
|
311 |
| - logger.exception("Chart sql run failed!") |
312 |
| - sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[]) |
313 |
| - return Result.succ( |
314 |
| - ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type) |
315 |
| - ) |
| 311 | + logger.error(f"chart_run exception: {str(e)}", exc_info=True) |
| 312 | + return Result.failed(msg=str(e)) |
316 | 313 |
|
317 | 314 |
|
318 | 315 | @router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
|
0 commit comments