Skip to content

Commit 389dcaa

Browse files
committed
keep exact value for string values when used order
1 parent c82e8d8 commit 389dcaa

File tree

1 file changed

+79
-40
lines changed

1 file changed

+79
-40
lines changed

sqlalchemy_iris/base.py

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from sqlalchemy.sql import util as sql_util
1111
from sqlalchemy.sql import between
1212
from sqlalchemy.sql import func
13+
from sqlalchemy.sql.functions import ReturnTypeFromArgs
1314
from sqlalchemy.sql import expression
15+
from sqlalchemy.sql import schema
1416
from sqlalchemy import sql, text
1517
from sqlalchemy import util
1618
from sqlalchemy import types as sqltypes
@@ -406,59 +408,89 @@ def _use_top(self, select):
406408
or select._simple_int_clause(select._fetch_clause)
407409
)
408410

409-
def translate_select_structure(self, select_stmt, **kwargs):
410-
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
411-
so tries to wrap it in a subquery with ``row_number()`` criterion.
411+
def visit_irisexact_func(self, fn, **kw):
412+
return "%EXACT" + self.function_argspec(fn)
412413

414+
def _use_exact_for_ordered_string(self, select):
415+
"""
416+
`SELECT string_value FROM some_table ORDER BY string_value`
417+
Will return `string_value` in uppercase
418+
So, this method fixes query to use %EXACT() function
419+
`SELECT %EXACT(string_value) AS string_value FROM some_table ORDER BY string_value`
413420
"""
421+
def _add_exact(column):
422+
if isinstance(column.type, sqltypes.String):
423+
return IRISExact(column).label(column._label if column._label else column.name)
424+
return column
425+
426+
_order_by_clauses = [
427+
sql_util.unwrap_label_reference(elem)
428+
for elem in select._order_by_clause.clauses
429+
]
430+
if _order_by_clauses:
431+
select._raw_columns = [
432+
(_add_exact(c) if isinstance(c, schema.Column) and c in _order_by_clauses else c)
433+
for c in select._raw_columns
434+
]
435+
436+
return select
437+
438+
def translate_select_structure(self, select_stmt, **kwargs):
414439
select = select_stmt
440+
if getattr(select, "_iris_visit", None) is True:
441+
return select
415442

416-
if (
443+
select._iris_visit = True
444+
select = select._generate()
445+
446+
select = self._use_exact_for_ordered_string(select)
447+
448+
if not (
417449
select._has_row_limiting_clause
418450
and not self._use_top(select)
419-
and not getattr(select, "_iris_visit", None)
420451
):
421-
_order_by_clauses = [
422-
sql_util.unwrap_label_reference(elem)
423-
for elem in select._order_by_clause.clauses
424-
]
425-
426-
if not _order_by_clauses:
427-
_order_by_clauses = [text('%id')]
452+
return select
428453

429-
limit_clause = self._get_limit_or_fetch(select)
430-
offset_clause = select._offset_clause
454+
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
455+
so tries to wrap it in a subquery with ``row_number()`` criterion.
431456
432-
select = select._generate()
433-
select._iris_visit = True
434-
label = "iris_rn"
435-
select = (
436-
select.add_columns(
437-
sql.func.ROW_NUMBER()
438-
.over(order_by=_order_by_clauses)
439-
.label(label)
440-
)
441-
.order_by(None)
442-
.alias()
457+
"""
458+
_order_by_clauses = [
459+
sql_util.unwrap_label_reference(elem)
460+
for elem in select._order_by_clause.clauses
461+
]
462+
if not _order_by_clauses:
463+
_order_by_clauses = [text('%id')]
464+
465+
limit_clause = self._get_limit_or_fetch(select)
466+
offset_clause = select._offset_clause
467+
468+
label = "iris_rn"
469+
select = (
470+
select.add_columns(
471+
sql.func.ROW_NUMBER()
472+
.over(order_by=_order_by_clauses)
473+
.label(label)
443474
)
475+
.order_by(None)
476+
.alias()
477+
)
444478

445-
iris_rn = sql.column(label)
446-
limitselect = sql.select(
447-
*[c for c in select.c if c.key != label]
448-
)
449-
if offset_clause is not None:
450-
if limit_clause is not None:
451-
limitselect = limitselect.where(
452-
between(iris_rn, offset_clause + 1,
453-
limit_clause + offset_clause)
454-
)
455-
else:
456-
limitselect = limitselect.where(iris_rn > offset_clause)
479+
iris_rn = sql.column(label)
480+
limitselect = sql.select(
481+
*[c for c in select.c if c.key != label]
482+
)
483+
if offset_clause is not None:
484+
if limit_clause is not None:
485+
limitselect = limitselect.where(
486+
between(iris_rn, offset_clause + 1,
487+
limit_clause + offset_clause)
488+
)
457489
else:
458-
limitselect = limitselect.where(iris_rn <= (limit_clause))
459-
return limitselect
490+
limitselect = limitselect.where(iris_rn > offset_clause)
460491
else:
461-
return select
492+
limitselect = limitselect.where(iris_rn <= (limit_clause))
493+
return limitselect
462494

463495
def order_by_clause(self, select, **kw):
464496
order_by = self.process(select._order_by_clause, **kw)
@@ -488,6 +520,7 @@ def visit_drop_schema(self, drop, **kw):
488520

489521
def visit_check_constraint(self, constraint, **kw):
490522
raise exc.CompileError("Check CONSTRAINT not supported")
523+
# pass
491524

492525
def visit_computed_column(self, generated, **kwargs):
493526
text = self.sql_compiler.process(
@@ -591,6 +624,12 @@ def create_cursor(self):
591624
}
592625

593626

627+
class IRISExact(ReturnTypeFromArgs):
628+
"""The IRIS SQL %EXACT() function."""
629+
630+
inherit_cache = True
631+
632+
594633
class IRISDialect(default.DefaultDialect):
595634

596635
name = 'iris'

0 commit comments

Comments
 (0)