Skip to content

Commit 7209ac6

Browse files
committed
feat:Support Oracle Database 12.1 (or later)
1 parent c57ee02 commit 7209ac6

File tree

3 files changed

+368
-4
lines changed

3 files changed

+368
-4
lines changed

dbgpt/datasource/manages/connector_manager.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def on_init(self):
4949
from dbgpt.datasource.rdbms.conn_hive import HiveConnector # noqa: F401
5050
from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnector # noqa: F401
5151
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnector # noqa: F401
52-
from dbgpt.datasource.rdbms.conn_oceanbase import OceanBaseConnect # noqa: F401
52+
from dbgpt.datasource.rdbms.conn_oracle import OracleConnector # noqa: F401
5353
from dbgpt.datasource.rdbms.conn_postgresql import ( # noqa: F401
5454
PostgreSQLConnector,
5555
)
@@ -58,9 +58,6 @@ def on_init(self):
5858
StarRocksConnector,
5959
)
6060
from dbgpt.datasource.rdbms.conn_vertica import VerticaConnector # noqa: F401
61-
from dbgpt.datasource.rdbms.dialect.oceanbase.ob_dialect import ( # noqa: F401
62-
OBDialect,
63-
)
6461

6562
from .connect_config_db import ConnectConfigEntity # noqa: F401
6663

dbgpt/datasource/rdbms/conn_oracle.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
"""Oracle connector."""
2+
3+
import logging
4+
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
5+
6+
import sqlparse
7+
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
8+
from sqlalchemy.engine import Engine
9+
from sqlalchemy.exc import SQLAlchemyError
10+
from sqlalchemy.sql import column, table, text
11+
12+
from .base import RDBMSConnector
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def _remove_trailing_semicolon(sql: str) -> str:
18+
"""Remove trailing semicolon if present."""
19+
return sql.rstrip(';')
20+
21+
22+
class OracleConnector(RDBMSConnector):
23+
"""
24+
Oracle connector.
25+
Oracle Database 12.1 (or later) is required.
26+
"""
27+
28+
driver = "oracle+oracledb"
29+
db_type = "oracle"
30+
db_dialect = "oracle"
31+
32+
def __init__(self, engine: Engine, *args, **kwargs):
33+
"""Initialize Oracle connector with SQLAlchemy engine."""
34+
super().__init__(engine, *args, **kwargs)
35+
36+
@classmethod
37+
def from_uri_db(
38+
cls,
39+
host: str,
40+
port: int,
41+
user: str,
42+
pwd: str,
43+
db_name: str,
44+
engine_args: Optional[dict] = None,
45+
**kwargs: Any,
46+
) -> "OracleConnector":
47+
"""Create a new OracleConnector from host, port, user, pwd, db_name."""
48+
db_url = f"{cls.driver}://{user}:{pwd}@{host}:{port}/{db_name}"
49+
return cast(OracleConnector, cls.from_uri(db_url, engine_args, **kwargs))
50+
51+
def _sync_tables_from_db(self) -> Iterable[str]:
52+
"""Synchronize tables from the database."""
53+
table_results = self.session.execute(
54+
text("SELECT table_name FROM all_tables WHERE owner = USER")
55+
)
56+
view_results = self.session.execute(
57+
text("SELECT view_name FROM all_views WHERE owner = USER")
58+
)
59+
table_results = set(row[0] for row in table_results) # noqa: F541
60+
view_results = set(row[0] for row in view_results) # noqa: F541
61+
self._all_tables = table_results.union(view_results)
62+
self._metadata.reflect(bind=self._engine)
63+
return self._all_tables
64+
65+
def get_current_db_name(self) -> str:
66+
"""Get current Oracle schema name instead of database name."""
67+
return self.session.execute(text("SELECT USER FROM DUAL")).scalar()
68+
69+
def table_simple_info(self):
70+
"""Return table simple info for Oracle."""
71+
_sql = """
72+
SELECT table_name, column_name
73+
FROM all_tab_columns
74+
WHERE owner = USER
75+
"""
76+
cursor = self.session.execute(text(_sql))
77+
results = cursor.fetchall()
78+
return results
79+
80+
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
81+
"""Get information about specified tables for Oracle.
82+
83+
Follows best practices and adapts to Oracle specifics, ensuring case-insensitive comparison
84+
and handling of table names.
85+
"""
86+
inspector = inspect(self._engine)
87+
all_table_names = {name.upper() for name in self.get_usable_table_names()}
88+
89+
if table_names is not None:
90+
missing_tables = set(list(name.upper() for name in table_names)).difference(
91+
all_table_names
92+
)
93+
if missing_tables:
94+
raise ValueError(
95+
f"Specified table_names {missing_tables} not found in the database."
96+
)
97+
all_table_names = set(name.upper() for name in table_names)
98+
99+
tables_info = []
100+
for table_name in all_table_names:
101+
# Fetching table metadata and constructing a string representation
102+
columns_info = inspector.get_columns(table_name)
103+
column_defs = ",\n".join(
104+
f"{col['name']} {col['type']}" for col in columns_info
105+
)
106+
create_table_desc = f"CREATE TABLE {table_name} (\n{column_defs}\n);"
107+
108+
table_info = create_table_desc
109+
110+
if self._indexes_in_table_info:
111+
# Fetching index information
112+
index_info = self._get_table_indexes(table_name)
113+
table_info += f"\n\n-- Indexes:\n{index_info}"
114+
115+
if self._sample_rows_in_table_info:
116+
# Fetching sample rows
117+
sample_rows = self._get_sample_rows(table_name)
118+
table_info += f"\n\n-- Sample Rows:\n{sample_rows}"
119+
120+
tables_info.append(table_info)
121+
122+
return "\n\n".join(tables_info)
123+
124+
def _get_table_indexes(self, table: Table) -> str:
125+
"""Get table indexes for an Oracle table."""
126+
try:
127+
indexes = self._inspector.get_indexes(table.name)
128+
indexes_formatted = [
129+
{"name": idx["name"], "column_names": idx["column_names"]}
130+
for idx in indexes
131+
]
132+
return f"Table Indexes:\n{indexes_formatted}"
133+
except SQLAlchemyError as e:
134+
logger.error(f"Error fetching indexes: {e}")
135+
return "[]"
136+
137+
def _get_sample_rows(self, table_name: str) -> str:
138+
"""
139+
Fetches sample rows from the specified Oracle table in a compatible manner.
140+
Pitfall 1: The FETCH FIRST syntax is valid in Oracle 12c and later versions, while ROWNUM works in all versions of Oracle.
141+
Pitfall 2: In some cases, Oracle might not accept a semicolon at the end of a query statement.
142+
"""
143+
# First, retrieve the table metadata to get column names
144+
table_obj = Table(table_name, MetaData(), autoload_with=self._engine)
145+
columns_str = "\t".join([col.name for col in table_obj.columns])
146+
147+
sample_query = text(
148+
f"SELECT * FROM {table_name} WHERE ROWNUM <= {self._sample_rows_in_table_info}"
149+
)
150+
151+
try:
152+
with self._engine.connect() as conn:
153+
sample_rows_result = conn.execute(sample_query)
154+
sample_rows = sample_rows_result.fetchall()
155+
156+
# Format each row as a tab-separated string, limiting string lengths
157+
sample_rows_str_list = [
158+
"\t".join(str(cell)[:100] for cell in row) for row in sample_rows
159+
]
160+
sample_rows_str = "\n".join(sample_rows_str_list)
161+
162+
except SQLAlchemyError as e:
163+
logger.error(f"Error fetching sample rows: {e}")
164+
return "Error fetching sample rows."
165+
166+
return (
167+
f"{self._sample_rows_in_table_info} rows from {table_name} table:\n"
168+
f"{columns_str}\n"
169+
f"{sample_rows_str}"
170+
)
171+
172+
def get_columns(self, table_name: str) -> List[Dict]:
173+
"""Get columns about specified Oracle table."""
174+
175+
# Fetch basic column information using Inspector
176+
columns_info = self._inspector.get_columns(table_name)
177+
178+
# Fetch primary key columns
179+
primary_key_info = self._inspector.get_pk_constraint(table_name)
180+
primary_key_columns = primary_key_info["constrained_columns"]
181+
182+
# If primary_key_columns is not a list, convert it to a list
183+
if not isinstance(primary_key_columns, list):
184+
primary_key_columns = [primary_key_columns]
185+
186+
# Enhance column information with additional details
187+
enhanced_columns = []
188+
for col in columns_info:
189+
# Check if the column is in primary key
190+
is_in_primary_key = col["name"] in primary_key_columns
191+
192+
# Construct the column info dict
193+
column_info = {
194+
"name": col["name"],
195+
"type": str(col["type"]), # convert SQLAlchemy type to string
196+
"default_expression": (
197+
str(col["default"]) if col["default"] is not None else None
198+
),
199+
"is_in_primary_key": is_in_primary_key,
200+
"comment": col["comment"] if col["comment"] is not None else None,
201+
}
202+
enhanced_columns.append(column_info)
203+
204+
return enhanced_columns
205+
206+
def convert_sql_write_to_select(self, write_sql: str) -> str:
207+
"""Convert SQL write command to a SELECT command for Oracle."""
208+
# Placeholder for Oracle-specific conversion logic
209+
return f"SELECT * FROM ({write_sql}) WHERE 1=0"
210+
211+
def get_table_comment(self, table_name: str) -> Dict:
212+
"""Get table comments for an Oracle table.
213+
214+
Args:
215+
table_name (str): table name
216+
Returns:
217+
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
218+
"""
219+
try:
220+
result = self.session.execute(
221+
text("SELECT comments FROM user_tab_comments WHERE table_name = :table_name"),
222+
{"table_name": table_name}
223+
).fetchone()
224+
return {"text": result[0]} if result else {"text": None}
225+
except SQLAlchemyError as e:
226+
logger.error(f"Error getting table comment: {e}")
227+
return {"text": None}
228+
229+
def get_grants(self):
230+
"""Get grant info for Oracle."""
231+
session = self._db_sessions()
232+
grants = []
233+
234+
return grants
235+
236+
def get_charset(self) -> str:
237+
"""Get character set."""
238+
session = self._db_sessions()
239+
charset_query = text(
240+
"SELECT value FROM NLS_DATABASE_PARAMETERS WHERE parameter = 'NLS_CHARACTERSET'"
241+
)
242+
character_set = session.execute(charset_query).scalar()
243+
return character_set
244+
245+
def get_collation(self) -> str | None:
246+
"""
247+
Get collation for Oracle. Note: Oracle does not support collations in the same way as other DBMSs like MySQL or SQL Server.
248+
This method returns None to indicate that collation querying is not applicable.
249+
"""
250+
logger.warning(
251+
"Collation querying is not applicable in Oracle as it does not support database-level collations."
252+
)
253+
return None
254+
255+
def _write(self, write_sql: str):
256+
"""Run a SQL write command and return the results as a list of tuples.
257+
258+
Args:
259+
write_sql (str): SQL write command to run
260+
"""
261+
logger.info(f"Write[{write_sql}]")
262+
command = _remove_trailing_semicolon(write_sql)
263+
return super()._write(command)
264+
265+
def _query(self, query: str, fetch: str = "all"):
266+
"""Run a SQL query and return the results as a list of tuples.
267+
268+
Args:
269+
query (str): SQL query to run
270+
fetch (str): fetch type
271+
"""
272+
logger.info(f"Query[{query}]")
273+
query = _remove_trailing_semicolon(query)
274+
return super()._query(query, fetch)
275+
276+
def run(self, command: str, fetch: str = "all") -> List:
277+
"""Execute a SQL command and return a string representing the results."""
278+
logger.info("SQL:" + command)
279+
command = _remove_trailing_semicolon(command)
280+
return super().run(command, fetch)
281+
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_oracle.py
3+
docker run -d -p 1521:1521 -e ORACLE_PASSWORD=oracle gvenzl/oracle-xe:21
4+
docker exec -it 7df26b427df0 /bin/bash
5+
sqlplus system/oracle
6+
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
7+
8+
> create database test;
9+
"""
10+
11+
import pytest
12+
13+
from dbgpt.datasource.rdbms.conn_oracle import OracleConnector
14+
15+
_create_table_sql = """
16+
CREATE TABLE test (
17+
id NUMBER(11) PRIMARY KEY
18+
)
19+
"""
20+
21+
22+
@pytest.fixture
23+
def db():
24+
conn = OracleConnector.from_uri_db("localhost", 1521, "oracle", "oracle", "XE")
25+
yield conn
26+
27+
28+
def test_get_usable_table_names(db):
29+
db.run(_create_table_sql)
30+
print(db._sync_tables_from_db())
31+
assert list(db.get_usable_table_names()) == ["TEST"]
32+
33+
34+
def test_get_columns(db):
35+
print(db.get_columns("test"))
36+
37+
38+
def test_get_table_info_with_table(db):
39+
# db.run(_create_table_sql)
40+
# print(db._sync_tables_from_db())
41+
print(db.get_table_info())
42+
43+
44+
def test_get_current_db_name(db):
45+
print(db.get_current_db_name())
46+
assert db.get_current_db_name() == "ORACLE"
47+
48+
49+
def test_table_simple_info(db):
50+
print(db.table_simple_info())
51+
52+
53+
def test_get_table_names(db):
54+
print(db.get_table_names())
55+
56+
57+
def test_get_sample_rows(db):
58+
print(db._get_sample_rows(db._metadata.tables["TEST"]))
59+
60+
61+
def test_get_table_indexes(db):
62+
print(db._get_table_indexes(db._metadata.tables["TEST"]))
63+
64+
65+
def test_run(db):
66+
SQL = "SELECT * FROM EMPLOYEES FETCH FIRST 50 ROWS ONLY"
67+
print(db.run(SQL))
68+
69+
def test_get_table_comment(db):
70+
print(db.get_table_comment("EMPLOYEES"))
71+
# print(db.get_table_comment("TEST"))
72+
73+
def test_get_fields(db):
74+
assert list(db.get_fields("test")[0])[0] == "id"
75+
76+
77+
def test_get_users(db):
78+
print(db.get_users())
79+
80+
81+
def test_get_charset(db):
82+
print(db.get_charset())
83+
84+
85+
def test_get_collation(db):
86+
print(db.get_collation())

0 commit comments

Comments
 (0)