Skip to content

Commit 2784c96

Browse files
committed
fix: correctly handle NULL types for execute_many in the latest ADBC
1 parent 5628f08 commit 2784c96

File tree

10 files changed

+429
-74
lines changed

10 files changed

+429
-74
lines changed

sqlspec/adapters/adbc/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional
77

88
from sqlspec.adapters.adbc.driver import AdbcConnection, AdbcDriver
9+
from sqlspec.adapters.adbc.transformers import AdbcPostgresTransformer
910
from sqlspec.config import NoPoolSyncConfig
1011
from sqlspec.exceptions import ImproperConfigurationError
1112
from sqlspec.statement.sql import SQLConfig
@@ -434,6 +435,17 @@ def session_manager() -> "Generator[AdbcDriver, None, None]":
434435
default_parameter_style=preferred_style,
435436
)
436437

438+
# Add ADBC PostgreSQL transformer if needed
439+
if self._get_dialect() == "postgres":
440+
# Get the default transformers from the pipeline
441+
pipeline = statement_config.get_statement_pipeline()
442+
existing_transformers = list(pipeline.transformers)
443+
444+
# Append our transformer to the existing ones
445+
existing_transformers.append(AdbcPostgresTransformer())
446+
447+
statement_config = replace(statement_config, transformers=existing_transformers)
448+
437449
driver = self.driver_type(connection=connection, config=statement_config)
438450
yield driver
439451

sqlspec/adapters/adbc/driver.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ def _execute(
220220

221221
with self._get_cursor(txn_conn) as cursor:
222222
try:
223+
# ADBC PostgreSQL has issues with NULL parameters in some cases
224+
# The transformer handles all-NULL cases, but mixed NULL/non-NULL
225+
# can still cause "Can't map Arrow type 'na' to Postgres type" errors
223226
cursor.execute(sql, cursor_params or [])
224227
except Exception as e:
225228
# Rollback transaction on error for PostgreSQL to avoid
@@ -265,6 +268,17 @@ def _execute_many(
265268
# Normalize parameter list using consolidated utility
266269
converted_param_list = convert_parameter_sequence(param_list)
267270

271+
# Handle empty parameter list case for PostgreSQL
272+
if not converted_param_list and self.dialect == "postgres":
273+
# Return empty result without executing
274+
return SQLResult(
275+
statement=SQL(sql, _dialect=self.dialect),
276+
data=[],
277+
rows_affected=0,
278+
operation_type="EXECUTE",
279+
metadata={"status_message": "OK"},
280+
)
281+
268282
with self._get_cursor(txn_conn) as cursor:
269283
try:
270284
cursor.executemany(sql, converted_param_list or [])

sqlspec/adapters/adbc/transformers.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""ADBC-specific AST transformers for handling driver limitations."""
2+
3+
from typing import Optional
4+
5+
from sqlglot import exp
6+
7+
from sqlspec.protocols import ProcessorProtocol
8+
from sqlspec.statement.pipelines.context import SQLProcessingContext
9+
10+
__all__ = ("AdbcPostgresTransformer",)
11+
12+
13+
class AdbcPostgresTransformer(ProcessorProtocol):
14+
"""Transformer to handle ADBC PostgreSQL driver limitations.
15+
16+
This transformer addresses specific issues with the ADBC PostgreSQL driver:
17+
1. Empty parameter lists in executemany() causing "no parameter $1" errors
18+
2. NULL parameters causing "Can't map Arrow type 'na' to Postgres type" errors
19+
20+
The transformer works at the AST level to properly handle these edge cases.
21+
"""
22+
23+
def __init__(self) -> None:
24+
self.has_placeholders = False
25+
self.all_params_null = False
26+
self.is_empty_params = False
27+
self.has_null_params = False
28+
self.null_param_indices: list[int] = []
29+
30+
def process(self, expression: Optional[exp.Expression], context: SQLProcessingContext) -> Optional[exp.Expression]:
31+
"""Process the SQL expression to handle ADBC limitations."""
32+
if not expression:
33+
return expression
34+
35+
# Check if we have an empty parameter list for executemany
36+
# Look at the merged_parameters in the context
37+
params = context.merged_parameters
38+
39+
# For execute_many, check if we have an empty list
40+
if isinstance(params, list) and len(params) == 0:
41+
self.is_empty_params = True
42+
43+
# Check for NULL parameters
44+
if params:
45+
if isinstance(params, (list, tuple)):
46+
# Track which parameters are NULL
47+
self.null_param_indices = [i for i, p in enumerate(params) if p is None]
48+
self.has_null_params = len(self.null_param_indices) > 0
49+
self.all_params_null = len(self.null_param_indices) == len(params)
50+
51+
# For ADBC PostgreSQL, we need to replace NULL parameters with literals
52+
# and remove them from the parameter list
53+
if self.has_null_params:
54+
# Create new parameter list without NULLs
55+
new_params = [p for p in params if p is not None]
56+
context.merged_parameters = new_params
57+
58+
elif isinstance(params, dict):
59+
# For dict parameters, track which ones are NULL
60+
null_keys = [k for k, v in params.items() if v is None]
61+
self.has_null_params = len(null_keys) > 0
62+
self.all_params_null = len(null_keys) == len(params)
63+
64+
if self.has_null_params:
65+
# Remove NULL parameters from dict
66+
context.merged_parameters = {k: v for k, v in params.items() if v is not None}
67+
68+
# Transform the AST if needed
69+
if self.is_empty_params:
70+
# For empty parameters, we should skip transformation and let the driver handle it
71+
# The driver already has logic to return empty result for empty params
72+
return expression
73+
74+
if self.has_null_params:
75+
# Transform placeholders to NULL literals where needed
76+
self._parameter_index = 0 # Track current parameter position
77+
return expression.transform(self._transform_node)
78+
79+
return expression
80+
81+
def _transform_node(self, node: exp.Expression) -> exp.Expression:
82+
"""Transform individual AST nodes."""
83+
# Handle parameter nodes (e.g., $1, $2, etc. in PostgreSQL)
84+
if isinstance(node, exp.Parameter):
85+
# Access the parameter value directly from the AST node
86+
# The 'this' attribute contains a Literal node, whose 'this' contains the actual value
87+
if node.this and isinstance(node.this, exp.Literal):
88+
try:
89+
param_index = int(node.this.this) - 1 # Convert to 0-based index
90+
# Check if this parameter should be NULL
91+
if param_index in self.null_param_indices:
92+
return exp.Null()
93+
# Renumber the parameter based on how many NULLs came before it
94+
nulls_before = sum(1 for idx in self.null_param_indices if idx < param_index)
95+
new_index = param_index - nulls_before + 1 # Convert back to 1-based
96+
return exp.Parameter(this=exp.Literal.number(new_index))
97+
except (ValueError, IndexError):
98+
pass
99+
100+
# Handle placeholder nodes for other dialects
101+
elif isinstance(node, exp.Placeholder):
102+
# For placeholders, we need to track position
103+
if self._parameter_index in self.null_param_indices:
104+
self._parameter_index += 1
105+
return exp.Null()
106+
self._parameter_index += 1
107+
108+
return node

sqlspec/statement/parameters.py

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@
1010
from collections.abc import Mapping, Sequence
1111
from dataclasses import dataclass, field
1212
from enum import Enum
13-
from typing import TYPE_CHECKING, Any, Final, Optional, Union
13+
from typing import Any, Final, Optional, Union
1414

15+
from sqlglot import exp
1516
from typing_extensions import TypedDict
1617

1718
from sqlspec.exceptions import ExtraParameterError, MissingParameterError, ParameterStyleMismatchError
1819
from sqlspec.typing import SQLParameterType
1920

20-
if TYPE_CHECKING:
21-
from sqlglot import exp
22-
2321
# Constants
2422
MAX_32BIT_INT: Final[int] = 2147483647
2523

@@ -28,7 +26,7 @@
2826
"ParameterConverter",
2927
"ParameterInfo",
3028
"ParameterStyle",
31-
"ParameterStyleTransformationState",
29+
"ParameterStyleConversionState",
3230
"ParameterValidator",
3331
"SQLParameterType",
3432
"TypedParameter",
@@ -169,7 +167,7 @@ class ParameterStyleInfo(TypedDict, total=False):
169167

170168

171169
@dataclass
172-
class ParameterStyleTransformationState:
170+
class ParameterStyleConversionState:
173171
"""Encapsulates all information about parameter style transformation.
174172
175173
This class provides a single source of truth for parameter style conversions,
@@ -213,7 +211,7 @@ class ConvertedParameters:
213211
merged_parameters: "SQLParameterType"
214212
"""Parameters after merging from various sources."""
215213

216-
conversion_state: ParameterStyleTransformationState
214+
conversion_state: ParameterStyleConversionState
217215
"""Complete conversion state for tracking conversions."""
218216

219217

@@ -314,17 +312,13 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
314312
"""
315313
if not parameters_info:
316314
return ParameterStyle.NONE
317-
318-
# Note: This logic prioritizes pyformat if present, then named, then positional.
319315
is_pyformat_named = any(p.style == ParameterStyle.NAMED_PYFORMAT for p in parameters_info)
320316
is_pyformat_positional = any(p.style == ParameterStyle.POSITIONAL_PYFORMAT for p in parameters_info)
321317

322318
if is_pyformat_named:
323319
return ParameterStyle.NAMED_PYFORMAT
324-
if is_pyformat_positional: # If only PYFORMAT_POSITIONAL and not PYFORMAT_NAMED
320+
if is_pyformat_positional:
325321
return ParameterStyle.POSITIONAL_PYFORMAT
326-
327-
# Simplified logic if not pyformat, checks for any named or any positional
328322
has_named = any(
329323
p.style
330324
in {
@@ -336,13 +330,7 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
336330
for p in parameters_info
337331
)
338332
has_positional = any(p.style in {ParameterStyle.QMARK, ParameterStyle.NUMERIC} for p in parameters_info)
339-
340-
# If mixed named and positional (non-pyformat), prefer named as dominant.
341-
# The choice of NAMED_COLON here is somewhat arbitrary if multiple named styles are mixed.
342333
if has_named:
343-
# Could refine to return the style of the first named param encountered, or most frequent.
344-
# For simplicity, returning a general named style like NAMED_COLON is often sufficient.
345-
# Or, more accurately, find the first named style:
346334
for p_style in (
347335
ParameterStyle.NAMED_COLON,
348336
ParameterStyle.POSITIONAL_COLON,
@@ -354,12 +342,11 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
354342
return ParameterStyle.NAMED_COLON
355343

356344
if has_positional:
357-
# Similarly, could choose QMARK or NUMERIC based on presence.
358345
if any(p.style == ParameterStyle.NUMERIC for p in parameters_info):
359346
return ParameterStyle.NUMERIC
360-
return ParameterStyle.QMARK # Default positional
347+
return ParameterStyle.QMARK
361348

362-
return ParameterStyle.NONE # Should not be reached if parameters_info is not empty
349+
return ParameterStyle.NONE
363350

364351
@staticmethod
365352
def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "Optional[type]":
@@ -384,9 +371,8 @@ def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "O
384371
if any(
385372
p.name is not None and p.style not in {ParameterStyle.POSITIONAL_COLON, ParameterStyle.NUMERIC}
386373
for p in parameters_info
387-
): # True for NAMED styles and PYFORMAT_NAMED
374+
):
388375
return dict
389-
# All parameters must have p.name is None or be positional styles (POSITIONAL_COLON, NUMERIC)
390376
if all(
391377
p.name is None or p.style in {ParameterStyle.POSITIONAL_COLON, ParameterStyle.NUMERIC}
392378
for p in parameters_info
@@ -400,9 +386,7 @@ def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "O
400386
"Ambiguous parameter structure for determining input type. "
401387
"Query might contain a mix of named and unnamed styles not typically supported together."
402388
)
403-
# Defaulting to dict if any named param is found, as that's the more common requirement for mixed scenarios.
404-
# However, strict validation should ideally prevent such mixed styles from being valid.
405-
return dict # Or raise an error for unsupported mixed styles.
389+
return dict
406390

407391
def validate_parameters(
408392
self,
@@ -421,12 +405,7 @@ def validate_parameters(
421405
ParameterStyleMismatchError: When style doesn't match
422406
"""
423407
expected_input_type = self.determine_parameter_input_type(parameters_info)
424-
425-
# Allow creating SQL statements with placeholders but no parameters
426-
# This enables patterns like SQL("SELECT * FROM users WHERE id = ?").as_many([...])
427-
# Validation will happen later when parameters are actually provided
428408
if provided_params is None and parameters_info:
429-
# Don't raise an error, just return - validation will happen later
430409
return
431410

432411
if (
@@ -707,7 +686,7 @@ def convert_parameters(
707686
self.validator.validate_parameters(parameters_info, merged_params, sql)
708687
if needs_conversion:
709688
transformed_sql, placeholder_map = self._transform_sql_for_parsing(sql, parameters_info)
710-
conversion_state = ParameterStyleTransformationState(
689+
conversion_state = ParameterStyleConversionState(
711690
was_transformed=True,
712691
original_styles=list({p.style for p in parameters_info}),
713692
transformation_style=ParameterStyle.NAMED_COLON,
@@ -716,7 +695,7 @@ def convert_parameters(
716695
)
717696
else:
718697
transformed_sql = sql
719-
conversion_state = ParameterStyleTransformationState(
698+
conversion_state = ParameterStyleConversionState(
720699
was_transformed=False,
721700
original_styles=list({p.style for p in parameters_info}),
722701
original_param_info=parameters_info,
@@ -775,10 +754,10 @@ def merge_parameters(
775754
return parameters
776755

777756
if kwargs is not None:
778-
return dict(kwargs) # Make a copy
757+
return dict(kwargs)
779758

780759
if args is not None:
781-
return list(args) # Convert tuple of args to list for consistency and mutability if needed later
760+
return list(args)
782761

783762
return None
784763

@@ -809,53 +788,34 @@ def wrap_parameters_with_types(
809788

810789
def infer_type_from_value(value: Any) -> tuple[str, "exp.DataType"]:
811790
"""Infer SQL type hint and SQLGlot DataType from Python value."""
812-
# Import here to avoid issues
813-
from sqlglot import exp
814791

815792
# None/NULL
816793
if value is None:
817794
return "null", exp.DataType.build("NULL")
818-
819-
# Boolean
820795
if isinstance(value, bool):
821796
return "boolean", exp.DataType.build("BOOLEAN")
822-
823-
# Integer types
824797
if isinstance(value, int) and not isinstance(value, bool):
825798
if abs(value) > MAX_32BIT_INT:
826799
return "bigint", exp.DataType.build("BIGINT")
827800
return "integer", exp.DataType.build("INT")
828-
829-
# Float/Decimal
830801
if isinstance(value, float):
831802
return "float", exp.DataType.build("FLOAT")
832803
if isinstance(value, Decimal):
833804
return "decimal", exp.DataType.build("DECIMAL")
834-
835-
# Date/Time types
836805
if isinstance(value, datetime):
837806
return "timestamp", exp.DataType.build("TIMESTAMP")
838807
if isinstance(value, date):
839808
return "date", exp.DataType.build("DATE")
840809
if isinstance(value, time):
841810
return "time", exp.DataType.build("TIME")
842-
843-
# JSON/Dict
844811
if isinstance(value, dict):
845812
return "json", exp.DataType.build("JSON")
846-
847-
# Array/List
848813
if isinstance(value, (list, tuple)):
849814
return "array", exp.DataType.build("ARRAY")
850-
851815
if isinstance(value, str):
852816
return "string", exp.DataType.build("VARCHAR")
853-
854-
# Bytes
855817
if isinstance(value, bytes):
856818
return "binary", exp.DataType.build("BINARY")
857-
858-
# Default fallback
859819
return "string", exp.DataType.build("VARCHAR")
860820

861821
def wrap_value(value: Any, semantic_name: Optional[str] = None) -> Any:

0 commit comments

Comments
 (0)