|
7 | 7 |
|
8 | 8 | import pytest
|
9 | 9 | from unittest.mock import patch, MagicMock, Mock
|
| 10 | +import logging |
10 | 11 |
|
11 |
| -from databricks.sql.result_set import SeaResultSet |
12 |
| -from databricks.sql.utils import JsonQueue |
| 12 | +from databricks.sql.result_set import SeaResultSet, ResultSet |
| 13 | +from databricks.sql.utils import JsonQueue, ResultSetQueue |
| 14 | +from databricks.sql.types import Row |
| 15 | +from databricks.sql.backend.sea.models.base import ResultData, ResultManifest |
13 | 16 | from databricks.sql.backend.types import CommandId, CommandState, BackendType
|
| 17 | +from databricks.sql.exc import RequestError, CursorAlreadyClosedError |
14 | 18 |
|
15 | 19 |
|
16 | 20 | class TestSeaResultSet:
|
@@ -299,67 +303,6 @@ def mock_arrow_queue(self):
|
299 | 303 |
|
300 | 304 | return mock_queue
|
301 | 305 |
|
302 |
| - @patch("pyarrow.concat_tables") |
303 |
| - def test_fetchmany_arrow( |
304 |
| - self, |
305 |
| - mock_concat_tables, |
306 |
| - mock_connection, |
307 |
| - mock_sea_client, |
308 |
| - execute_response, |
309 |
| - mock_arrow_queue, |
310 |
| - ): |
311 |
| - """Test fetchmany_arrow method.""" |
312 |
| - # Setup mock for pyarrow.concat_tables |
313 |
| - mock_concat_result = Mock() |
314 |
| - mock_concat_result.num_rows = 3 |
315 |
| - mock_concat_tables.return_value = mock_concat_result |
316 |
| - |
317 |
| - result_set = SeaResultSet( |
318 |
| - connection=mock_connection, |
319 |
| - execute_response=execute_response, |
320 |
| - sea_client=mock_sea_client, |
321 |
| - buffer_size_bytes=1000, |
322 |
| - arraysize=100, |
323 |
| - ) |
324 |
| - result_set.results = mock_arrow_queue |
325 |
| - |
326 |
| - # Test with specific size |
327 |
| - result = result_set.fetchmany_arrow(5) |
328 |
| - |
329 |
| - # Verify next_n_rows was called with the correct size |
330 |
| - mock_arrow_queue.next_n_rows.assert_called_with(5) |
331 |
| - |
332 |
| - # Verify _next_row_index was updated |
333 |
| - assert result_set._next_row_index == 2 |
334 |
| - |
335 |
| - # Test with negative size |
336 |
| - with pytest.raises( |
337 |
| - ValueError, match="size argument for fetchmany is -1 but must be >= 0" |
338 |
| - ): |
339 |
| - result_set.fetchmany_arrow(-1) |
340 |
| - |
341 |
| - def test_fetchall_arrow( |
342 |
| - self, mock_connection, mock_sea_client, execute_response, mock_arrow_queue |
343 |
| - ): |
344 |
| - """Test fetchall_arrow method.""" |
345 |
| - result_set = SeaResultSet( |
346 |
| - connection=mock_connection, |
347 |
| - execute_response=execute_response, |
348 |
| - sea_client=mock_sea_client, |
349 |
| - buffer_size_bytes=1000, |
350 |
| - arraysize=100, |
351 |
| - ) |
352 |
| - result_set.results = mock_arrow_queue |
353 |
| - |
354 |
| - # Test fetchall_arrow |
355 |
| - result = result_set.fetchall_arrow() |
356 |
| - |
357 |
| - # Verify remaining_rows was called |
358 |
| - mock_arrow_queue.remaining_rows.assert_called_once() |
359 |
| - |
360 |
| - # Verify _next_row_index was updated |
361 |
| - assert result_set._next_row_index == 3 |
362 |
| - |
363 | 306 | def test_fetchone(
|
364 | 307 | self, mock_connection, mock_sea_client, execute_response, mock_json_queue
|
365 | 308 | ):
|
@@ -441,3 +384,220 @@ def test_fetchall_with_non_json_queue(
|
441 | 384 | NotImplementedError, match="fetchall only supported for JSON data"
|
442 | 385 | ):
|
443 | 386 | result_set.fetchall()
|
| 387 | + |
| 388 | + def test_iterator_protocol( |
| 389 | + self, mock_connection, mock_sea_client, execute_response, mock_json_queue |
| 390 | + ): |
| 391 | + """Test the iterator protocol (__iter__) implementation.""" |
| 392 | + result_set = SeaResultSet( |
| 393 | + connection=mock_connection, |
| 394 | + execute_response=execute_response, |
| 395 | + sea_client=mock_sea_client, |
| 396 | + buffer_size_bytes=1000, |
| 397 | + arraysize=100, |
| 398 | + ) |
| 399 | + result_set.results = mock_json_queue |
| 400 | + result_set.description = [ |
| 401 | + ("test_value", "INT", None, None, None, None, None), |
| 402 | + ] |
| 403 | + |
| 404 | + # Mock fetchone to return a sequence of values and then None |
| 405 | + with patch.object(result_set, "fetchone") as mock_fetchone: |
| 406 | + mock_fetchone.side_effect = [ |
| 407 | + Row("test_value")(100), |
| 408 | + Row("test_value")(200), |
| 409 | + Row("test_value")(300), |
| 410 | + None, |
| 411 | + ] |
| 412 | + |
| 413 | + # Test iterating over the result set |
| 414 | + rows = list(result_set) |
| 415 | + assert len(rows) == 3 |
| 416 | + assert rows[0].test_value == 100 |
| 417 | + assert rows[1].test_value == 200 |
| 418 | + assert rows[2].test_value == 300 |
| 419 | + |
| 420 | + def test_rownumber_property( |
| 421 | + self, mock_connection, mock_sea_client, execute_response, mock_json_queue |
| 422 | + ): |
| 423 | + """Test the rownumber property.""" |
| 424 | + result_set = SeaResultSet( |
| 425 | + connection=mock_connection, |
| 426 | + execute_response=execute_response, |
| 427 | + sea_client=mock_sea_client, |
| 428 | + buffer_size_bytes=1000, |
| 429 | + arraysize=100, |
| 430 | + ) |
| 431 | + result_set.results = mock_json_queue |
| 432 | + |
| 433 | + # Initial row number should be 0 |
| 434 | + assert result_set.rownumber == 0 |
| 435 | + |
| 436 | + # After fetching rows, row number should be updated |
| 437 | + mock_json_queue.next_n_rows.return_value = [["value1"]] |
| 438 | + result_set.fetchmany_json(2) |
| 439 | + result_set._next_row_index = 2 |
| 440 | + assert result_set.rownumber == 2 |
| 441 | + |
| 442 | + # After fetching more rows, row number should be incremented |
| 443 | + mock_json_queue.next_n_rows.return_value = [["value3"]] |
| 444 | + result_set.fetchmany_json(1) |
| 445 | + result_set._next_row_index = 3 |
| 446 | + assert result_set.rownumber == 3 |
| 447 | + |
| 448 | + def test_is_staging_operation_property(self, mock_connection, mock_sea_client): |
| 449 | + """Test the is_staging_operation property.""" |
| 450 | + # Create a response with staging operation set to True |
| 451 | + staging_response = Mock() |
| 452 | + staging_response.command_id = CommandId.from_sea_statement_id( |
| 453 | + "test-staging-123" |
| 454 | + ) |
| 455 | + staging_response.status = CommandState.SUCCEEDED |
| 456 | + staging_response.has_been_closed_server_side = False |
| 457 | + staging_response.description = [] |
| 458 | + staging_response.is_staging_operation = True |
| 459 | + staging_response.lz4_compressed = False |
| 460 | + staging_response.arrow_schema_bytes = b"" |
| 461 | + |
| 462 | + # Create a result set with staging operation |
| 463 | + result_set = SeaResultSet( |
| 464 | + connection=mock_connection, |
| 465 | + execute_response=staging_response, |
| 466 | + sea_client=mock_sea_client, |
| 467 | + buffer_size_bytes=1000, |
| 468 | + arraysize=100, |
| 469 | + ) |
| 470 | + |
| 471 | + # Verify the is_staging_operation property |
| 472 | + assert result_set.is_staging_operation is True |
| 473 | + |
| 474 | + def test_init_with_result_data( |
| 475 | + self, mock_connection, mock_sea_client, execute_response |
| 476 | + ): |
| 477 | + """Test initializing SeaResultSet with result data.""" |
| 478 | + # Create sample result data with a mock |
| 479 | + result_data = Mock(spec=ResultData) |
| 480 | + result_data.data = [["value1", 123], ["value2", 456]] |
| 481 | + result_data.external_links = None |
| 482 | + |
| 483 | + manifest = Mock(spec=ResultManifest) |
| 484 | + |
| 485 | + # Mock the SeaResultSetQueueFactory.build_queue method |
| 486 | + with patch( |
| 487 | + "databricks.sql.result_set.SeaResultSetQueueFactory" |
| 488 | + ) as factory_mock: |
| 489 | + # Create a mock JsonQueue |
| 490 | + mock_queue = Mock(spec=JsonQueue) |
| 491 | + factory_mock.build_queue.return_value = mock_queue |
| 492 | + |
| 493 | + result_set = SeaResultSet( |
| 494 | + connection=mock_connection, |
| 495 | + execute_response=execute_response, |
| 496 | + sea_client=mock_sea_client, |
| 497 | + buffer_size_bytes=1000, |
| 498 | + arraysize=100, |
| 499 | + result_data=result_data, |
| 500 | + manifest=manifest, |
| 501 | + ) |
| 502 | + |
| 503 | + # Verify the factory was called with the right parameters |
| 504 | + factory_mock.build_queue.assert_called_once_with( |
| 505 | + result_data, |
| 506 | + manifest, |
| 507 | + str(execute_response.command_id.to_sea_statement_id()), |
| 508 | + description=execute_response.description, |
| 509 | + max_download_threads=mock_sea_client.max_download_threads, |
| 510 | + ssl_options=mock_sea_client.ssl_options, |
| 511 | + sea_client=mock_sea_client, |
| 512 | + lz4_compressed=execute_response.lz4_compressed, |
| 513 | + ) |
| 514 | + |
| 515 | + # Verify the results queue was set correctly |
| 516 | + assert result_set.results == mock_queue |
| 517 | + |
| 518 | + def test_close_with_request_error( |
| 519 | + self, mock_connection, mock_sea_client, execute_response |
| 520 | + ): |
| 521 | + """Test closing a result set when a RequestError is raised.""" |
| 522 | + result_set = SeaResultSet( |
| 523 | + connection=mock_connection, |
| 524 | + execute_response=execute_response, |
| 525 | + sea_client=mock_sea_client, |
| 526 | + buffer_size_bytes=1000, |
| 527 | + arraysize=100, |
| 528 | + ) |
| 529 | + |
| 530 | + # Create a patched version of the close method that doesn't check e.args[1] |
| 531 | + with patch("databricks.sql.result_set.ResultSet.close") as mock_close: |
| 532 | + # Call the close method |
| 533 | + result_set.close() |
| 534 | + |
| 535 | + # Verify the parent's close method was called |
| 536 | + mock_close.assert_called_once() |
| 537 | + |
| 538 | + def test_init_with_empty_result_data( |
| 539 | + self, mock_connection, mock_sea_client, execute_response |
| 540 | + ): |
| 541 | + """Test initializing SeaResultSet with empty result data.""" |
| 542 | + # Create sample result data with a mock |
| 543 | + result_data = Mock(spec=ResultData) |
| 544 | + result_data.data = None |
| 545 | + result_data.external_links = None |
| 546 | + |
| 547 | + manifest = Mock(spec=ResultManifest) |
| 548 | + |
| 549 | + result_set = SeaResultSet( |
| 550 | + connection=mock_connection, |
| 551 | + execute_response=execute_response, |
| 552 | + sea_client=mock_sea_client, |
| 553 | + buffer_size_bytes=1000, |
| 554 | + arraysize=100, |
| 555 | + result_data=result_data, |
| 556 | + manifest=manifest, |
| 557 | + ) |
| 558 | + |
| 559 | + # Verify an empty JsonQueue was created |
| 560 | + assert isinstance(result_set.results, JsonQueue) |
| 561 | + assert result_set.results.data_array == [] |
| 562 | + |
| 563 | + def test_init_without_result_data( |
| 564 | + self, mock_connection, mock_sea_client, execute_response |
| 565 | + ): |
| 566 | + """Test initializing SeaResultSet without result data.""" |
| 567 | + result_set = SeaResultSet( |
| 568 | + connection=mock_connection, |
| 569 | + execute_response=execute_response, |
| 570 | + sea_client=mock_sea_client, |
| 571 | + buffer_size_bytes=1000, |
| 572 | + arraysize=100, |
| 573 | + ) |
| 574 | + |
| 575 | + # Verify an empty JsonQueue was created |
| 576 | + assert isinstance(result_set.results, JsonQueue) |
| 577 | + assert result_set.results.data_array == [] |
| 578 | + |
| 579 | + def test_init_with_external_links( |
| 580 | + self, mock_connection, mock_sea_client, execute_response |
| 581 | + ): |
| 582 | + """Test initializing SeaResultSet with external links.""" |
| 583 | + # Create sample result data with external links |
| 584 | + result_data = Mock(spec=ResultData) |
| 585 | + result_data.data = None |
| 586 | + result_data.external_links = ["link1", "link2"] |
| 587 | + |
| 588 | + manifest = Mock(spec=ResultManifest) |
| 589 | + |
| 590 | + # This should raise NotImplementedError |
| 591 | + with pytest.raises( |
| 592 | + NotImplementedError, |
| 593 | + match="EXTERNAL_LINKS disposition is not implemented for SEA backend", |
| 594 | + ): |
| 595 | + SeaResultSet( |
| 596 | + connection=mock_connection, |
| 597 | + execute_response=execute_response, |
| 598 | + sea_client=mock_sea_client, |
| 599 | + buffer_size_bytes=1000, |
| 600 | + arraysize=100, |
| 601 | + result_data=result_data, |
| 602 | + manifest=manifest, |
| 603 | + ) |
0 commit comments