Skip to content

Commit 760f65d

Browse files
Update dbtools.py
1 parent 7c11e1c commit 760f65d

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

app/dbtools.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def run_latency_test(
2323
database="",
2424
url="",
2525
interval=1.0,
26-
period=10
26+
period=10,
27+
custom_sql=""
2728
):
2829
query_times = []
2930
result_info = {"success": False, "error": None, "latency_stats": {}, "details": []}
31+
custom_sql = (custom_sql or "").strip()
3032
try:
3133
end_time = time.perf_counter() + period
3234
while time.perf_counter() < end_time:
@@ -36,29 +38,33 @@ def run_latency_test(
3638
if dbtype == "oracle":
3739
conn = oracledb.connect(user=username, password=password, dsn=host)
3840
cursor = conn.cursor()
39-
cursor.execute("select 1 from dual")
41+
sql = custom_sql if custom_sql else "select 1 from dual"
42+
cursor.execute(sql)
4043
cursor.fetchall()
4144
cursor.close()
4245
conn.close()
4346
elif dbtype == "postgresql":
4447
conn = psycopg2.connect(host=host, port=port, dbname=database, user=username, password=password)
4548
cursor = conn.cursor()
46-
cursor.execute("SELECT 1")
49+
sql = custom_sql if custom_sql else "SELECT 1"
50+
cursor.execute(sql)
4751
cursor.fetchall()
4852
cursor.close()
4953
conn.close()
5054
elif dbtype == "mysql":
5155
conn = pymysql.connect(host=host, port=int(port), user=username, password=password, db=database)
5256
cursor = conn.cursor()
53-
cursor.execute("SELECT 1")
57+
sql = custom_sql if custom_sql else "SELECT 1"
58+
cursor.execute(sql)
5459
cursor.fetchall()
5560
cursor.close()
5661
conn.close()
5762
elif dbtype == "sqlserver" and mssql_ok:
5863
conn_str = f"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={host},{port};DATABASE={database};UID={username};PWD={password}"
5964
conn = pyodbc.connect(conn_str)
6065
cursor = conn.cursor()
61-
cursor.execute("SELECT 1")
66+
sql = custom_sql if custom_sql else "SELECT 1"
67+
cursor.execute(sql)
6268
cursor.fetchall()
6369
cursor.close()
6470
conn.close()
@@ -81,7 +87,6 @@ def run_latency_test(
8187
"error": error
8288
})
8389
time.sleep(interval)
84-
# Compute p99, p90, avg, stddev, mean on all query_times
8590
arr = np.array(query_times)
8691
result_info["latency_stats"] = {
8792
"p99": float(np.percentile(arr, 99)) if len(arr) else None,

0 commit comments

Comments
 (0)