Skip to content

Commit d210ccd

Browse files
remove changes in sea result set testing
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent f6c5950 commit d210ccd

File tree

2 files changed

+433
-2
lines changed

2 files changed

+433
-2
lines changed

tests/unit/test_sea_result_set.py

Lines changed: 346 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from databricks.sql.result_set import SeaResultSet
1212
from databricks.sql.backend.types import CommandId, CommandState, BackendType
13+
from databricks.sql.utils import JsonQueue
14+
from databricks.sql.types import Row
1315

1416

1517
class TestSeaResultSet:
@@ -20,12 +22,15 @@ def mock_connection(self):
2022
"""Create a mock connection."""
2123
connection = Mock()
2224
connection.open = True
25+
connection.disable_pandas = False
2326
return connection
2427

2528
@pytest.fixture
2629
def mock_sea_client(self):
2730
"""Create a mock SEA client."""
28-
return Mock()
31+
client = Mock()
32+
client.max_download_threads = 10
33+
return client
2934

3035
@pytest.fixture
3136
def execute_response(self):
@@ -37,11 +42,27 @@ def execute_response(self):
3742
mock_response.is_direct_results = False
3843
mock_response.results_queue = None
3944
mock_response.description = [
40-
("test_value", "INT", None, None, None, None, None)
45+
("col1", "INT", None, None, None, None, None),
46+
("col2", "STRING", None, None, None, None, None),
4147
]
4248
mock_response.is_staging_operation = False
49+
mock_response.lz4_compressed = False
50+
mock_response.arrow_schema_bytes = b""
4351
return mock_response
4452

53+
@pytest.fixture
54+
def mock_result_data(self):
55+
"""Create mock result data."""
56+
result_data = Mock()
57+
result_data.data = [[1, "value1"], [2, "value2"], [3, "value3"]]
58+
result_data.external_links = None
59+
return result_data
60+
61+
@pytest.fixture
62+
def mock_manifest(self):
63+
"""Create a mock manifest."""
64+
return Mock()
65+
4566
def test_init_with_execute_response(
4667
self, mock_connection, mock_sea_client, execute_response
4768
):
@@ -63,6 +84,49 @@ def test_init_with_execute_response(
6384
assert result_set.arraysize == 100
6485
assert result_set.description == execute_response.description
6586

87+
# Verify that a JsonQueue was created with empty data
88+
assert isinstance(result_set.results, JsonQueue)
89+
assert result_set.results.data_array == []
90+
91+
def test_init_with_result_data(
92+
self,
93+
mock_connection,
94+
mock_sea_client,
95+
execute_response,
96+
mock_result_data,
97+
mock_manifest,
98+
):
99+
"""Test initializing SeaResultSet with result data."""
100+
with patch(
101+
"databricks.sql.result_set.SeaResultSetQueueFactory"
102+
) as mock_factory:
103+
mock_queue = Mock(spec=JsonQueue)
104+
mock_factory.build_queue.return_value = mock_queue
105+
106+
result_set = SeaResultSet(
107+
connection=mock_connection,
108+
execute_response=execute_response,
109+
sea_client=mock_sea_client,
110+
buffer_size_bytes=1000,
111+
arraysize=100,
112+
result_data=mock_result_data,
113+
manifest=mock_manifest,
114+
)
115+
116+
# Verify that the factory was called with the correct arguments
117+
mock_factory.build_queue.assert_called_once_with(
118+
mock_result_data,
119+
mock_manifest,
120+
str(execute_response.command_id.to_sea_statement_id()),
121+
description=execute_response.description,
122+
max_download_threads=mock_sea_client.max_download_threads,
123+
sea_client=mock_sea_client,
124+
lz4_compressed=execute_response.lz4_compressed,
125+
)
126+
127+
# Verify that the queue was set correctly
128+
assert result_set.results == mock_queue
129+
66130
def test_close(self, mock_connection, mock_sea_client, execute_response):
67131
"""Test closing a result set."""
68132
result_set = SeaResultSet(
@@ -122,3 +186,283 @@ def test_close_when_connection_closed(
122186
mock_sea_client.close_command.assert_not_called()
123187
assert result_set.has_been_closed_server_side is True
124188
assert result_set.status == CommandState.CLOSED
189+
190+
def test_convert_json_table(
191+
self, mock_connection, mock_sea_client, execute_response
192+
):
193+
"""Test converting JSON data to Row objects."""
194+
result_set = SeaResultSet(
195+
connection=mock_connection,
196+
execute_response=execute_response,
197+
sea_client=mock_sea_client,
198+
buffer_size_bytes=1000,
199+
arraysize=100,
200+
)
201+
202+
# Sample data
203+
data = [[1, "value1"], [2, "value2"]]
204+
205+
# Convert to Row objects
206+
rows = result_set._convert_json_table(data)
207+
208+
# Check that we got Row objects with the correct values
209+
assert len(rows) == 2
210+
assert isinstance(rows[0], Row)
211+
assert rows[0].col1 == 1
212+
assert rows[0].col2 == "value1"
213+
assert rows[1].col1 == 2
214+
assert rows[1].col2 == "value2"
215+
216+
def test_convert_json_table_empty(
217+
self, mock_connection, mock_sea_client, execute_response
218+
):
219+
"""Test converting empty JSON data."""
220+
result_set = SeaResultSet(
221+
connection=mock_connection,
222+
execute_response=execute_response,
223+
sea_client=mock_sea_client,
224+
buffer_size_bytes=1000,
225+
arraysize=100,
226+
)
227+
228+
# Empty data
229+
data = []
230+
231+
# Convert to Row objects
232+
rows = result_set._convert_json_table(data)
233+
234+
# Check that we got an empty list
235+
assert rows == []
236+
237+
def test_convert_json_table_no_description(
238+
self, mock_connection, mock_sea_client, execute_response
239+
):
240+
"""Test converting JSON data with no description."""
241+
execute_response.description = None
242+
result_set = SeaResultSet(
243+
connection=mock_connection,
244+
execute_response=execute_response,
245+
sea_client=mock_sea_client,
246+
buffer_size_bytes=1000,
247+
arraysize=100,
248+
)
249+
250+
# Sample data
251+
data = [[1, "value1"], [2, "value2"]]
252+
253+
# Convert to Row objects
254+
rows = result_set._convert_json_table(data)
255+
256+
# Check that we got the original data
257+
assert rows == data
258+
259+
def test_fetchone(
260+
self, mock_connection, mock_sea_client, execute_response, mock_result_data
261+
):
262+
"""Test fetching one row."""
263+
# Create a result set with data
264+
result_set = SeaResultSet(
265+
connection=mock_connection,
266+
execute_response=execute_response,
267+
sea_client=mock_sea_client,
268+
buffer_size_bytes=1000,
269+
arraysize=100,
270+
result_data=mock_result_data,
271+
)
272+
273+
# Replace the results queue with a JsonQueue containing test data
274+
result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]])
275+
276+
# Fetch one row
277+
row = result_set.fetchone()
278+
279+
# Check that we got a Row object with the correct values
280+
assert isinstance(row, Row)
281+
assert row.col1 == 1
282+
assert row.col2 == "value1"
283+
284+
# Check that the row index was updated
285+
assert result_set._next_row_index == 1
286+
287+
def test_fetchone_empty(self, mock_connection, mock_sea_client, execute_response):
288+
"""Test fetching one row from an empty result set."""
289+
result_set = SeaResultSet(
290+
connection=mock_connection,
291+
execute_response=execute_response,
292+
sea_client=mock_sea_client,
293+
buffer_size_bytes=1000,
294+
arraysize=100,
295+
)
296+
297+
# Fetch one row
298+
row = result_set.fetchone()
299+
300+
# Check that we got None
301+
assert row is None
302+
303+
def test_fetchmany(
304+
self, mock_connection, mock_sea_client, execute_response, mock_result_data
305+
):
306+
"""Test fetching multiple rows."""
307+
# Create a result set with data
308+
result_set = SeaResultSet(
309+
connection=mock_connection,
310+
execute_response=execute_response,
311+
sea_client=mock_sea_client,
312+
buffer_size_bytes=1000,
313+
arraysize=100,
314+
result_data=mock_result_data,
315+
)
316+
317+
# Replace the results queue with a JsonQueue containing test data
318+
result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]])
319+
320+
# Fetch two rows
321+
rows = result_set.fetchmany(2)
322+
323+
# Check that we got two Row objects with the correct values
324+
assert len(rows) == 2
325+
assert isinstance(rows[0], Row)
326+
assert rows[0].col1 == 1
327+
assert rows[0].col2 == "value1"
328+
assert rows[1].col1 == 2
329+
assert rows[1].col2 == "value2"
330+
331+
# Check that the row index was updated
332+
assert result_set._next_row_index == 2
333+
334+
def test_fetchmany_negative_size(
335+
self, mock_connection, mock_sea_client, execute_response
336+
):
337+
"""Test fetching with a negative size."""
338+
result_set = SeaResultSet(
339+
connection=mock_connection,
340+
execute_response=execute_response,
341+
sea_client=mock_sea_client,
342+
buffer_size_bytes=1000,
343+
arraysize=100,
344+
)
345+
346+
# Try to fetch with a negative size
347+
with pytest.raises(
348+
ValueError, match="size argument for fetchmany is -1 but must be >= 0"
349+
):
350+
result_set.fetchmany(-1)
351+
352+
def test_fetchall(
353+
self, mock_connection, mock_sea_client, execute_response, mock_result_data
354+
):
355+
"""Test fetching all rows."""
356+
# Create a result set with data
357+
result_set = SeaResultSet(
358+
connection=mock_connection,
359+
execute_response=execute_response,
360+
sea_client=mock_sea_client,
361+
buffer_size_bytes=1000,
362+
arraysize=100,
363+
result_data=mock_result_data,
364+
)
365+
366+
# Replace the results queue with a JsonQueue containing test data
367+
result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]])
368+
369+
# Fetch all rows
370+
rows = result_set.fetchall()
371+
372+
# Check that we got three Row objects with the correct values
373+
assert len(rows) == 3
374+
assert isinstance(rows[0], Row)
375+
assert rows[0].col1 == 1
376+
assert rows[0].col2 == "value1"
377+
assert rows[1].col1 == 2
378+
assert rows[1].col2 == "value2"
379+
assert rows[2].col1 == 3
380+
assert rows[2].col2 == "value3"
381+
382+
# Check that the row index was updated
383+
assert result_set._next_row_index == 3
384+
385+
def test_fetchmany_json(
386+
self, mock_connection, mock_sea_client, execute_response, mock_result_data
387+
):
388+
"""Test fetching JSON data directly."""
389+
# Create a result set with data
390+
result_set = SeaResultSet(
391+
connection=mock_connection,
392+
execute_response=execute_response,
393+
sea_client=mock_sea_client,
394+
buffer_size_bytes=1000,
395+
arraysize=100,
396+
result_data=mock_result_data,
397+
)
398+
399+
# Replace the results queue with a JsonQueue containing test data
400+
result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]])
401+
402+
# Fetch two rows as JSON
403+
rows = result_set.fetchmany_json(2)
404+
405+
# Check that we got the raw data
406+
assert rows == [[1, "value1"], [2, "value2"]]
407+
408+
# Check that the row index was updated
409+
assert result_set._next_row_index == 2
410+
411+
def test_fetchall_json(
412+
self, mock_connection, mock_sea_client, execute_response, mock_result_data
413+
):
414+
"""Test fetching all JSON data directly."""
415+
# Create a result set with data
416+
result_set = SeaResultSet(
417+
connection=mock_connection,
418+
execute_response=execute_response,
419+
sea_client=mock_sea_client,
420+
buffer_size_bytes=1000,
421+
arraysize=100,
422+
result_data=mock_result_data,
423+
)
424+
425+
# Replace the results queue with a JsonQueue containing test data
426+
result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]])
427+
428+
# Fetch all rows as JSON
429+
rows = result_set.fetchall_json()
430+
431+
# Check that we got the raw data
432+
assert rows == [[1, "value1"], [2, "value2"], [3, "value3"]]
433+
434+
# Check that the row index was updated
435+
assert result_set._next_row_index == 3
436+
437+
def test_iteration(
438+
self, mock_connection, mock_sea_client, execute_response, mock_result_data
439+
):
440+
"""Test iterating over the result set."""
441+
# Create a result set with data
442+
result_set = SeaResultSet(
443+
connection=mock_connection,
444+
execute_response=execute_response,
445+
sea_client=mock_sea_client,
446+
buffer_size_bytes=1000,
447+
arraysize=100,
448+
result_data=mock_result_data,
449+
)
450+
451+
# Replace the results queue with a JsonQueue containing test data
452+
result_set.results = JsonQueue([[1, "value1"], [2, "value2"], [3, "value3"]])
453+
454+
# Iterate over the result set
455+
rows = list(result_set)
456+
457+
# Check that we got three Row objects with the correct values
458+
assert len(rows) == 3
459+
assert isinstance(rows[0], Row)
460+
assert rows[0].col1 == 1
461+
assert rows[0].col2 == "value1"
462+
assert rows[1].col1 == 2
463+
assert rows[1].col2 == "value2"
464+
assert rows[2].col1 == 3
465+
assert rows[2].col2 == "value3"
466+
467+
# Check that the row index was updated
468+
assert result_set._next_row_index == 3

0 commit comments

Comments
 (0)