Skip to content

Commit 2d10cb4

Browse files
introduce unit tests for added methods in THttpClient
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 9280fc2 commit 2d10cb4

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

tests/unit/test_thrift_http_client.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import unittest
2+
import json
3+
import urllib
4+
from unittest.mock import patch, Mock, MagicMock
5+
import urllib3
6+
from http.client import HTTPResponse
7+
from io import BytesIO
8+
9+
from databricks.sql.auth.thrift_http_client import THttpClient
10+
from databricks.sql.exc import RequestError
11+
from databricks.sql.auth.retry import DatabricksRetryPolicy
12+
from databricks.sql.types import SSLOptions
13+
14+
15+
class TestTHttpClient(unittest.TestCase):
16+
"""Unit tests for the THttpClient class."""
17+
18+
@patch("urllib.request.getproxies")
19+
@patch("urllib.request.proxy_bypass")
20+
def setUp(self, mock_proxy_bypass, mock_getproxies):
21+
"""Set up test fixtures."""
22+
# Mock proxy functions
23+
mock_getproxies.return_value = {}
24+
mock_proxy_bypass.return_value = True
25+
26+
# Create auth provider mock
27+
self.mock_auth_provider = Mock()
28+
self.mock_auth_provider.add_headers = Mock()
29+
30+
# Create HTTP client
31+
self.uri = "https://example.com/path"
32+
self.http_client = THttpClient(
33+
auth_provider=self.mock_auth_provider,
34+
uri_or_host=self.uri,
35+
ssl_options=SSLOptions(),
36+
)
37+
38+
# Mock the connection pool
39+
self.mock_pool = Mock()
40+
self.http_client._THttpClient__pool = self.mock_pool
41+
42+
# Set custom headers to include User-Agent (required by the class)
43+
self.http_client._headers = {"User-Agent": "test-agent"}
44+
self.http_client.__custom_headers = {"User-Agent": "test-agent"}
45+
46+
def test_check_rest_response_for_error_success(self):
47+
"""Test _check_rest_response_for_error with success status."""
48+
# No exception should be raised for status codes < 400
49+
self.http_client._check_rest_response_for_error(200, None)
50+
self.http_client._check_rest_response_for_error(201, None)
51+
self.http_client._check_rest_response_for_error(302, None)
52+
# No assertion needed - test passes if no exception is raised
53+
54+
def test_check_rest_response_for_error_client_error(self):
55+
"""Test _check_rest_response_for_error with client error status."""
56+
# Setup response data with error message
57+
response_data = json.dumps({"message": "Bad request"}).encode("utf-8")
58+
59+
# Check that exception is raised for client error
60+
with self.assertRaises(RequestError) as context:
61+
self.http_client._check_rest_response_for_error(400, response_data)
62+
63+
# Verify the exception message
64+
self.assertIn(
65+
"REST HTTP request failed with status 400", str(context.exception)
66+
)
67+
self.assertIn("Bad request", str(context.exception))
68+
69+
def test_check_rest_response_for_error_server_error(self):
70+
"""Test _check_rest_response_for_error with server error status."""
71+
# Setup response data with error message
72+
response_data = json.dumps({"message": "Internal server error"}).encode("utf-8")
73+
74+
# Check that exception is raised for server error
75+
with self.assertRaises(RequestError) as context:
76+
self.http_client._check_rest_response_for_error(500, response_data)
77+
78+
# Verify the exception message
79+
self.assertIn(
80+
"REST HTTP request failed with status 500", str(context.exception)
81+
)
82+
self.assertIn("Internal server error", str(context.exception))
83+
84+
def test_check_rest_response_for_error_no_message(self):
85+
"""Test _check_rest_response_for_error with error but no message."""
86+
# Check that exception is raised with generic message
87+
with self.assertRaises(RequestError) as context:
88+
self.http_client._check_rest_response_for_error(404, None)
89+
90+
# Verify the exception message
91+
self.assertIn(
92+
"REST HTTP request failed with status 404", str(context.exception)
93+
)
94+
95+
def test_check_rest_response_for_error_invalid_json(self):
96+
"""Test _check_rest_response_for_error with invalid JSON response."""
97+
# Setup invalid JSON response
98+
response_data = "Not a JSON response".encode("utf-8")
99+
100+
# Check that exception is raised with generic message
101+
with self.assertRaises(RequestError) as context:
102+
self.http_client._check_rest_response_for_error(500, response_data)
103+
104+
# Verify the exception message
105+
self.assertIn(
106+
"REST HTTP request failed with status 500", str(context.exception)
107+
)
108+
109+
@patch("databricks.sql.auth.thrift_http_client.THttpClient.make_rest_request")
110+
def test_make_rest_request_integration(self, mock_make_rest_request):
111+
"""Test that make_rest_request can be called with the expected parameters."""
112+
# Setup mock return value
113+
expected_result = {"result": "success"}
114+
mock_make_rest_request.return_value = expected_result
115+
116+
# Call the original method to verify it works
117+
result = self.http_client.make_rest_request(
118+
method="GET",
119+
endpoint_path="test/endpoint",
120+
params={"param": "value"},
121+
data={"key": "value"},
122+
headers={"X-Custom-Header": "custom-value"},
123+
)
124+
125+
# Verify the result
126+
self.assertEqual(result, expected_result)
127+
128+
# Verify the method was called with correct parameters
129+
mock_make_rest_request.assert_called_once_with(
130+
method="GET",
131+
endpoint_path="test/endpoint",
132+
params={"param": "value"},
133+
data={"key": "value"},
134+
headers={"X-Custom-Header": "custom-value"},
135+
)
136+
137+
138+
if __name__ == "__main__":
139+
unittest.main()

0 commit comments

Comments
 (0)