Skip to content

Commit df6dac2

Browse files
add more unit tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 1d57c99 commit df6dac2

File tree

1 file changed

+296
-3
lines changed

1 file changed

+296
-3
lines changed

tests/unit/test_sea_backend.py

Lines changed: 296 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
import pytest
1010
from unittest.mock import patch, MagicMock, Mock
1111

12-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
12+
from databricks.sql.backend.sea.backend import (
13+
SeaDatabricksClient,
14+
_filter_session_configuration,
15+
)
1316
from databricks.sql.result_set import SeaResultSet
1417
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
1518
from databricks.sql.types import SSLOptions
1619
from databricks.sql.auth.authenticators import AuthProvider
17-
from databricks.sql.exc import Error, NotSupportedError
20+
from databricks.sql.exc import Error, NotSupportedError, ServerOperationError
1821

1922

2023
class TestSeaBackend:
@@ -305,6 +308,32 @@ def test_execute_command_async(
305308
assert isinstance(mock_cursor.active_command_id, CommandId)
306309
assert mock_cursor.active_command_id.guid == "test-statement-456"
307310

311+
def test_execute_command_async_missing_statement_id(
312+
self, sea_client, mock_http_client, mock_cursor, sea_session_id
313+
):
314+
"""Test executing an async command that returns no statement ID."""
315+
# Set up mock response with status but no statement_id
316+
mock_http_client._make_request.return_value = {"status": {"state": "PENDING"}}
317+
318+
# Call the method and expect an error
319+
with pytest.raises(ServerOperationError) as excinfo:
320+
sea_client.execute_command(
321+
operation="SELECT 1",
322+
session_id=sea_session_id,
323+
max_rows=100,
324+
max_bytes=1000,
325+
lz4_compression=False,
326+
cursor=mock_cursor,
327+
use_cloud_fetch=False,
328+
parameters=[],
329+
async_op=True, # Async mode
330+
enforce_embedded_schema_correctness=False,
331+
)
332+
333+
assert "Failed to execute command: No statement ID returned" in str(
334+
excinfo.value
335+
)
336+
308337
def test_execute_command_with_polling(
309338
self, sea_client, mock_http_client, mock_cursor, sea_session_id
310339
):
@@ -442,6 +471,32 @@ def test_execute_command_failure(
442471

443472
assert "Statement execution did not succeed" in str(excinfo.value)
444473

474+
def test_execute_command_missing_statement_id(
475+
self, sea_client, mock_http_client, mock_cursor, sea_session_id
476+
):
477+
"""Test executing a command that returns no statement ID."""
478+
# Set up mock response with status but no statement_id
479+
mock_http_client._make_request.return_value = {"status": {"state": "SUCCEEDED"}}
480+
481+
# Call the method and expect an error
482+
with pytest.raises(ServerOperationError) as excinfo:
483+
sea_client.execute_command(
484+
operation="SELECT 1",
485+
session_id=sea_session_id,
486+
max_rows=100,
487+
max_bytes=1000,
488+
lz4_compression=False,
489+
cursor=mock_cursor,
490+
use_cloud_fetch=False,
491+
parameters=[],
492+
async_op=False,
493+
enforce_embedded_schema_correctness=False,
494+
)
495+
496+
assert "Failed to execute command: No statement ID returned" in str(
497+
excinfo.value
498+
)
499+
445500
def test_cancel_command(self, sea_client, mock_http_client, sea_command_id):
446501
"""Test canceling a command."""
447502
# Set up mock response
@@ -533,7 +588,6 @@ def test_get_execution_result(
533588

534589
# Create a real result set to verify the implementation
535590
result = sea_client.get_execution_result(sea_command_id, mock_cursor)
536-
print(result)
537591

538592
# Verify basic properties of the result
539593
assert result.command_id.to_sea_statement_id() == "test-statement-123"
@@ -546,3 +600,242 @@ def test_get_execution_result(
546600
assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format(
547601
"test-statement-123"
548602
)
603+
604+
def test_get_execution_result_with_invalid_command_id(
605+
self, sea_client, mock_cursor
606+
):
607+
"""Test getting execution result with an invalid command ID."""
608+
# Create a Thrift command ID (not SEA)
609+
mock_thrift_operation_handle = MagicMock()
610+
mock_thrift_operation_handle.operationId.guid = b"guid"
611+
mock_thrift_operation_handle.operationId.secret = b"secret"
612+
command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle)
613+
614+
# Call the method and expect an error
615+
with pytest.raises(ValueError) as excinfo:
616+
sea_client.get_execution_result(command_id, mock_cursor)
617+
618+
assert "Not a valid SEA command ID" in str(excinfo.value)
619+
620+
def test_max_download_threads_property(self, mock_http_client):
621+
"""Test the max_download_threads property."""
622+
# Test with default value
623+
client = SeaDatabricksClient(
624+
server_hostname="test-server.databricks.com",
625+
port=443,
626+
http_path="/sql/warehouses/abc123",
627+
http_headers=[],
628+
auth_provider=AuthProvider(),
629+
ssl_options=SSLOptions(),
630+
)
631+
assert client.max_download_threads == 10
632+
633+
# Test with custom value
634+
client = SeaDatabricksClient(
635+
server_hostname="test-server.databricks.com",
636+
port=443,
637+
http_path="/sql/warehouses/abc123",
638+
http_headers=[],
639+
auth_provider=AuthProvider(),
640+
ssl_options=SSLOptions(),
641+
max_download_threads=5,
642+
)
643+
assert client.max_download_threads == 5
644+
645+
def test_get_default_session_configuration_value(self):
646+
"""Test the get_default_session_configuration_value static method."""
647+
# Test with supported configuration parameter
648+
value = SeaDatabricksClient.get_default_session_configuration_value("ANSI_MODE")
649+
assert value == "true"
650+
651+
# Test with unsupported configuration parameter
652+
value = SeaDatabricksClient.get_default_session_configuration_value(
653+
"UNSUPPORTED_PARAM"
654+
)
655+
assert value is None
656+
657+
# Test with case-insensitive parameter name
658+
value = SeaDatabricksClient.get_default_session_configuration_value("ansi_mode")
659+
assert value == "true"
660+
661+
def test_get_allowed_session_configurations(self):
662+
"""Test the get_allowed_session_configurations static method."""
663+
configs = SeaDatabricksClient.get_allowed_session_configurations()
664+
assert isinstance(configs, list)
665+
assert len(configs) > 0
666+
assert "ANSI_MODE" in configs
667+
668+
def test_extract_description_from_manifest(self, sea_client):
669+
"""Test the _extract_description_from_manifest method."""
670+
# Test with valid manifest containing columns
671+
manifest_obj = MagicMock()
672+
manifest_obj.schema = {
673+
"columns": [
674+
{
675+
"name": "col1",
676+
"type_name": "STRING",
677+
"precision": 10,
678+
"scale": 2,
679+
"nullable": True,
680+
},
681+
{
682+
"name": "col2",
683+
"type_name": "INT",
684+
"nullable": False,
685+
},
686+
]
687+
}
688+
689+
description = sea_client._extract_description_from_manifest(manifest_obj)
690+
assert description is not None
691+
assert len(description) == 2
692+
693+
# Check first column
694+
assert description[0][0] == "col1" # name
695+
assert description[0][1] == "STRING" # type_code
696+
assert description[0][4] == 10 # precision
697+
assert description[0][5] == 2 # scale
698+
assert description[0][6] is True # null_ok
699+
700+
# Check second column
701+
assert description[1][0] == "col2" # name
702+
assert description[1][1] == "INT" # type_code
703+
assert description[1][6] is False # null_ok
704+
705+
# Test with manifest containing non-dict column
706+
manifest_obj.schema = {"columns": ["not_a_dict"]}
707+
description = sea_client._extract_description_from_manifest(manifest_obj)
708+
assert (
709+
description is None
710+
) # Method returns None when no valid columns are found
711+
712+
# Test with manifest without columns
713+
manifest_obj.schema = {}
714+
description = sea_client._extract_description_from_manifest(manifest_obj)
715+
assert description is None
716+
717+
def test_cancel_command_with_invalid_command_id(self, sea_client):
718+
"""Test canceling a command with an invalid command ID."""
719+
# Create a Thrift command ID (not SEA)
720+
mock_thrift_operation_handle = MagicMock()
721+
mock_thrift_operation_handle.operationId.guid = b"guid"
722+
mock_thrift_operation_handle.operationId.secret = b"secret"
723+
command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle)
724+
725+
# Call the method and expect an error
726+
with pytest.raises(ValueError) as excinfo:
727+
sea_client.cancel_command(command_id)
728+
729+
assert "Not a valid SEA command ID" in str(excinfo.value)
730+
731+
def test_close_command_with_invalid_command_id(self, sea_client):
732+
"""Test closing a command with an invalid command ID."""
733+
# Create a Thrift command ID (not SEA)
734+
mock_thrift_operation_handle = MagicMock()
735+
mock_thrift_operation_handle.operationId.guid = b"guid"
736+
mock_thrift_operation_handle.operationId.secret = b"secret"
737+
command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle)
738+
739+
# Call the method and expect an error
740+
with pytest.raises(ValueError) as excinfo:
741+
sea_client.close_command(command_id)
742+
743+
assert "Not a valid SEA command ID" in str(excinfo.value)
744+
745+
def test_get_query_state_with_invalid_command_id(self, sea_client):
746+
"""Test getting query state with an invalid command ID."""
747+
# Create a Thrift command ID (not SEA)
748+
mock_thrift_operation_handle = MagicMock()
749+
mock_thrift_operation_handle.operationId.guid = b"guid"
750+
mock_thrift_operation_handle.operationId.secret = b"secret"
751+
command_id = CommandId.from_thrift_handle(mock_thrift_operation_handle)
752+
753+
# Call the method and expect an error
754+
with pytest.raises(ValueError) as excinfo:
755+
sea_client.get_query_state(command_id)
756+
757+
assert "Not a valid SEA command ID" in str(excinfo.value)
758+
759+
def test_unimplemented_metadata_methods(
760+
self, sea_client, sea_session_id, mock_cursor
761+
):
762+
"""Test that metadata methods raise NotImplementedError."""
763+
# Test get_catalogs
764+
with pytest.raises(NotImplementedError) as excinfo:
765+
sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor)
766+
assert "get_catalogs is not implemented for SEA backend" in str(excinfo.value)
767+
768+
# Test get_schemas
769+
with pytest.raises(NotImplementedError) as excinfo:
770+
sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor)
771+
assert "get_schemas is not implemented for SEA backend" in str(excinfo.value)
772+
773+
# Test get_schemas with optional parameters
774+
with pytest.raises(NotImplementedError) as excinfo:
775+
sea_client.get_schemas(
776+
sea_session_id, 100, 1000, mock_cursor, "catalog", "schema"
777+
)
778+
assert "get_schemas is not implemented for SEA backend" in str(excinfo.value)
779+
780+
# Test get_tables
781+
with pytest.raises(NotImplementedError) as excinfo:
782+
sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor)
783+
assert "get_tables is not implemented for SEA backend" in str(excinfo.value)
784+
785+
# Test get_tables with optional parameters
786+
with pytest.raises(NotImplementedError) as excinfo:
787+
sea_client.get_tables(
788+
sea_session_id,
789+
100,
790+
1000,
791+
mock_cursor,
792+
catalog_name="catalog",
793+
schema_name="schema",
794+
table_name="table",
795+
table_types=["TABLE", "VIEW"],
796+
)
797+
assert "get_tables is not implemented for SEA backend" in str(excinfo.value)
798+
799+
# Test get_columns
800+
with pytest.raises(NotImplementedError) as excinfo:
801+
sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor)
802+
assert "get_columns is not implemented for SEA backend" in str(excinfo.value)
803+
804+
# Test get_columns with optional parameters
805+
with pytest.raises(NotImplementedError) as excinfo:
806+
sea_client.get_columns(
807+
sea_session_id,
808+
100,
809+
1000,
810+
mock_cursor,
811+
catalog_name="catalog",
812+
schema_name="schema",
813+
table_name="table",
814+
column_name="column",
815+
)
816+
assert "get_columns is not implemented for SEA backend" in str(excinfo.value)
817+
818+
def test_execute_command_with_invalid_session_id(self, sea_client, mock_cursor):
819+
"""Test executing a command with an invalid session ID type."""
820+
# Create a Thrift session ID (not SEA)
821+
mock_thrift_handle = MagicMock()
822+
mock_thrift_handle.sessionId.guid = b"guid"
823+
mock_thrift_handle.sessionId.secret = b"secret"
824+
session_id = SessionId.from_thrift_handle(mock_thrift_handle)
825+
826+
# Call the method and expect an error
827+
with pytest.raises(ValueError) as excinfo:
828+
sea_client.execute_command(
829+
operation="SELECT 1",
830+
session_id=session_id,
831+
max_rows=100,
832+
max_bytes=1000,
833+
lz4_compression=False,
834+
cursor=mock_cursor,
835+
use_cloud_fetch=False,
836+
parameters=[],
837+
async_op=False,
838+
enforce_embedded_schema_correctness=False,
839+
)
840+
841+
assert "Not a valid SEA session ID" in str(excinfo.value)

0 commit comments

Comments
 (0)