Skip to content

Commit 34a7f66

Browse files
update unit tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent dd43715 commit 34a7f66

File tree

1 file changed

+223
-63
lines changed

1 file changed

+223
-63
lines changed

tests/unit/test_sea_result_set.py

Lines changed: 223 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77

88
import pytest
99
from unittest.mock import patch, MagicMock, Mock
10+
import logging
1011

11-
from databricks.sql.result_set import SeaResultSet
12-
from databricks.sql.utils import JsonQueue
12+
from databricks.sql.result_set import SeaResultSet, ResultSet
13+
from databricks.sql.utils import JsonQueue, ResultSetQueue
14+
from databricks.sql.types import Row
15+
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
1316
from databricks.sql.backend.types import CommandId, CommandState, BackendType
17+
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
1418

1519

1620
class TestSeaResultSet:
@@ -299,67 +303,6 @@ def mock_arrow_queue(self):
299303

300304
return mock_queue
301305

302-
@patch("pyarrow.concat_tables")
303-
def test_fetchmany_arrow(
304-
self,
305-
mock_concat_tables,
306-
mock_connection,
307-
mock_sea_client,
308-
execute_response,
309-
mock_arrow_queue,
310-
):
311-
"""Test fetchmany_arrow method."""
312-
# Setup mock for pyarrow.concat_tables
313-
mock_concat_result = Mock()
314-
mock_concat_result.num_rows = 3
315-
mock_concat_tables.return_value = mock_concat_result
316-
317-
result_set = SeaResultSet(
318-
connection=mock_connection,
319-
execute_response=execute_response,
320-
sea_client=mock_sea_client,
321-
buffer_size_bytes=1000,
322-
arraysize=100,
323-
)
324-
result_set.results = mock_arrow_queue
325-
326-
# Test with specific size
327-
result = result_set.fetchmany_arrow(5)
328-
329-
# Verify next_n_rows was called with the correct size
330-
mock_arrow_queue.next_n_rows.assert_called_with(5)
331-
332-
# Verify _next_row_index was updated
333-
assert result_set._next_row_index == 2
334-
335-
# Test with negative size
336-
with pytest.raises(
337-
ValueError, match="size argument for fetchmany is -1 but must be >= 0"
338-
):
339-
result_set.fetchmany_arrow(-1)
340-
341-
def test_fetchall_arrow(
342-
self, mock_connection, mock_sea_client, execute_response, mock_arrow_queue
343-
):
344-
"""Test fetchall_arrow method."""
345-
result_set = SeaResultSet(
346-
connection=mock_connection,
347-
execute_response=execute_response,
348-
sea_client=mock_sea_client,
349-
buffer_size_bytes=1000,
350-
arraysize=100,
351-
)
352-
result_set.results = mock_arrow_queue
353-
354-
# Test fetchall_arrow
355-
result = result_set.fetchall_arrow()
356-
357-
# Verify remaining_rows was called
358-
mock_arrow_queue.remaining_rows.assert_called_once()
359-
360-
# Verify _next_row_index was updated
361-
assert result_set._next_row_index == 3
362-
363306
def test_fetchone(
364307
self, mock_connection, mock_sea_client, execute_response, mock_json_queue
365308
):
@@ -441,3 +384,220 @@ def test_fetchall_with_non_json_queue(
441384
NotImplementedError, match="fetchall only supported for JSON data"
442385
):
443386
result_set.fetchall()
387+
388+
def test_iterator_protocol(
389+
self, mock_connection, mock_sea_client, execute_response, mock_json_queue
390+
):
391+
"""Test the iterator protocol (__iter__) implementation."""
392+
result_set = SeaResultSet(
393+
connection=mock_connection,
394+
execute_response=execute_response,
395+
sea_client=mock_sea_client,
396+
buffer_size_bytes=1000,
397+
arraysize=100,
398+
)
399+
result_set.results = mock_json_queue
400+
result_set.description = [
401+
("test_value", "INT", None, None, None, None, None),
402+
]
403+
404+
# Mock fetchone to return a sequence of values and then None
405+
with patch.object(result_set, "fetchone") as mock_fetchone:
406+
mock_fetchone.side_effect = [
407+
Row("test_value")(100),
408+
Row("test_value")(200),
409+
Row("test_value")(300),
410+
None,
411+
]
412+
413+
# Test iterating over the result set
414+
rows = list(result_set)
415+
assert len(rows) == 3
416+
assert rows[0].test_value == 100
417+
assert rows[1].test_value == 200
418+
assert rows[2].test_value == 300
419+
420+
def test_rownumber_property(
421+
self, mock_connection, mock_sea_client, execute_response, mock_json_queue
422+
):
423+
"""Test the rownumber property."""
424+
result_set = SeaResultSet(
425+
connection=mock_connection,
426+
execute_response=execute_response,
427+
sea_client=mock_sea_client,
428+
buffer_size_bytes=1000,
429+
arraysize=100,
430+
)
431+
result_set.results = mock_json_queue
432+
433+
# Initial row number should be 0
434+
assert result_set.rownumber == 0
435+
436+
# After fetching rows, row number should be updated
437+
mock_json_queue.next_n_rows.return_value = [["value1"]]
438+
result_set.fetchmany_json(2)
439+
result_set._next_row_index = 2
440+
assert result_set.rownumber == 2
441+
442+
# After fetching more rows, row number should be incremented
443+
mock_json_queue.next_n_rows.return_value = [["value3"]]
444+
result_set.fetchmany_json(1)
445+
result_set._next_row_index = 3
446+
assert result_set.rownumber == 3
447+
448+
def test_is_staging_operation_property(self, mock_connection, mock_sea_client):
449+
"""Test the is_staging_operation property."""
450+
# Create a response with staging operation set to True
451+
staging_response = Mock()
452+
staging_response.command_id = CommandId.from_sea_statement_id(
453+
"test-staging-123"
454+
)
455+
staging_response.status = CommandState.SUCCEEDED
456+
staging_response.has_been_closed_server_side = False
457+
staging_response.description = []
458+
staging_response.is_staging_operation = True
459+
staging_response.lz4_compressed = False
460+
staging_response.arrow_schema_bytes = b""
461+
462+
# Create a result set with staging operation
463+
result_set = SeaResultSet(
464+
connection=mock_connection,
465+
execute_response=staging_response,
466+
sea_client=mock_sea_client,
467+
buffer_size_bytes=1000,
468+
arraysize=100,
469+
)
470+
471+
# Verify the is_staging_operation property
472+
assert result_set.is_staging_operation is True
473+
474+
def test_init_with_result_data(
475+
self, mock_connection, mock_sea_client, execute_response
476+
):
477+
"""Test initializing SeaResultSet with result data."""
478+
# Create sample result data with a mock
479+
result_data = Mock(spec=ResultData)
480+
result_data.data = [["value1", 123], ["value2", 456]]
481+
result_data.external_links = None
482+
483+
manifest = Mock(spec=ResultManifest)
484+
485+
# Mock the SeaResultSetQueueFactory.build_queue method
486+
with patch(
487+
"databricks.sql.result_set.SeaResultSetQueueFactory"
488+
) as factory_mock:
489+
# Create a mock JsonQueue
490+
mock_queue = Mock(spec=JsonQueue)
491+
factory_mock.build_queue.return_value = mock_queue
492+
493+
result_set = SeaResultSet(
494+
connection=mock_connection,
495+
execute_response=execute_response,
496+
sea_client=mock_sea_client,
497+
buffer_size_bytes=1000,
498+
arraysize=100,
499+
result_data=result_data,
500+
manifest=manifest,
501+
)
502+
503+
# Verify the factory was called with the right parameters
504+
factory_mock.build_queue.assert_called_once_with(
505+
result_data,
506+
manifest,
507+
str(execute_response.command_id.to_sea_statement_id()),
508+
description=execute_response.description,
509+
max_download_threads=mock_sea_client.max_download_threads,
510+
ssl_options=mock_sea_client.ssl_options,
511+
sea_client=mock_sea_client,
512+
lz4_compressed=execute_response.lz4_compressed,
513+
)
514+
515+
# Verify the results queue was set correctly
516+
assert result_set.results == mock_queue
517+
518+
def test_close_with_request_error(
519+
self, mock_connection, mock_sea_client, execute_response
520+
):
521+
"""Test closing a result set when a RequestError is raised."""
522+
result_set = SeaResultSet(
523+
connection=mock_connection,
524+
execute_response=execute_response,
525+
sea_client=mock_sea_client,
526+
buffer_size_bytes=1000,
527+
arraysize=100,
528+
)
529+
530+
# Create a patched version of the close method that doesn't check e.args[1]
531+
with patch("databricks.sql.result_set.ResultSet.close") as mock_close:
532+
# Call the close method
533+
result_set.close()
534+
535+
# Verify the parent's close method was called
536+
mock_close.assert_called_once()
537+
538+
def test_init_with_empty_result_data(
539+
self, mock_connection, mock_sea_client, execute_response
540+
):
541+
"""Test initializing SeaResultSet with empty result data."""
542+
# Create sample result data with a mock
543+
result_data = Mock(spec=ResultData)
544+
result_data.data = None
545+
result_data.external_links = None
546+
547+
manifest = Mock(spec=ResultManifest)
548+
549+
result_set = SeaResultSet(
550+
connection=mock_connection,
551+
execute_response=execute_response,
552+
sea_client=mock_sea_client,
553+
buffer_size_bytes=1000,
554+
arraysize=100,
555+
result_data=result_data,
556+
manifest=manifest,
557+
)
558+
559+
# Verify an empty JsonQueue was created
560+
assert isinstance(result_set.results, JsonQueue)
561+
assert result_set.results.data_array == []
562+
563+
def test_init_without_result_data(
564+
self, mock_connection, mock_sea_client, execute_response
565+
):
566+
"""Test initializing SeaResultSet without result data."""
567+
result_set = SeaResultSet(
568+
connection=mock_connection,
569+
execute_response=execute_response,
570+
sea_client=mock_sea_client,
571+
buffer_size_bytes=1000,
572+
arraysize=100,
573+
)
574+
575+
# Verify an empty JsonQueue was created
576+
assert isinstance(result_set.results, JsonQueue)
577+
assert result_set.results.data_array == []
578+
579+
def test_init_with_external_links(
580+
self, mock_connection, mock_sea_client, execute_response
581+
):
582+
"""Test initializing SeaResultSet with external links."""
583+
# Create sample result data with external links
584+
result_data = Mock(spec=ResultData)
585+
result_data.data = None
586+
result_data.external_links = ["link1", "link2"]
587+
588+
manifest = Mock(spec=ResultManifest)
589+
590+
# This should raise NotImplementedError
591+
with pytest.raises(
592+
NotImplementedError,
593+
match="EXTERNAL_LINKS disposition is not implemented for SEA backend",
594+
):
595+
SeaResultSet(
596+
connection=mock_connection,
597+
execute_response=execute_response,
598+
sea_client=mock_sea_client,
599+
buffer_size_bytes=1000,
600+
arraysize=100,
601+
result_data=result_data,
602+
manifest=manifest,
603+
)

0 commit comments

Comments
 (0)