10
10
from collections .abc import Mapping , Sequence
11
11
from dataclasses import dataclass , field
12
12
from enum import Enum
13
- from typing import TYPE_CHECKING , Any , Final , Optional , Union
13
+ from typing import Any , Final , Optional , Union
14
14
15
+ from sqlglot import exp
15
16
from typing_extensions import TypedDict
16
17
17
18
from sqlspec .exceptions import ExtraParameterError , MissingParameterError , ParameterStyleMismatchError
18
19
from sqlspec .typing import SQLParameterType
19
20
20
- if TYPE_CHECKING :
21
- from sqlglot import exp
22
-
23
21
# Constants
24
22
MAX_32BIT_INT : Final [int ] = 2147483647
25
23
28
26
"ParameterConverter" ,
29
27
"ParameterInfo" ,
30
28
"ParameterStyle" ,
31
- "ParameterStyleTransformationState " ,
29
+ "ParameterStyleConversionState " ,
32
30
"ParameterValidator" ,
33
31
"SQLParameterType" ,
34
32
"TypedParameter" ,
@@ -169,7 +167,7 @@ class ParameterStyleInfo(TypedDict, total=False):
169
167
170
168
171
169
@dataclass
172
- class ParameterStyleTransformationState :
170
+ class ParameterStyleConversionState :
173
171
"""Encapsulates all information about parameter style transformation.
174
172
175
173
This class provides a single source of truth for parameter style conversions,
@@ -213,7 +211,7 @@ class ConvertedParameters:
213
211
merged_parameters : "SQLParameterType"
214
212
"""Parameters after merging from various sources."""
215
213
216
- conversion_state : ParameterStyleTransformationState
214
+ conversion_state : ParameterStyleConversionState
217
215
"""Complete conversion state for tracking conversions."""
218
216
219
217
@@ -314,17 +312,13 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
314
312
"""
315
313
if not parameters_info :
316
314
return ParameterStyle .NONE
317
-
318
- # Note: This logic prioritizes pyformat if present, then named, then positional.
319
315
is_pyformat_named = any (p .style == ParameterStyle .NAMED_PYFORMAT for p in parameters_info )
320
316
is_pyformat_positional = any (p .style == ParameterStyle .POSITIONAL_PYFORMAT for p in parameters_info )
321
317
322
318
if is_pyformat_named :
323
319
return ParameterStyle .NAMED_PYFORMAT
324
- if is_pyformat_positional : # If only PYFORMAT_POSITIONAL and not PYFORMAT_NAMED
320
+ if is_pyformat_positional :
325
321
return ParameterStyle .POSITIONAL_PYFORMAT
326
-
327
- # Simplified logic if not pyformat, checks for any named or any positional
328
322
has_named = any (
329
323
p .style
330
324
in {
@@ -336,13 +330,7 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
336
330
for p in parameters_info
337
331
)
338
332
has_positional = any (p .style in {ParameterStyle .QMARK , ParameterStyle .NUMERIC } for p in parameters_info )
339
-
340
- # If mixed named and positional (non-pyformat), prefer named as dominant.
341
- # The choice of NAMED_COLON here is somewhat arbitrary if multiple named styles are mixed.
342
333
if has_named :
343
- # Could refine to return the style of the first named param encountered, or most frequent.
344
- # For simplicity, returning a general named style like NAMED_COLON is often sufficient.
345
- # Or, more accurately, find the first named style:
346
334
for p_style in (
347
335
ParameterStyle .NAMED_COLON ,
348
336
ParameterStyle .POSITIONAL_COLON ,
@@ -354,12 +342,11 @@ def get_parameter_style(parameters_info: "list[ParameterInfo]") -> "ParameterSty
354
342
return ParameterStyle .NAMED_COLON
355
343
356
344
if has_positional :
357
- # Similarly, could choose QMARK or NUMERIC based on presence.
358
345
if any (p .style == ParameterStyle .NUMERIC for p in parameters_info ):
359
346
return ParameterStyle .NUMERIC
360
- return ParameterStyle .QMARK # Default positional
347
+ return ParameterStyle .QMARK
361
348
362
- return ParameterStyle .NONE # Should not be reached if parameters_info is not empty
349
+ return ParameterStyle .NONE
363
350
364
351
@staticmethod
365
352
def determine_parameter_input_type (parameters_info : "list[ParameterInfo]" ) -> "Optional[type]" :
@@ -384,9 +371,8 @@ def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "O
384
371
if any (
385
372
p .name is not None and p .style not in {ParameterStyle .POSITIONAL_COLON , ParameterStyle .NUMERIC }
386
373
for p in parameters_info
387
- ): # True for NAMED styles and PYFORMAT_NAMED
374
+ ):
388
375
return dict
389
- # All parameters must have p.name is None or be positional styles (POSITIONAL_COLON, NUMERIC)
390
376
if all (
391
377
p .name is None or p .style in {ParameterStyle .POSITIONAL_COLON , ParameterStyle .NUMERIC }
392
378
for p in parameters_info
@@ -400,9 +386,7 @@ def determine_parameter_input_type(parameters_info: "list[ParameterInfo]") -> "O
400
386
"Ambiguous parameter structure for determining input type. "
401
387
"Query might contain a mix of named and unnamed styles not typically supported together."
402
388
)
403
- # Defaulting to dict if any named param is found, as that's the more common requirement for mixed scenarios.
404
- # However, strict validation should ideally prevent such mixed styles from being valid.
405
- return dict # Or raise an error for unsupported mixed styles.
389
+ return dict
406
390
407
391
def validate_parameters (
408
392
self ,
@@ -421,12 +405,7 @@ def validate_parameters(
421
405
ParameterStyleMismatchError: When style doesn't match
422
406
"""
423
407
expected_input_type = self .determine_parameter_input_type (parameters_info )
424
-
425
- # Allow creating SQL statements with placeholders but no parameters
426
- # This enables patterns like SQL("SELECT * FROM users WHERE id = ?").as_many([...])
427
- # Validation will happen later when parameters are actually provided
428
408
if provided_params is None and parameters_info :
429
- # Don't raise an error, just return - validation will happen later
430
409
return
431
410
432
411
if (
@@ -707,7 +686,7 @@ def convert_parameters(
707
686
self .validator .validate_parameters (parameters_info , merged_params , sql )
708
687
if needs_conversion :
709
688
transformed_sql , placeholder_map = self ._transform_sql_for_parsing (sql , parameters_info )
710
- conversion_state = ParameterStyleTransformationState (
689
+ conversion_state = ParameterStyleConversionState (
711
690
was_transformed = True ,
712
691
original_styles = list ({p .style for p in parameters_info }),
713
692
transformation_style = ParameterStyle .NAMED_COLON ,
@@ -716,7 +695,7 @@ def convert_parameters(
716
695
)
717
696
else :
718
697
transformed_sql = sql
719
- conversion_state = ParameterStyleTransformationState (
698
+ conversion_state = ParameterStyleConversionState (
720
699
was_transformed = False ,
721
700
original_styles = list ({p .style for p in parameters_info }),
722
701
original_param_info = parameters_info ,
@@ -775,10 +754,10 @@ def merge_parameters(
775
754
return parameters
776
755
777
756
if kwargs is not None :
778
- return dict (kwargs ) # Make a copy
757
+ return dict (kwargs )
779
758
780
759
if args is not None :
781
- return list (args ) # Convert tuple of args to list for consistency and mutability if needed later
760
+ return list (args )
782
761
783
762
return None
784
763
@@ -809,53 +788,34 @@ def wrap_parameters_with_types(
809
788
810
789
def infer_type_from_value (value : Any ) -> tuple [str , "exp.DataType" ]:
811
790
"""Infer SQL type hint and SQLGlot DataType from Python value."""
812
- # Import here to avoid issues
813
- from sqlglot import exp
814
791
815
792
# None/NULL
816
793
if value is None :
817
794
return "null" , exp .DataType .build ("NULL" )
818
-
819
- # Boolean
820
795
if isinstance (value , bool ):
821
796
return "boolean" , exp .DataType .build ("BOOLEAN" )
822
-
823
- # Integer types
824
797
if isinstance (value , int ) and not isinstance (value , bool ):
825
798
if abs (value ) > MAX_32BIT_INT :
826
799
return "bigint" , exp .DataType .build ("BIGINT" )
827
800
return "integer" , exp .DataType .build ("INT" )
828
-
829
- # Float/Decimal
830
801
if isinstance (value , float ):
831
802
return "float" , exp .DataType .build ("FLOAT" )
832
803
if isinstance (value , Decimal ):
833
804
return "decimal" , exp .DataType .build ("DECIMAL" )
834
-
835
- # Date/Time types
836
805
if isinstance (value , datetime ):
837
806
return "timestamp" , exp .DataType .build ("TIMESTAMP" )
838
807
if isinstance (value , date ):
839
808
return "date" , exp .DataType .build ("DATE" )
840
809
if isinstance (value , time ):
841
810
return "time" , exp .DataType .build ("TIME" )
842
-
843
- # JSON/Dict
844
811
if isinstance (value , dict ):
845
812
return "json" , exp .DataType .build ("JSON" )
846
-
847
- # Array/List
848
813
if isinstance (value , (list , tuple )):
849
814
return "array" , exp .DataType .build ("ARRAY" )
850
-
851
815
if isinstance (value , str ):
852
816
return "string" , exp .DataType .build ("VARCHAR" )
853
-
854
- # Bytes
855
817
if isinstance (value , bytes ):
856
818
return "binary" , exp .DataType .build ("BINARY" )
857
-
858
- # Default fallback
859
819
return "string" , exp .DataType .build ("VARCHAR" )
860
820
861
821
def wrap_value (value : Any , semantic_name : Optional [str ] = None ) -> Any :
0 commit comments