From 2728848dfe68e4b295d997aa8372c1be84facf36 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 11 Jul 2024 19:01:23 +0300 Subject: [PATCH 1/3] Disable SSL verification for CloudFetch links Signed-off-by: Levko Kravets --- src/databricks/sql/cloudfetch/downloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 3b1e01263..6663db7d7 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -95,7 +95,9 @@ def run(self) -> DownloadedFile: try: # Get the file via HTTP request response = session.get( - self.link.fileLink, timeout=self.settings.download_timeout + self.link.fileLink, + timeout=self.settings.download_timeout, + verify=False, ) response.raise_for_status() From 0cd439b6ff0770841b84f47a6a76ec3283fcd988 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Mon, 15 Jul 2024 22:15:21 +0300 Subject: [PATCH 2/3] Use existing `_tls_no_verify` option in CloudFetch downloader Signed-off-by: Levko Kravets --- src/databricks/sql/client.py | 2 ++ src/databricks/sql/cloudfetch/download_manager.py | 9 ++++++++- src/databricks/sql/cloudfetch/downloader.py | 7 ++++++- src/databricks/sql/thrift_backend.py | 6 +++++- src/databricks/sql/utils.py | 13 +++++++++++-- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e56d22f62..084c42dfa 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -171,6 +171,8 @@ def read(self) -> Optional[OAuthToken]: # Which port to connect to # _skip_routing_headers: # Don't set routing headers if set to True (for use when connecting directly to server) + # _tls_no_verify + # Set to True (Boolean) to completely disable SSL verification. # _tls_verify_hostname # Set to False (Boolean) to disable SSL hostname verification, but check certificate. # _tls_trusted_ca_file diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 93b6f623c..e30adcd6e 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,5 +1,6 @@ import logging +from ssl import SSLContext from concurrent.futures import ThreadPoolExecutor, Future from typing import List, Union @@ -19,6 +20,7 @@ def __init__( links: List[TSparkArrowResultLink], max_download_threads: int, lz4_compressed: bool, + ssl_context: SSLContext, ): self._pending_links: List[TSparkArrowResultLink] = [] for link in links: @@ -36,6 +38,7 @@ def __init__( self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) + self._ssl_context = ssl_context def get_next_downloaded_file( self, next_row_offset: int @@ -89,7 +92,11 @@ def _schedule_downloads(self): logger.debug( "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount) ) - handler = ResultSetDownloadHandler(self._downloadable_result_settings, link) + handler = ResultSetDownloadHandler( + settings=self._downloadable_result_settings, + link=link, + ssl_context=self._ssl_context, + ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 6663db7d7..00ffecd02 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -3,6 +3,7 @@ import requests from requests.adapters import HTTPAdapter, Retry +from ssl import SSLContext, CERT_NONE import lz4.frame import time @@ -65,9 +66,11 @@ def __init__( self, settings: DownloadableResultSettings, link: TSparkArrowResultLink, + ssl_context: SSLContext, ): self.settings = settings self.link = link + self._ssl_context = ssl_context def run(self) -> DownloadedFile: """ @@ -92,12 +95,14 @@ def run(self) -> DownloadedFile: session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) + ssl_verify = self._ssl_context.verify_mode != CERT_NONE + try: # Get the file via HTTP request response = session.get( self.link.fileLink, timeout=self.settings.download_timeout, - verify=False, + verify=ssl_verify, ) response.raise_for_status() diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 79293e857..56412fcee 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -184,6 +184,8 @@ def __init__( password=tls_client_cert_key_password, ) + self._ssl_context = ssl_context + self._auth_provider = auth_provider # Connector version 3 retry approach @@ -223,7 +225,7 @@ def __init__( self._transport = databricks.sql.auth.thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri, - ssl_context=ssl_context, + ssl_context=self._ssl_context, **additional_transport_args, # type: ignore ) @@ -774,6 +776,7 @@ def _results_message_to_execute_response(self, resp, operation_state): max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, + ssl_context=self._ssl_context, ) else: arrow_queue_opt = None @@ -1005,6 +1008,7 @@ def fetch_results( max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, + ssl_context=self._ssl_context, ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 4a770079d..c22688bb0 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -9,6 +9,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union import re +from ssl import SSLContext import lz4.frame import pyarrow @@ -47,6 +48,7 @@ def build_queue( t_row_set: TRowSet, arrow_schema_bytes: bytes, max_download_threads: int, + ssl_context: SSLContext, lz4_compressed: bool = True, description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: @@ -60,6 +62,7 @@ def build_queue( lz4_compressed (bool): Whether result data has been lz4 compressed. description (List[List[Any]]): Hive table schema description. max_download_threads (int): Maximum number of downloader thread pool threads. + ssl_context (SSLContext): SSLContext object for CloudFetchQueue Returns: ResultSetQueue @@ -82,12 +85,13 @@ def build_queue( return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( - arrow_schema_bytes, + schema_bytes=arrow_schema_bytes, start_row_offset=t_row_set.startRowOffset, result_links=t_row_set.resultLinks, lz4_compressed=lz4_compressed, description=description, max_download_threads=max_download_threads, + ssl_context=ssl_context, ) else: raise AssertionError("Row set type is not valid") @@ -133,6 +137,7 @@ def __init__( self, schema_bytes, max_download_threads: int, + ssl_context: SSLContext, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -155,6 +160,7 @@ def __init__( self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description + self._ssl_context = ssl_context logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( @@ -169,7 +175,10 @@ def __init__( ) ) self.download_manager = ResultFileDownloadManager( - result_links or [], self.max_download_threads, self.lz4_compressed + links=result_links or [], + max_download_threads=self.max_download_threads, + lz4_compressed=self.lz4_compressed, + ssl_context=self._ssl_context, ) self.table = self._create_next_table() From 7413f5a992acfa8848928943f4b7ef5f68ac971a Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Mon, 15 Jul 2024 22:26:30 +0300 Subject: [PATCH 3/3] Update tests Signed-off-by: Levko Kravets --- tests/unit/test_client.py | 2 +- tests/unit/test_cloud_fetch_queue.py | 111 +++++++++++++++++++++++---- tests/unit/test_download_manager.py | 9 ++- tests/unit/test_downloader.py | 15 ++-- 4 files changed, 113 insertions(+), 24 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 68e8b8303..c86a9f7f5 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -361,7 +361,7 @@ def test_cancel_command_calls_the_backend(self): mock_op_handle = Mock() cursor.active_op_handle = mock_op_handle cursor.cancel() - self.assertTrue(mock_thrift_backend.cancel_command.called_with(mock_op_handle)) + mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle) @patch("databricks.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index e9dfd712d..cd14c676e 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -1,11 +1,11 @@ import pyarrow import unittest from unittest.mock import MagicMock, patch +from ssl import create_default_context from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink import databricks.sql.utils as utils - class CloudFetchQueueSuite(unittest.TestCase): def create_result_link( @@ -47,7 +47,12 @@ def get_schema_bytes(): def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=result_links, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert len(queue.download_manager._pending_links) == 10 assert len(queue.download_manager._download_tasks) == 0 @@ -56,7 +61,12 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() result_links = [] - queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=result_links, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert len(queue.download_manager._pending_links) == 0 assert len(queue.download_manager._download_tasks) == 0 @@ -64,7 +74,12 @@ def test_initializer_no_links_to_add(self): @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.CloudFetchQueue(MagicMock(), result_links=[], max_download_threads=10) + queue = utils.CloudFetchQueue( + MagicMock(), + result_links=[], + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue._create_next_table() is None mock_get_next_downloaded_file.assert_called_with(0) @@ -75,7 +90,13 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) expected_result = self.make_arrow_table() mock_get_next_downloaded_file.assert_called_with(0) @@ -94,7 +115,13 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -108,7 +135,13 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -122,7 +155,13 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -136,7 +175,13 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -149,7 +194,13 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table is None result = queue.next_n_rows(100) @@ -160,7 +211,13 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 4 @@ -173,7 +230,13 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 2 @@ -186,7 +249,13 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -199,7 +268,13 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 3 @@ -213,7 +288,13 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + queue = utils.CloudFetchQueue( + schema_bytes, + result_links=[], + description=description, + max_download_threads=10, + ssl_context=create_default_context(), + ) assert queue.table is None result = queue.remaining_rows() diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 7a35e65aa..c084d8e47 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -1,6 +1,8 @@ import unittest from unittest.mock import patch, MagicMock +from ssl import create_default_context + import databricks.sql.cloudfetch.download_manager as download_manager from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink @@ -11,7 +13,12 @@ class DownloadManagerTests(unittest.TestCase): """ def create_download_manager(self, links, max_download_threads=10, lz4_compressed=True): - return download_manager.ResultFileDownloadManager(links, max_download_threads, lz4_compressed) + return download_manager.ResultFileDownloadManager( + links, + max_download_threads, + lz4_compressed, + ssl_context=create_default_context(), + ) def create_result_link( self, diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index e138cdbb9..b6e473b5a 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch, MagicMock import requests +from ssl import create_default_context import databricks.sql.cloudfetch.downloader as downloader from databricks.sql.exc import Error @@ -25,7 +26,7 @@ def test_run_link_expired(self, mock_time): result_link = Mock() # Already expired result_link.expiryTime = 999 - d = downloader.ResultSetDownloadHandler(settings, result_link) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) with self.assertRaises(Error) as context: d.run() @@ -39,7 +40,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): result_link = Mock() # Within the expiry buffer time result_link.expiryTime = 1004 - d = downloader.ResultSetDownloadHandler(settings, result_link) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) with self.assertRaises(Error) as context: d.run() @@ -57,7 +58,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() self.assertTrue('404' in str(context.exception)) @@ -72,7 +73,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -88,7 +89,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -101,7 +102,7 @@ def test_download_connection_error(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) with self.assertRaises(ConnectionError): d.run() @@ -113,6 +114,6 @@ def test_download_timeout(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) with self.assertRaises(TimeoutError): d.run()