Skip to content

Commit c819b1c

Browse files
committed
fix result type from dbapi
1 parent fb27dbc commit c819b1c

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

sqlalchemy_iris/base.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
from telnetlib import BINARY
33
from iris.dbapi._DBAPI import Cursor
4+
from iris.dbapi._Column import _Column
45
from iris.dbapi._ResultSetRow import _ResultSetRow
56
from iris.dbapi._DBAPI import SQLType as IRISSQLType
67
import iris._IRISNative as irisnative
@@ -516,21 +517,42 @@ def __init__(self, dialect):
516517
class CursorWrapper(Cursor):
517518
def __init__(self, connection):
518519
super(CursorWrapper, self).__init__(connection)
520+
521+
_types = {
522+
IRISSQLType.INTEGER: int,
523+
IRISSQLType.BIGINT: int,
524+
525+
IRISSQLType.VARCHAR: str,
526+
}
527+
528+
# Workaround for issue, when type of variable not the same as column type
529+
def _fix_type(self, value, sql_type: IRISSQLType):
530+
if value is None:
531+
return value
532+
533+
try:
534+
expected_type = self._types.get(sql_type)
535+
if expected_type and not isinstance(value, expected_type):
536+
value = expected_type(value)
537+
except Exception:
538+
pass
539+
540+
return value
519541

520542
def fetchone(self):
521543
retval = super(CursorWrapper, self).fetchone()
522544
if retval is None:
523545
return None
524546
if not isinstance(retval, _ResultSetRow.DataRow):
525547
return retval
548+
# return retval[:]
526549

527550
# Workaround for fetchone, which returns values in row not from 0
528551
row = []
552+
self._columns: list[_Column]
529553
for c in self._columns:
530-
value = retval[c.name]
531-
# Workaround for issue, when int returned as string
532-
if value is not None and c.type in (IRISSQLType.INTEGER, IRISSQLType.BIGINT,) and type(value) is not int:
533-
value = int(value)
554+
value = retval[c.name]
555+
value = self._fix_type(value, c.type)
534556
row.append(value)
535557
return row
536558

@@ -846,14 +868,22 @@ def get_indexes(self, connection, table_name, schema=None, unique=False, **kw):
846868

847869
indexes = util.defaultdict(dict)
848870
for row in rs:
849-
indexrec = indexes[row["INDEX_NAME"]]
871+
(
872+
idxname,
873+
colname,
874+
_,
875+
nuniq,
876+
_,
877+
) = row
878+
879+
indexrec = indexes[idxname]
850880
if "name" not in indexrec:
851-
indexrec["name"] = self.normalize_name(row["INDEX_NAME"])
881+
indexrec["name"] = self.normalize_name(idxname)
852882
indexrec["column_names"] = []
853-
indexrec["unique"] = not row["NON_UNIQUE"]
883+
indexrec["unique"] = not nuniq
854884

855885
indexrec["column_names"].append(
856-
self.normalize_name(row["COLUMN_NAME"])
886+
self.normalize_name(colname)
857887
)
858888

859889
indexes = list(indexes.values())
@@ -890,8 +920,12 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
890920
constraint_name = None
891921
pkfields = []
892922
for row in rs:
893-
constraint_name = self.normalize_name(row["CONSTRAINT_NAME"])
894-
pkfields.append(self.normalize_name(row["COLUMN_NAME"]))
923+
(
924+
name,
925+
colname,
926+
) = row
927+
constraint_name = self.normalize_name(name)
928+
pkfields.append(self.normalize_name(colname))
895929

896930
if pkfields:
897931
return {
@@ -1038,7 +1072,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
10381072
for row in c.mappings():
10391073
name = row[columns.c.column_name]
10401074
type_ = row[columns.c.data_type].upper()
1041-
nullable = row[columns.c.is_nullable] == "YES"
1075+
nullable = row[columns.c.is_nullable]
10421076
charlen = row[columns.c.character_maximum_length]
10431077
numericprec = row[columns.c.numeric_precision]
10441078
numericscale = row[columns.c.numeric_scale]

0 commit comments

Comments
 (0)