1
- from pathlib import Path
2
-
3
1
from supabase_mcp .exceptions import OperationNotAllowedError
4
2
from supabase_mcp .logger import logger
5
3
from supabase_mcp .services .database .migration_manager import MigrationManager
6
4
from supabase_mcp .services .database .postgres_client import PostgresClient , QueryResult
5
+ from supabase_mcp .services .database .sql .loader import SQLLoader
7
6
from supabase_mcp .services .database .sql .models import QueryValidationResults
8
7
from supabase_mcp .services .database .sql .validator import SQLValidator
9
8
from supabase_mcp .services .safety .models import ClientType , SafetyMode
@@ -24,26 +23,29 @@ class QueryManager:
24
23
validation and execution patterns.
25
24
"""
26
25
27
- # Path to SQL files directory
28
- SQL_DIR = Path (__file__ ).parent / "sql" / "queries"
29
-
30
26
def __init__ (
31
27
self ,
32
28
postgres_client : PostgresClient ,
33
29
safety_manager : SafetyManager ,
34
30
sql_validator : SQLValidator | None = None ,
35
31
migration_manager : MigrationManager | None = None ,
32
+ sql_loader : SQLLoader | None = None ,
36
33
):
37
34
"""
38
35
Initialize the QueryManager.
39
36
40
37
Args:
41
- db_client: The database client to use for executing queries
38
+ postgres_client: The database client to use for executing queries
39
+ safety_manager: The safety manager to use for validating operations
40
+ sql_validator: Optional SQL validator to use
41
+ migration_manager: Optional migration manager to use
42
+ sql_loader: Optional SQL loader to use
42
43
"""
43
44
self .db_client = postgres_client
44
45
self .safety_manager = safety_manager
45
46
self .validator = sql_validator or SQLValidator ()
46
- self .migration_manager = migration_manager or MigrationManager ()
47
+ self .sql_loader = sql_loader or SQLLoader ()
48
+ self .migration_manager = migration_manager or MigrationManager (loader = self .sql_loader )
47
49
48
50
def check_readonly (self ) -> bool :
49
51
"""Returns true if current safety mode is SAFE."""
@@ -103,7 +105,7 @@ async def handle_query_execution(self, validated_query: QueryValidationResults)
103
105
QueryResult: The result of the query execution
104
106
"""
105
107
readonly = self .check_readonly ()
106
- result = await self .db_client .execute_query_async (validated_query , readonly )
108
+ result = await self .db_client .execute_query (validated_query , readonly )
107
109
logger .debug (f"Query result: { result } " )
108
110
return result
109
111
@@ -120,24 +122,42 @@ async def handle_migration(
120
122
"""
121
123
# 1. Check if migration is needed
122
124
if not validation_result .needs_migration ():
123
- logger .info ("No migration needed for this query" )
125
+ logger .debug ("No migration needed for this query" )
124
126
return
125
127
126
- # 2. Create migration
127
- try :
128
- # Prepare the migration with the original query
129
- migration_query , migration_name = self .migration_manager .prepare_migration_query (
130
- validation_result , original_query , migration_name
131
- )
128
+ # 2. Prepare migration query
129
+ migration_query , name = self .migration_manager .prepare_migration_query (
130
+ validation_result , original_query , migration_name
131
+ )
132
+ logger .debug ("Migration query prepared" )
132
133
133
- # Validate migration query, since it's a raw query
134
- validated_query = self .validator .validate_query (migration_query )
134
+ # 3. Execute migration query
135
+ try :
136
+ # First, ensure the migration schema exists
137
+ await self .init_migration_schema ()
135
138
136
- await self .handle_query_execution (validated_query )
137
- logger .info (f"Successfully created migration: { migration_name } " )
139
+ # Then execute the migration query
140
+ migration_validation = self .validator .validate_query (migration_query )
141
+ await self .db_client .execute_query (migration_validation , readonly = False )
142
+ logger .info (f"Migration '{ name } ' executed successfully" )
138
143
except Exception as e :
139
144
logger .debug (f"Migration failure details: { str (e )} " )
140
- raise e
145
+ # We don't want to fail the main query if migration fails
146
+ # Just log the error and continue
147
+ logger .warning (f"Failed to record migration '{ name } ': { e } " )
148
+
149
+ async def init_migration_schema (self ) -> None :
150
+ """Initialize the migrations schema and table if they don't exist."""
151
+ try :
152
+ # Get the initialization query
153
+ init_query = self .sql_loader .get_init_migrations_query ()
154
+
155
+ # Validate and execute it
156
+ init_validation = self .validator .validate_query (init_query )
157
+ await self .db_client .execute_query (init_validation , readonly = False )
158
+ logger .debug ("Migrations schema initialized successfully" )
159
+ except Exception as e :
160
+ logger .warning (f"Failed to initialize migrations schema: { e } " )
141
161
142
162
async def handle_confirmation (self , confirmation_id : str ) -> QueryResult :
143
163
"""
@@ -163,101 +183,22 @@ async def handle_confirmation(self, confirmation_id: str) -> QueryResult:
163
183
# Call handle_query with the query and has_confirmation=True
164
184
return await self .handle_query (query , has_confirmation = True )
165
185
166
- @classmethod
167
- def load_sql (cls , filename : str ) -> str :
168
- """
169
- Load SQL from a file in the sql directory.
170
-
171
- Args:
172
- filename: Name of the SQL file (with or without .sql extension)
173
-
174
- Returns:
175
- str: The SQL query from the file
176
-
177
- Raises:
178
- FileNotFoundError: If the SQL file doesn't exist
179
- """
180
- # Ensure the filename has .sql extension
181
- if not filename .endswith (".sql" ):
182
- filename = f"{ filename } .sql"
183
-
184
- file_path = cls .SQL_DIR / filename
185
-
186
- if not file_path .exists ():
187
- logger .error (f"SQL file not found: { file_path } " )
188
- raise FileNotFoundError (f"SQL file not found: { file_path } " )
189
-
190
- with open (file_path ) as f :
191
- sql = f .read ().strip ()
192
- logger .debug (f"Loaded SQL file: { filename } ({ len (sql )} chars)" )
193
- return sql
194
-
195
186
def get_schemas_query (self ) -> str :
196
- """
197
- Get SQL query to list all schemas with their sizes and table counts.
198
-
199
- Returns:
200
- str: SQL query for listing schemas
201
- """
202
- logger .debug ("Getting schemas query" )
203
- return self .load_sql ("get_schemas" )
187
+ """Get a query to list all schemas."""
188
+ return self .sql_loader .get_schemas_query ()
204
189
205
190
def get_tables_query (self , schema_name : str ) -> str :
206
- """
207
- Get SQL query to list all tables in a schema.
208
-
209
- Args:
210
- schema_name: Name of the schema
211
-
212
- Returns:
213
- str: SQL query for listing tables
214
- """
215
- logger .debug (f"Getting tables query for schema: { schema_name } " )
216
- sql = self .load_sql ("get_tables" )
217
- return sql .format (schema_name = schema_name )
191
+ """Get a query to list all tables in a schema."""
192
+ return self .sql_loader .get_tables_query (schema_name )
218
193
219
194
def get_table_schema_query (self , schema_name : str , table : str ) -> str :
220
- """
221
- Get SQL query to get detailed table schema.
222
-
223
- Args:
224
- schema_name: Name of the schema
225
- table: Name of the table
226
-
227
- Returns:
228
- str: SQL query for getting table schema
229
- """
230
- logger .debug (f"Getting table schema query for { schema_name } .{ table } " )
231
- sql = self .load_sql ("get_table_schema" )
232
- return sql .format (schema_name = schema_name , table = table )
195
+ """Get a query to get the schema of a table."""
196
+ return self .sql_loader .get_table_schema_query (schema_name , table )
233
197
234
198
def get_migrations_query (
235
199
self , limit : int = 50 , offset : int = 0 , name_pattern : str = "" , include_full_queries : bool = False
236
200
) -> str :
237
- """
238
- Get a query to retrieve migrations from Supabase with filtering and pagination.
239
-
240
- Args:
241
- limit: Maximum number of migrations to return (default: 50)
242
- offset: Number of migrations to skip (for pagination)
243
- name_pattern: Optional pattern to filter migrations by name
244
- include_full_queries: Whether to include the full SQL statements in the result
245
-
246
- Returns:
247
- str: SQL query to get migrations with the specified filters
248
- """
249
- logger .debug (f"Getting migrations query with limit={ limit } , offset={ offset } , name_pattern='{ name_pattern } '" )
250
- sql = self .load_sql ("get_migrations" )
251
-
252
- # Sanitize inputs to prevent SQL injection
253
- sanitized_limit = max (1 , min (100 , limit )) # Limit between 1 and 100
254
- sanitized_offset = max (0 , offset )
255
- sanitized_name_pattern = name_pattern .replace ("'" , "''" ) # Escape single quotes
256
-
257
- # Format the SQL query with the parameters
258
- return sql .format (
259
- limit = sanitized_limit ,
260
- offset = sanitized_offset ,
261
- name_pattern = sanitized_name_pattern ,
262
- include_full_queries = "true" if include_full_queries else "false" ,
201
+ """Get a query to list migrations."""
202
+ return self .sql_loader .get_migrations_query (
203
+ limit = limit , offset = offset , name_pattern = name_pattern , include_full_queries = include_full_queries
263
204
)
0 commit comments