@@ -242,10 +242,126 @@ def setoutputsize(self, size, column):
242
242
raise prestodb .exceptions .NotSupportedError
243
243
244
244
def execute (self , operation , params = None ):
245
- self ._query = prestodb .client .PrestoQuery (self ._request , sql = operation )
246
- result = self ._query .execute ()
247
- self ._iterator = iter (result )
248
- return result
245
+ if params :
246
+ assert isinstance (params , (list , tuple )), (
247
+ "params must be a list or tuple containing the query "
248
+ "parameter values"
249
+ )
250
+
251
+ statement_name = self ._generate_unique_statement_name ()
252
+ self ._prepare_statement (operation , statement_name )
253
+
254
+ try :
255
+ # Send execute statement and assign the return value to `results`
256
+ # as it will be returned by the function
257
+ self ._query = self ._execute_prepared_statement (statement_name , params )
258
+ self ._iterator = iter (self ._query .execute ())
259
+ finally :
260
+ # Send deallocate statement
261
+ # At this point the query can be deallocated since it has already
262
+ # been executed
263
+ # TODO: Consider caching prepared statements if requested by caller
264
+ self ._deallocate_prepared_statement (statement_name )
265
+ else :
266
+ self ._query = prestodb .client .PrestoQuery (self ._request , sql = operation )
267
+ self ._iterator = iter (self ._query .execute ())
268
+ return self
269
+
270
+ def _generate_unique_statement_name (self ):
271
+ return "st_" + uuid .uuid4 ().hex .replace ("-" , "" )
272
+
273
+ def _prepare_statement (self , statement : str , name : str ) -> None :
274
+ sql = f"PREPARE { name } FROM { statement } "
275
+ query = prestodb .client .PrestoQuery (self ._request , sql = sql )
276
+ query .execute ()
277
+
278
+ def _execute_prepared_statement (self , statement_name , params ):
279
+ sql = (
280
+ "EXECUTE "
281
+ + statement_name
282
+ + " USING "
283
+ + "," .join (map (self ._format_prepared_param , params ))
284
+ )
285
+ return prestodb .client .PrestoQuery (self ._request , sql = sql )
286
+
287
+ def _deallocate_prepared_statement (self , statement_name : str ) -> None :
288
+ sql = "DEALLOCATE PREPARE " + statement_name
289
+ query = prestodb .client .PrestoQuery (self ._request , sql = sql )
290
+ query .execute ()
291
+
292
+ def _format_prepared_param (self , param ):
293
+ """
294
+ Formats parameters to be passed in an
295
+ EXECUTE statement.
296
+ """
297
+ if param is None :
298
+ return "NULL"
299
+
300
+ if isinstance (param , bool ):
301
+ return "true" if param else "false"
302
+
303
+ if isinstance (param , int ):
304
+ # TODO represent numbers exceeding 64-bit (BIGINT) as DECIMAL
305
+ return "%d" % param
306
+
307
+ if isinstance (param , float ):
308
+ if param == float ("+inf" ):
309
+ return "infinity()"
310
+ if param == float ("-inf" ):
311
+ return "-infinity()"
312
+ return "DOUBLE '%s'" % param
313
+
314
+ if isinstance (param , str ):
315
+ return "'%s'" % param .replace ("'" , "''" )
316
+
317
+ if isinstance (param , bytes ):
318
+ return "X'%s'" % param .hex ()
319
+
320
+ if isinstance (param , datetime .datetime ) and param .tzinfo is None :
321
+ datetime_str = param .strftime ("%Y-%m-%d %H:%M:%S.%f" )
322
+ return "TIMESTAMP '%s'" % datetime_str
323
+
324
+ if isinstance (param , datetime .datetime ) and param .tzinfo is not None :
325
+ datetime_str = param .strftime ("%Y-%m-%d %H:%M:%S.%f" )
326
+ # offset-based timezones
327
+ return "TIMESTAMP '%s %s'" % (datetime_str , param .tzinfo .tzname (param ))
328
+
329
+ # We can't calculate the offset for a time without a point in time
330
+ if isinstance (param , datetime .time ) and param .tzinfo is None :
331
+ time_str = param .strftime ("%H:%M:%S.%f" )
332
+ return "TIME '%s'" % time_str
333
+
334
+ if isinstance (param , datetime .time ) and param .tzinfo is not None :
335
+ time_str = param .strftime ("%H:%M:%S.%f" )
336
+ # offset-based timezones
337
+ return "TIME '%s %s'" % (time_str , param .strftime ("%Z" )[3 :])
338
+
339
+ if isinstance (param , datetime .date ):
340
+ date_str = param .strftime ("%Y-%m-%d" )
341
+ return "DATE '%s'" % date_str
342
+
343
+ if isinstance (param , list ):
344
+ return "ARRAY[%s]" % "," .join (map (self ._format_prepared_param , param ))
345
+
346
+ if isinstance (param , tuple ):
347
+ return "ROW(%s)" % "," .join (map (self ._format_prepared_param , param ))
348
+
349
+ if isinstance (param , dict ):
350
+ keys = list (param .keys ())
351
+ values = [param [key ] for key in keys ]
352
+ return "MAP({}, {})" .format (
353
+ self ._format_prepared_param (keys ), self ._format_prepared_param (values )
354
+ )
355
+
356
+ if isinstance (param , uuid .UUID ):
357
+ return "UUID '%s'" % param
358
+
359
+ if isinstance (param , (bytes , bytearray )):
360
+ return "X'%s'" % binascii .hexlify (param ).decode ("utf-8" )
361
+
362
+ raise prestodb .exceptions .NotSupportedError (
363
+ "Query parameter of type '%s' is not supported." % type (param )
364
+ )
249
365
250
366
def executemany (self , operation , seq_of_params ):
251
367
raise prestodb .exceptions .NotSupportedError
0 commit comments