Skip to content

Commit 5a687b4

Browse files
mlyublenaharph
andcommitted
Support parameterized statements in Presto Python Client
Cherry-pick of trinodb/trino-python-client@a743855 Co-authored-by: Harrington Joseph <harph@hjoseph.com>
1 parent 08c2cca commit 5a687b4

File tree

4 files changed

+194
-7
lines changed

4 files changed

+194
-7
lines changed

integration_tests/test_dbapi.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
# limitations under the License.
1212
from __future__ import absolute_import, division, print_function
1313

14-
import integration_tests.fixtures as fixtures
14+
from datetime import date, datetime
15+
import numpy as np
1516
import prestodb
1617
import pytest
18+
19+
import integration_tests.fixtures as fixtures
1720
from integration_tests.fixtures import run_presto
1821
from prestodb.transaction import IsolationLevel
1922

@@ -116,7 +119,64 @@ def test_select_failed_query(presto_connection):
116119
cur.execute("select * from catalog.schema.do_not_exist")
117120
cur.fetchall()
118121

122+
def test_select_query_result_iteration_statement_params(presto_connection):
123+
cur = presto_connection.cursor()
124+
cur.execute(
125+
"""
126+
select * from (
127+
values
128+
(1, 'one', 'a'),
129+
(2, 'two', 'b'),
130+
(3, 'three', 'c'),
131+
(4, 'four', 'd'),
132+
(5, 'five', 'e')
133+
) x (id, name, letter)
134+
where id >= ?
135+
""",
136+
params=(3,) # expecting all the rows with id >= 3
137+
)
138+
139+
rows = cur.fetchall()
140+
assert len(rows) == 3
141+
142+
for row in rows:
143+
# Validate that all the ids of the returned rows are greather or equals than 3
144+
assert row[0] >= 3
145+
146+
147+
def test_select_query_param_types(presto_connection):
148+
cur = presto_connection.cursor()
149+
150+
date_param = date.today()
151+
timestamp_param = datetime.now().replace(microsecond=0)
152+
float_param = 1.5
153+
list_param = (1,2,3)
154+
cur.execute(
155+
"""
156+
select ?,?,?,?
157+
""",
158+
params=(date_param, timestamp_param, float_param, list_param,)
159+
)
160+
161+
rows = cur.fetchall()
162+
assert len(rows) == 1
163+
for row in rows:
164+
assert date.fromisoformat(row[0]) == date_param
165+
assert datetime.strptime(row[1], "%Y-%m-%d %H:%M:%S.%f") == timestamp_param
166+
assert row[2] == float_param
167+
assert (row[3] == np.array(list_param)).all()
168+
169+
@pytest.mark.parametrize('params', [
170+
'NOT A LIST OR TUPPLE',
171+
{'invalid', 'params'},
172+
object,
173+
])
174+
def test_select_query_invalid_params(presto_connection, params):
175+
cur = presto_connection.cursor()
176+
with pytest.raises(AssertionError):
177+
cur.execute('select ?', params=params)
119178

179+
120180
def test_select_tpch_1000(presto_connection):
121181
cur = presto_connection.cursor()
122182
cur.execute("SELECT * FROM tpch.sf1.customer LIMIT 1000")

prestodb/client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(
238238
)
239239
self._http_session.headers.update(self.get_oauth_token())
240240

241+
self.prepared_statements = []
241242
self._http_session.headers.update(self.http_headers)
242243
self._exceptions = self.HTTP_EXCEPTIONS
243244
self._auth = auth
@@ -270,6 +271,8 @@ def http_headers(self):
270271
headers[constants.HEADER_SCHEMA] = self._client_session.schema
271272
headers[constants.HEADER_SOURCE] = self._client_session.source
272273
headers[constants.HEADER_USER] = self._client_session.user
274+
if len(self.prepared_statements) > 0:
275+
headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(self.prepared_statements)
273276

274277
headers[constants.HEADER_SESSION] = ",".join(
275278
# ``name`` must not contain ``=``
@@ -417,6 +420,11 @@ def process(self, http_response):
417420
):
418421
self._client_session.properties[key] = value
419422

423+
if constants.HEADER_ADDED_PREPARE in http_response.headers:
424+
self._http_session.headers[
425+
constants.HEADER_PREPARED_STATEMENT
426+
] = http_response.headers[constants.HEADER_ADDED_PREPARE]
427+
420428
self._next_uri = response.get("nextUri")
421429

422430
return PrestoStatus(
@@ -529,12 +537,12 @@ def execute(self):
529537

530538
response = self._request.post(self._sql)
531539
status = self._request.process(response)
532-
if status.next_uri is None:
533-
self._finished = True
534540
self.query_id = status.id
535541
self._stats.update({"queryId": self.query_id})
536542
self._stats.update(status.stats)
537543
self._warnings = getattr(status, "warnings", [])
544+
if status.next_uri is None:
545+
self._finished = True
538546
self._result = PrestoResult(self, status.rows)
539547
while (
540548
not self._finished and not self._cancelled

prestodb/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
HEADER_STARTED_TRANSACTION = HEADER_PREFIX + "Started-Transaction-Id"
4343
HEADER_TRANSACTION = HEADER_PREFIX + "Transaction-Id"
4444

45+
HEADER_PREPARED_STATEMENT = 'X-Presto-Prepared-Statement'
46+
HEADER_ADDED_PREPARE = 'X-Presto-Added-Prepare'
47+
4548
PRESTO_EXTRA_CREDENTIAL = "X-Presto-Extra-Credential"
4649
GCS_CREDENTIALS_OAUTH_TOKEN_KEY = "hive.gcs.oauth"
4750

prestodb/dbapi.py

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,126 @@ def setoutputsize(self, size, column):
242242
raise prestodb.exceptions.NotSupportedError
243243

244244
def execute(self, operation, params=None):
245-
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
246-
result = self._query.execute()
247-
self._iterator = iter(result)
248-
return result
245+
if params:
246+
assert isinstance(params, (list, tuple)), (
247+
"params must be a list or tuple containing the query "
248+
"parameter values"
249+
)
250+
251+
statement_name = self._generate_unique_statement_name()
252+
self._prepare_statement(operation, statement_name)
253+
254+
try:
255+
# Send execute statement and assign the return value to `results`
256+
# as it will be returned by the function
257+
self._query = self._execute_prepared_statement(statement_name, params)
258+
self._iterator = iter(self._query.execute())
259+
finally:
260+
# Send deallocate statement
261+
# At this point the query can be deallocated since it has already
262+
# been executed
263+
# TODO: Consider caching prepared statements if requested by caller
264+
self._deallocate_prepared_statement(statement_name)
265+
else:
266+
self._query = prestodb.client.PrestoQuery(self._request, sql=operation)
267+
self._iterator = iter(self._query.execute())
268+
return self
269+
270+
def _generate_unique_statement_name(self):
271+
return "st_" + uuid.uuid4().hex.replace("-", "")
272+
273+
def _prepare_statement(self, statement: str, name: str) -> None:
274+
sql = f"PREPARE {name} FROM {statement}"
275+
query = prestodb.client.PrestoQuery(self._request, sql=sql)
276+
query.execute()
277+
278+
def _execute_prepared_statement(self, statement_name, params):
279+
sql = (
280+
"EXECUTE "
281+
+ statement_name
282+
+ " USING "
283+
+ ",".join(map(self._format_prepared_param, params))
284+
)
285+
return prestodb.client.PrestoQuery(self._request, sql=sql)
286+
287+
def _deallocate_prepared_statement(self, statement_name: str) -> None:
288+
sql = "DEALLOCATE PREPARE " + statement_name
289+
query = prestodb.client.PrestoQuery(self._request, sql=sql)
290+
query.execute()
291+
292+
def _format_prepared_param(self, param):
293+
"""
294+
Formats parameters to be passed in an
295+
EXECUTE statement.
296+
"""
297+
if param is None:
298+
return "NULL"
299+
300+
if isinstance(param, bool):
301+
return "true" if param else "false"
302+
303+
if isinstance(param, int):
304+
# TODO represent numbers exceeding 64-bit (BIGINT) as DECIMAL
305+
return "%d" % param
306+
307+
if isinstance(param, float):
308+
if param == float("+inf"):
309+
return "infinity()"
310+
if param == float("-inf"):
311+
return "-infinity()"
312+
return "DOUBLE '%s'" % param
313+
314+
if isinstance(param, str):
315+
return "'%s'" % param.replace("'", "''")
316+
317+
if isinstance(param, bytes):
318+
return "X'%s'" % param.hex()
319+
320+
if isinstance(param, datetime.datetime) and param.tzinfo is None:
321+
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
322+
return "TIMESTAMP '%s'" % datetime_str
323+
324+
if isinstance(param, datetime.datetime) and param.tzinfo is not None:
325+
datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f")
326+
# offset-based timezones
327+
return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param))
328+
329+
# We can't calculate the offset for a time without a point in time
330+
if isinstance(param, datetime.time) and param.tzinfo is None:
331+
time_str = param.strftime("%H:%M:%S.%f")
332+
return "TIME '%s'" % time_str
333+
334+
if isinstance(param, datetime.time) and param.tzinfo is not None:
335+
time_str = param.strftime("%H:%M:%S.%f")
336+
# offset-based timezones
337+
return "TIME '%s %s'" % (time_str, param.strftime("%Z")[3:])
338+
339+
if isinstance(param, datetime.date):
340+
date_str = param.strftime("%Y-%m-%d")
341+
return "DATE '%s'" % date_str
342+
343+
if isinstance(param, list):
344+
return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param))
345+
346+
if isinstance(param, tuple):
347+
return "ROW(%s)" % ",".join(map(self._format_prepared_param, param))
348+
349+
if isinstance(param, dict):
350+
keys = list(param.keys())
351+
values = [param[key] for key in keys]
352+
return "MAP({}, {})".format(
353+
self._format_prepared_param(keys), self._format_prepared_param(values)
354+
)
355+
356+
if isinstance(param, uuid.UUID):
357+
return "UUID '%s'" % param
358+
359+
if isinstance(param, (bytes, bytearray)):
360+
return "X'%s'" % binascii.hexlify(param).decode("utf-8")
361+
362+
raise prestodb.exceptions.NotSupportedError(
363+
"Query parameter of type '%s' is not supported." % type(param)
364+
)
249365

250366
def executemany(self, operation, seq_of_params):
251367
raise prestodb.exceptions.NotSupportedError

0 commit comments

Comments
 (0)