Skip to content

Commit 7446bde

Browse files
committed
wrote test for invalid hint (and fixed hint check notice)
1 parent 0cbcbba commit 7446bde

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

env/integtest_pg_conn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ def test_time_query(self) -> None:
129129
self.assertIsNone(explain_data)
130130

131131
def test_time_query_with_explain(self) -> None:
132-
_, _, explain_data = self.pg_conn.time_query(
133-
"select pg_sleep(1)", add_explain=True
134-
)
132+
_, _, explain_data = self.pg_conn.time_query("select 1", add_explain=True)
135133
self.assertIsNotNone(explain_data)
136134

137135
def test_time_query_with_timeout(self) -> None:
@@ -150,7 +148,7 @@ def test_time_query_with_invalid_table(self) -> None:
150148
with self.assertRaises(psycopg.errors.UndefinedTable):
151149
self.pg_conn.time_query("select * from itemline limit 10")
152150

153-
def test_time_query_with_hint(self) -> None:
151+
def test_time_query_with_valid_hints(self) -> None:
154152
join_query = """SELECT *
155153
FROM orders
156154
JOIN lineitem ON o_orderkey = l_orderkey
@@ -168,9 +166,17 @@ def test_time_query_with_hint(self) -> None:
168166
query_knobs=[f"{hint_join_type}(lineitem orders)"],
169167
add_explain=True,
170168
)
169+
assert explain_data is not None # This assertion is for mypy.
171170
actual_join_type = explain_data["Plan"]["Plans"][0]["Node Type"]
172171
self.assertEqual(expected_join_type, actual_join_type)
173172

173+
def test_time_query_with_invalid_hint(self) -> None:
174+
with self.assertRaises(RuntimeError) as context:
175+
self.pg_conn.time_query("select 1", query_knobs=["dbgym"])
176+
self.assertTrue(
177+
'Unrecognized hint keyword "dbgym"' in str(context.exception)
178+
)
179+
174180

175181
if __name__ == "__main__":
176182
unittest.main()

env/pg_conn.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
self.dbdata_dpath = self.dbdata_parent_dpath / f"dbdata{self.pgport}"
7171

7272
self._conn: Optional[psycopg.Connection[Any]] = None
73+
self.hint_check_failed_with: Optional[str] = None
7374

7475
def get_kv_connstr(self) -> str:
7576
return get_kv_connstr(self.pgport)
@@ -79,6 +80,21 @@ def conn(self) -> psycopg.Connection[Any]:
7980
self._conn = psycopg.connect(
8081
self.get_kv_connstr(), autocommit=True, prepare_threshold=None
8182
)
83+
84+
def hint_check_notice_handler(notice: psycopg.errors.Diagnostic) -> None:
85+
"""
86+
Custom handler for raising errors if hints fail.
87+
"""
88+
if (
89+
notice.message_detail is not None
90+
and "hint" in notice.message_detail.lower()
91+
):
92+
self.hint_check_failed_with = notice.message_detail
93+
94+
# We add the notice handler when the _conn is created instead of before executing a
95+
# query to avoid adding it more than once.
96+
self._conn.add_notice_handler(hint_check_notice_handler)
97+
8298
return self._conn
8399

84100
def disconnect(self) -> None:
@@ -137,17 +153,6 @@ def time_query(
137153
did_time_out = False
138154
explain_data = None
139155

140-
# def hint_notice_handler(notice) -> None:
141-
# """
142-
# Custom handler for database notices.
143-
# Raises an error or logs the notice if it indicates a problem.
144-
# """
145-
# logging.getLogger(DBGYM_LOGGER_NAME).warning(f"Postgres notice: {notice}")
146-
# if "hint" in notice.message.lower():
147-
# raise RuntimeError(f"Query hint failed: {notice.message}")
148-
149-
# self.conn().add_notice_handler(hint_notice_handler)
150-
151156
try:
152157
if query_knobs:
153158
query = f"/*+ {' '.join(query_knobs)} */ {query}"
@@ -158,10 +163,16 @@ def time_query(
158163
), "If you're using add_explain, don't also write explain manually in the query."
159164
query = f"explain (analyze, format json, timing off) {query}"
160165

166+
# Reset this every time before calling execute() so that hint_check_notice_handler works correctly.
167+
self.hint_check_failed_with = None
168+
161169
start_time = time.time()
162170
cursor = self.conn().execute(query)
163171
qid_runtime = (time.time() - start_time) * 1e6
164172

173+
if self.hint_check_failed_with is not None:
174+
raise RuntimeError(f"Query hint failed: {self.hint_check_failed_with}")
175+
165176
if add_explain:
166177
c = [c for c in cursor][0][0][0]
167178
assert "Execution Time" in c

0 commit comments

Comments
 (0)