Skip to content

Commit 5de1b54

Browse files
authored
Merge pull request #696 from Mause/try-cast
feat: try_cast operator
2 parents 062a1ef + 2993e14 commit 5de1b54

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

duckdb_engine/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
)
1515

1616
import duckdb
17+
import sqlalchemy
1718
from sqlalchemy import pool, text
1819
from sqlalchemy import types as sqltypes
1920
from sqlalchemy import util
2021
from sqlalchemy.dialects.postgresql import UUID
21-
from sqlalchemy.dialects.postgresql.base import PGDialect, PGInspector
22+
from sqlalchemy.dialects.postgresql.base import PGDialect, PGInspector, PGTypeCompiler
2223
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
2324
from sqlalchemy.engine.default import DefaultDialect
2425
from sqlalchemy.engine.url import URL
26+
from sqlalchemy.ext.compiler import compiles
2527

2628
from .config import apply_config, get_core_config
2729
from .datatypes import ISCHEMA_NAMES, register_extension_types
@@ -357,3 +359,18 @@ def get_multi_columns(
357359
columns = self._get_columns_info(rows, domains, enums, schema) # type: ignore[attr-defined]
358360

359361
return columns.items()
362+
363+
364+
if sqlalchemy.__version__ >= "2.0.14":
365+
from sqlalchemy import TryCast # type: ignore[attr-defined]
366+
367+
@compiles(TryCast, "duckdb") # type: ignore[misc]
368+
def visit_try_cast(
369+
instance: TryCast,
370+
compiler: PGTypeCompiler,
371+
**kw: Any,
372+
) -> str:
373+
return "TRY_CAST({} AS {})".format(
374+
compiler.process(instance.clause, **kw),
375+
compiler.process(instance.typeclause, **kw),
376+
)

duckdb_engine/tests/test_basic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33
import zlib
4-
from datetime import timedelta
4+
from datetime import datetime, timedelta
55
from pathlib import Path
66
from typing import Any, Generic, Optional, TypeVar, cast
77

@@ -13,6 +13,7 @@
1313
from pytest import LogCaptureFixture, fixture, importorskip, mark, raises
1414
from sqlalchemy import (
1515
Column,
16+
DateTime,
1617
ForeignKey,
1718
Integer,
1819
Interval,
@@ -22,6 +23,7 @@
2223
Table,
2324
create_engine,
2425
inspect,
26+
select,
2527
text,
2628
types,
2729
)
@@ -408,3 +410,14 @@ def test_do_ping(tmp_path: Path, caplog: LogCaptureFixture) -> None:
408410
assert any(
409411
"Pool pre-ping on connection" in message for message in caplog.messages
410412
)
413+
414+
415+
def test_try_cast(engine: Engine) -> None:
416+
try_cast = importorskip("sqlalchemy", "2.0.14").try_cast
417+
418+
with engine.connect() as conn:
419+
query = select(try_cast("2022-01-01", DateTime))
420+
assert conn.execute(query).one() == (datetime(2022, 1, 1),)
421+
422+
query = select(try_cast("not a date", DateTime))
423+
assert conn.execute(query).one() == (None,)

0 commit comments

Comments
 (0)