9
9
import pytest
10
10
from unittest .mock import patch , MagicMock , Mock
11
11
12
- from databricks .sql .backend .sea .backend import SeaDatabricksClient
12
+ from databricks .sql .backend .sea .backend import (
13
+ SeaDatabricksClient ,
14
+ _filter_session_configuration ,
15
+ )
13
16
from databricks .sql .result_set import SeaResultSet
14
17
from databricks .sql .backend .types import SessionId , CommandId , CommandState , BackendType
15
18
from databricks .sql .types import SSLOptions
16
19
from databricks .sql .auth .authenticators import AuthProvider
17
- from databricks .sql .exc import Error , NotSupportedError
20
+ from databricks .sql .exc import Error , NotSupportedError , ServerOperationError
18
21
19
22
20
23
class TestSeaBackend :
@@ -305,6 +308,32 @@ def test_execute_command_async(
305
308
assert isinstance (mock_cursor .active_command_id , CommandId )
306
309
assert mock_cursor .active_command_id .guid == "test-statement-456"
307
310
311
+ def test_execute_command_async_missing_statement_id (
312
+ self , sea_client , mock_http_client , mock_cursor , sea_session_id
313
+ ):
314
+ """Test executing an async command that returns no statement ID."""
315
+ # Set up mock response with status but no statement_id
316
+ mock_http_client ._make_request .return_value = {"status" : {"state" : "PENDING" }}
317
+
318
+ # Call the method and expect an error
319
+ with pytest .raises (ServerOperationError ) as excinfo :
320
+ sea_client .execute_command (
321
+ operation = "SELECT 1" ,
322
+ session_id = sea_session_id ,
323
+ max_rows = 100 ,
324
+ max_bytes = 1000 ,
325
+ lz4_compression = False ,
326
+ cursor = mock_cursor ,
327
+ use_cloud_fetch = False ,
328
+ parameters = [],
329
+ async_op = True , # Async mode
330
+ enforce_embedded_schema_correctness = False ,
331
+ )
332
+
333
+ assert "Failed to execute command: No statement ID returned" in str (
334
+ excinfo .value
335
+ )
336
+
308
337
def test_execute_command_with_polling (
309
338
self , sea_client , mock_http_client , mock_cursor , sea_session_id
310
339
):
@@ -442,6 +471,32 @@ def test_execute_command_failure(
442
471
443
472
assert "Statement execution did not succeed" in str (excinfo .value )
444
473
474
+ def test_execute_command_missing_statement_id (
475
+ self , sea_client , mock_http_client , mock_cursor , sea_session_id
476
+ ):
477
+ """Test executing a command that returns no statement ID."""
478
+ # Set up mock response with status but no statement_id
479
+ mock_http_client ._make_request .return_value = {"status" : {"state" : "SUCCEEDED" }}
480
+
481
+ # Call the method and expect an error
482
+ with pytest .raises (ServerOperationError ) as excinfo :
483
+ sea_client .execute_command (
484
+ operation = "SELECT 1" ,
485
+ session_id = sea_session_id ,
486
+ max_rows = 100 ,
487
+ max_bytes = 1000 ,
488
+ lz4_compression = False ,
489
+ cursor = mock_cursor ,
490
+ use_cloud_fetch = False ,
491
+ parameters = [],
492
+ async_op = False ,
493
+ enforce_embedded_schema_correctness = False ,
494
+ )
495
+
496
+ assert "Failed to execute command: No statement ID returned" in str (
497
+ excinfo .value
498
+ )
499
+
445
500
def test_cancel_command (self , sea_client , mock_http_client , sea_command_id ):
446
501
"""Test canceling a command."""
447
502
# Set up mock response
@@ -533,7 +588,6 @@ def test_get_execution_result(
533
588
534
589
# Create a real result set to verify the implementation
535
590
result = sea_client .get_execution_result (sea_command_id , mock_cursor )
536
- print (result )
537
591
538
592
# Verify basic properties of the result
539
593
assert result .command_id .to_sea_statement_id () == "test-statement-123"
@@ -546,3 +600,242 @@ def test_get_execution_result(
546
600
assert kwargs ["path" ] == sea_client .STATEMENT_PATH_WITH_ID .format (
547
601
"test-statement-123"
548
602
)
603
+
604
+ def test_get_execution_result_with_invalid_command_id (
605
+ self , sea_client , mock_cursor
606
+ ):
607
+ """Test getting execution result with an invalid command ID."""
608
+ # Create a Thrift command ID (not SEA)
609
+ mock_thrift_operation_handle = MagicMock ()
610
+ mock_thrift_operation_handle .operationId .guid = b"guid"
611
+ mock_thrift_operation_handle .operationId .secret = b"secret"
612
+ command_id = CommandId .from_thrift_handle (mock_thrift_operation_handle )
613
+
614
+ # Call the method and expect an error
615
+ with pytest .raises (ValueError ) as excinfo :
616
+ sea_client .get_execution_result (command_id , mock_cursor )
617
+
618
+ assert "Not a valid SEA command ID" in str (excinfo .value )
619
+
620
+ def test_max_download_threads_property (self , mock_http_client ):
621
+ """Test the max_download_threads property."""
622
+ # Test with default value
623
+ client = SeaDatabricksClient (
624
+ server_hostname = "test-server.databricks.com" ,
625
+ port = 443 ,
626
+ http_path = "/sql/warehouses/abc123" ,
627
+ http_headers = [],
628
+ auth_provider = AuthProvider (),
629
+ ssl_options = SSLOptions (),
630
+ )
631
+ assert client .max_download_threads == 10
632
+
633
+ # Test with custom value
634
+ client = SeaDatabricksClient (
635
+ server_hostname = "test-server.databricks.com" ,
636
+ port = 443 ,
637
+ http_path = "/sql/warehouses/abc123" ,
638
+ http_headers = [],
639
+ auth_provider = AuthProvider (),
640
+ ssl_options = SSLOptions (),
641
+ max_download_threads = 5 ,
642
+ )
643
+ assert client .max_download_threads == 5
644
+
645
+ def test_get_default_session_configuration_value (self ):
646
+ """Test the get_default_session_configuration_value static method."""
647
+ # Test with supported configuration parameter
648
+ value = SeaDatabricksClient .get_default_session_configuration_value ("ANSI_MODE" )
649
+ assert value == "true"
650
+
651
+ # Test with unsupported configuration parameter
652
+ value = SeaDatabricksClient .get_default_session_configuration_value (
653
+ "UNSUPPORTED_PARAM"
654
+ )
655
+ assert value is None
656
+
657
+ # Test with case-insensitive parameter name
658
+ value = SeaDatabricksClient .get_default_session_configuration_value ("ansi_mode" )
659
+ assert value == "true"
660
+
661
+ def test_get_allowed_session_configurations (self ):
662
+ """Test the get_allowed_session_configurations static method."""
663
+ configs = SeaDatabricksClient .get_allowed_session_configurations ()
664
+ assert isinstance (configs , list )
665
+ assert len (configs ) > 0
666
+ assert "ANSI_MODE" in configs
667
+
668
+ def test_extract_description_from_manifest (self , sea_client ):
669
+ """Test the _extract_description_from_manifest method."""
670
+ # Test with valid manifest containing columns
671
+ manifest_obj = MagicMock ()
672
+ manifest_obj .schema = {
673
+ "columns" : [
674
+ {
675
+ "name" : "col1" ,
676
+ "type_name" : "STRING" ,
677
+ "precision" : 10 ,
678
+ "scale" : 2 ,
679
+ "nullable" : True ,
680
+ },
681
+ {
682
+ "name" : "col2" ,
683
+ "type_name" : "INT" ,
684
+ "nullable" : False ,
685
+ },
686
+ ]
687
+ }
688
+
689
+ description = sea_client ._extract_description_from_manifest (manifest_obj )
690
+ assert description is not None
691
+ assert len (description ) == 2
692
+
693
+ # Check first column
694
+ assert description [0 ][0 ] == "col1" # name
695
+ assert description [0 ][1 ] == "STRING" # type_code
696
+ assert description [0 ][4 ] == 10 # precision
697
+ assert description [0 ][5 ] == 2 # scale
698
+ assert description [0 ][6 ] is True # null_ok
699
+
700
+ # Check second column
701
+ assert description [1 ][0 ] == "col2" # name
702
+ assert description [1 ][1 ] == "INT" # type_code
703
+ assert description [1 ][6 ] is False # null_ok
704
+
705
+ # Test with manifest containing non-dict column
706
+ manifest_obj .schema = {"columns" : ["not_a_dict" ]}
707
+ description = sea_client ._extract_description_from_manifest (manifest_obj )
708
+ assert (
709
+ description is None
710
+ ) # Method returns None when no valid columns are found
711
+
712
+ # Test with manifest without columns
713
+ manifest_obj .schema = {}
714
+ description = sea_client ._extract_description_from_manifest (manifest_obj )
715
+ assert description is None
716
+
717
+ def test_cancel_command_with_invalid_command_id (self , sea_client ):
718
+ """Test canceling a command with an invalid command ID."""
719
+ # Create a Thrift command ID (not SEA)
720
+ mock_thrift_operation_handle = MagicMock ()
721
+ mock_thrift_operation_handle .operationId .guid = b"guid"
722
+ mock_thrift_operation_handle .operationId .secret = b"secret"
723
+ command_id = CommandId .from_thrift_handle (mock_thrift_operation_handle )
724
+
725
+ # Call the method and expect an error
726
+ with pytest .raises (ValueError ) as excinfo :
727
+ sea_client .cancel_command (command_id )
728
+
729
+ assert "Not a valid SEA command ID" in str (excinfo .value )
730
+
731
+ def test_close_command_with_invalid_command_id (self , sea_client ):
732
+ """Test closing a command with an invalid command ID."""
733
+ # Create a Thrift command ID (not SEA)
734
+ mock_thrift_operation_handle = MagicMock ()
735
+ mock_thrift_operation_handle .operationId .guid = b"guid"
736
+ mock_thrift_operation_handle .operationId .secret = b"secret"
737
+ command_id = CommandId .from_thrift_handle (mock_thrift_operation_handle )
738
+
739
+ # Call the method and expect an error
740
+ with pytest .raises (ValueError ) as excinfo :
741
+ sea_client .close_command (command_id )
742
+
743
+ assert "Not a valid SEA command ID" in str (excinfo .value )
744
+
745
+ def test_get_query_state_with_invalid_command_id (self , sea_client ):
746
+ """Test getting query state with an invalid command ID."""
747
+ # Create a Thrift command ID (not SEA)
748
+ mock_thrift_operation_handle = MagicMock ()
749
+ mock_thrift_operation_handle .operationId .guid = b"guid"
750
+ mock_thrift_operation_handle .operationId .secret = b"secret"
751
+ command_id = CommandId .from_thrift_handle (mock_thrift_operation_handle )
752
+
753
+ # Call the method and expect an error
754
+ with pytest .raises (ValueError ) as excinfo :
755
+ sea_client .get_query_state (command_id )
756
+
757
+ assert "Not a valid SEA command ID" in str (excinfo .value )
758
+
759
+ def test_unimplemented_metadata_methods (
760
+ self , sea_client , sea_session_id , mock_cursor
761
+ ):
762
+ """Test that metadata methods raise NotImplementedError."""
763
+ # Test get_catalogs
764
+ with pytest .raises (NotImplementedError ) as excinfo :
765
+ sea_client .get_catalogs (sea_session_id , 100 , 1000 , mock_cursor )
766
+ assert "get_catalogs is not implemented for SEA backend" in str (excinfo .value )
767
+
768
+ # Test get_schemas
769
+ with pytest .raises (NotImplementedError ) as excinfo :
770
+ sea_client .get_schemas (sea_session_id , 100 , 1000 , mock_cursor )
771
+ assert "get_schemas is not implemented for SEA backend" in str (excinfo .value )
772
+
773
+ # Test get_schemas with optional parameters
774
+ with pytest .raises (NotImplementedError ) as excinfo :
775
+ sea_client .get_schemas (
776
+ sea_session_id , 100 , 1000 , mock_cursor , "catalog" , "schema"
777
+ )
778
+ assert "get_schemas is not implemented for SEA backend" in str (excinfo .value )
779
+
780
+ # Test get_tables
781
+ with pytest .raises (NotImplementedError ) as excinfo :
782
+ sea_client .get_tables (sea_session_id , 100 , 1000 , mock_cursor )
783
+ assert "get_tables is not implemented for SEA backend" in str (excinfo .value )
784
+
785
+ # Test get_tables with optional parameters
786
+ with pytest .raises (NotImplementedError ) as excinfo :
787
+ sea_client .get_tables (
788
+ sea_session_id ,
789
+ 100 ,
790
+ 1000 ,
791
+ mock_cursor ,
792
+ catalog_name = "catalog" ,
793
+ schema_name = "schema" ,
794
+ table_name = "table" ,
795
+ table_types = ["TABLE" , "VIEW" ],
796
+ )
797
+ assert "get_tables is not implemented for SEA backend" in str (excinfo .value )
798
+
799
+ # Test get_columns
800
+ with pytest .raises (NotImplementedError ) as excinfo :
801
+ sea_client .get_columns (sea_session_id , 100 , 1000 , mock_cursor )
802
+ assert "get_columns is not implemented for SEA backend" in str (excinfo .value )
803
+
804
+ # Test get_columns with optional parameters
805
+ with pytest .raises (NotImplementedError ) as excinfo :
806
+ sea_client .get_columns (
807
+ sea_session_id ,
808
+ 100 ,
809
+ 1000 ,
810
+ mock_cursor ,
811
+ catalog_name = "catalog" ,
812
+ schema_name = "schema" ,
813
+ table_name = "table" ,
814
+ column_name = "column" ,
815
+ )
816
+ assert "get_columns is not implemented for SEA backend" in str (excinfo .value )
817
+
818
+ def test_execute_command_with_invalid_session_id (self , sea_client , mock_cursor ):
819
+ """Test executing a command with an invalid session ID type."""
820
+ # Create a Thrift session ID (not SEA)
821
+ mock_thrift_handle = MagicMock ()
822
+ mock_thrift_handle .sessionId .guid = b"guid"
823
+ mock_thrift_handle .sessionId .secret = b"secret"
824
+ session_id = SessionId .from_thrift_handle (mock_thrift_handle )
825
+
826
+ # Call the method and expect an error
827
+ with pytest .raises (ValueError ) as excinfo :
828
+ sea_client .execute_command (
829
+ operation = "SELECT 1" ,
830
+ session_id = session_id ,
831
+ max_rows = 100 ,
832
+ max_bytes = 1000 ,
833
+ lz4_compression = False ,
834
+ cursor = mock_cursor ,
835
+ use_cloud_fetch = False ,
836
+ parameters = [],
837
+ async_op = False ,
838
+ enforce_embedded_schema_correctness = False ,
839
+ )
840
+
841
+ assert "Not a valid SEA session ID" in str (excinfo .value )
0 commit comments