From 5d0221606adaa86ffa7ae20c93d623812baea998 Mon Sep 17 00:00:00 2001 From: ponyisi Date: Tue, 13 May 2025 11:57:17 -0500 Subject: [PATCH 1/2] Change from aiohttp to httpx --- pyproject.toml | 2 +- servicex/servicex_adapter.py | 270 ++++++++++++++++----------------- tests/test_servicex_adapter.py | 174 +++++++++------------ 3 files changed, 210 insertions(+), 236 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f606f7c3..a3a7acf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,8 @@ dependencies = [ "func_adl>=3.2.6", "requests>=2.31", "pydantic>=2.6.0", - "aiohttp-retry>=2.8.3", "httpx>=0.24", + "httpx_retries>=0.3.2", "aioboto3>=14.1.0", "tinydb>=4.7", "google-auth>=2.17", diff --git a/servicex/servicex_adapter.py b/servicex/servicex_adapter.py index baa7b5b0..d01aa06d 100644 --- a/servicex/servicex_adapter.py +++ b/servicex/servicex_adapter.py @@ -29,10 +29,10 @@ import time from typing import Optional, Dict, List -from aiohttp import ClientSession import httpx -from aiohttp_retry import RetryClient, ExponentialRetry, ClientResponse -from aiohttp import ContentTypeError +from httpx import AsyncClient as ClientSession, Response +from json import JSONDecodeError +from httpx_retries import RetryTransport, Retry from google.auth import jwt from tenacity import ( AsyncRetrying, @@ -48,12 +48,12 @@ class AuthorizationError(BaseException): pass -async def _extract_message(r: ClientResponse): +async def _extract_message(r: Response): try: o = await r.json() error_message = o.get("message", str(r)) - except ContentTypeError: - error_message = await r.text() + except JSONDecodeError: + error_message = r.text return error_message @@ -66,15 +66,15 @@ def __init__(self, url: str, refresh_token: Optional[str] = None): async def _get_token(self): url = f"{self.url}/token/refresh" headers = {"Authorization": f"Bearer {self.refresh_token}"} - async with RetryClient() as client: - async with client.post(url, headers=headers, json=None) as r: - if r.status == 200: - o = await r.json() - self.token = o["access_token"] - else: - raise AuthorizationError( - f"ServiceX access token request rejected [{r.status} {r.reason}]" - ) + async with ClientSession() as client: + r = await client.post(url, headers=headers, json=None) + if r.status_code == 200: + o = await r.json() + self.token = o["access_token"] + else: + raise AuthorizationError( + f"ServiceX access token request rejected [{r.status_code} {r.reason_phrase}]" + ) @staticmethod def _get_bearer_token_file(): @@ -112,23 +112,25 @@ async def _get_authorization(self, force_reauth: bool = False) -> Dict[str, str] async def get_transforms(self) -> List[TransformStatus]: headers = await self._get_authorization() - retry_options = ExponentialRetry(attempts=3, start_timeout=10) - async with RetryClient(retry_options=retry_options) as client: - async with client.get( + retry_options = Retry(total=3, backoff_factor=10) + async with ClientSession( + transport=RetryTransport(retry=retry_options) + ) as client: + r = await client.get( url=f"{self.url}/servicex/transformation", headers=headers - ) as r: - if r.status == 401: - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - elif r.status > 400: - error_message = await _extract_message(r) - raise RuntimeError( - "ServiceX WebAPI Error during transformation " - f"submission: {r.status} - {error_message}" - ) - o = await r.json() - statuses = [TransformStatus(**status) for status in o["requests"]] + ) + if r.status_code == 401: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code > 400: + error_message = await _extract_message(r) + raise RuntimeError( + "ServiceX WebAPI Error during transformation " + f"submission: {r.status_code} - {error_message}" + ) + o = await r.json() + statuses = [TransformStatus(**status) for status in o["requests"]] return statuses def get_code_generators(self): @@ -150,19 +152,18 @@ async def get_datasets( params["show-deleted"] = True async with ClientSession() as session: - async with session.get( + r = await session.get( headers=headers, url=f"{self.url}/servicex/datasets", params=params - ) as r: - - if r.status == 403: - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - elif r.status != 200: - msg = await _extract_message(r) - raise RuntimeError(f"Failed to get datasets: {r.status} - {msg}") + ) + if r.status_code == 403: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code != 200: + msg = await _extract_message(r) + raise RuntimeError(f"Failed to get datasets: {r.status_code} - {msg}") - result = await r.json() + result = await r.json() datasets = [CachedDataset(**d) for d in result["datasets"]] return datasets @@ -172,21 +173,20 @@ async def get_dataset(self, dataset_id=None) -> CachedDataset: path_template = "/servicex/datasets/{dataset_id}" url = self.url + path_template.format(dataset_id=dataset_id) async with ClientSession() as session: - async with session.get(headers=headers, url=url) as r: - - if r.status == 403: - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - elif r.status == 404: - raise ValueError(f"Dataset {dataset_id} not found") - elif r.status != 200: - msg = await _extract_message(r) - raise RuntimeError(f"Failed to get dataset {dataset_id} - {msg}") - result = await r.json() + r = await session.get(headers=headers, url=url) + if r.status_code == 403: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code == 404: + raise ValueError(f"Dataset {dataset_id} not found") + elif r.status_code != 200: + msg = await _extract_message(r) + raise RuntimeError(f"Failed to get dataset {dataset_id} - {msg}") + result = await r.json() - dataset = CachedDataset(**result) - return dataset + dataset = CachedDataset(**result) + return dataset async def delete_dataset(self, dataset_id=None) -> bool: headers = await self._get_authorization() @@ -194,19 +194,18 @@ async def delete_dataset(self, dataset_id=None) -> bool: url = self.url + path_template.format(dataset_id=dataset_id) async with ClientSession() as session: - async with session.delete(headers=headers, url=url) as r: - - if r.status == 403: - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - elif r.status == 404: - raise ValueError(f"Dataset {dataset_id} not found") - elif r.status != 200: - msg = await _extract_message(r) - raise RuntimeError(f"Failed to delete dataset {dataset_id} - {msg}") - result = await r.json() - return result["stale"] + r = await session.delete(headers=headers, url=url) + if r.status_code == 403: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code == 404: + raise ValueError(f"Dataset {dataset_id} not found") + elif r.status_code != 200: + msg = await _extract_message(r) + raise RuntimeError(f"Failed to delete dataset {dataset_id} - {msg}") + result = await r.json() + return result["stale"] async def delete_transform(self, transform_id=None): headers = await self._get_authorization() @@ -214,19 +213,16 @@ async def delete_transform(self, transform_id=None): url = self.url + path_template.format(transform_id=transform_id) async with ClientSession() as session: - async with session.delete(headers=headers, url=url) as r: - - if r.status == 403: - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - elif r.status == 404: - raise ValueError(f"Transform {transform_id} not found") - elif r.status != 200: - msg = await _extract_message(r) - raise RuntimeError( - f"Failed to delete transform {transform_id} - {msg}" - ) + r = await session.delete(headers=headers, url=url) + if r.status_code == 403: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code == 404: + raise ValueError(f"Transform {transform_id} not found") + elif r.status_code != 200: + msg = await _extract_message(r) + raise RuntimeError(f"Failed to delete transform {transform_id} - {msg}") async def cancel_transform(self, transform_id=None): headers = await self._get_authorization() @@ -234,50 +230,51 @@ async def cancel_transform(self, transform_id=None): url = self.url + path_template.format(transform_id=transform_id) async with ClientSession() as session: - async with session.get(headers=headers, url=url) as r: - - if r.status == 403: - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - elif r.status == 404: - raise ValueError(f"Transform {transform_id} not found") - elif r.status != 200: - msg = await _extract_message(r) - raise RuntimeError( - f"Failed to cancel transform {transform_id} - {msg}" - ) + r = await session.get(headers=headers, url=url) + if r.status_code == 403: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code == 404: + raise ValueError(f"Transform {transform_id} not found") + elif r.status_code != 200: + msg = await _extract_message(r) + raise RuntimeError(f"Failed to cancel transform {transform_id} - {msg}") async def submit_transform(self, transform_request: TransformRequest) -> str: headers = await self._get_authorization() - retry_options = ExponentialRetry(attempts=3, start_timeout=30) - async with RetryClient(retry_options=retry_options) as client: - async with client.post( + retry_options = Retry(total=3, backoff_factor=30) + async with ClientSession( + transport=RetryTransport(retry=retry_options) + ) as client: + r = await client.post( url=f"{self.url}/servicex/transformation", headers=headers, json=transform_request.model_dump(by_alias=True, exclude_none=True), - ) as r: - if r.status == 401: - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - elif r.status == 400: - message = await _extract_message(r) - raise ValueError(f"Invalid transform request: {message}") - elif r.status > 400: - error_message = await _extract_message(r) - raise RuntimeError( - "ServiceX WebAPI Error during transformation " - f"submission: {r.status} - {error_message}" - ) - else: - o = await r.json() - return o["request_id"] + ) + if r.status_code == 401: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code == 400: + message = await _extract_message(r) + raise ValueError(f"Invalid transform request: {message}") + elif r.status_code > 400: + error_message = await _extract_message(r) + raise RuntimeError( + "ServiceX WebAPI Error during transformation " + f"submission: {r.status_code} - {error_message}" + ) + else: + o = await r.json() + return o["request_id"] async def get_transform_status(self, request_id: str) -> TransformStatus: headers = await self._get_authorization() - retry_options = ExponentialRetry(attempts=5, start_timeout=3) - async with RetryClient(retry_options=retry_options) as client: + retry_options = Retry(total=5, backoff_factor=3) + async with ClientSession( + transport=RetryTransport(retry=retry_options) + ) as client: try: async for attempt in AsyncRetrying( retry=retry_if_not_exception_type(ValueError), @@ -286,28 +283,29 @@ async def get_transform_status(self, request_id: str) -> TransformStatus: reraise=True, ): with attempt: - async with client.get( + r = await client.get( url=f"{self.url}/servicex/" f"transformation/{request_id}", headers=headers, - ) as r: - if r.status == 401: - # perhaps we just ran out of auth validity the last time? - # refetch auth then raise an error for retry - headers = await self._get_authorization(True) - raise AuthorizationError( - f"Not authorized to access serviceX at {self.url}" - ) - if r.status == 404: - raise ValueError(f"Transform ID {request_id} not found") - elif r.status > 400: - error_message = await _extract_message(r) - raise RuntimeError( - "ServiceX WebAPI Error during transformation: " - f"{r.status} - {error_message}" - ) - o = await r.json() - return TransformStatus(**o) + ) + if r.status_code == 401: + # perhaps we just ran out of auth validity the last time? + # refetch auth then raise an error for retry + headers = await self._get_authorization(True) + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + if r.status_code == 404: + raise ValueError(f"Transform ID {request_id} not found") + elif r.status_code > 400: + error_message = await _extract_message(r) + raise RuntimeError( + "ServiceX WebAPI Error during transformation: " + f"{r.status_code} - {error_message}" + ) + o = await r.json() + return TransformStatus(**o) except RuntimeError as e: raise RuntimeError( "ServiceX WebAPI Error " f"while getting transform status: {e}" ) + raise RuntimeError("ServiceX WebAPI: unable to retrieve transform status") diff --git a/tests/test_servicex_adapter.py b/tests/test_servicex_adapter.py index 99d8f5c0..3fb59985 100644 --- a/tests/test_servicex_adapter.py +++ b/tests/test_servicex_adapter.py @@ -32,7 +32,7 @@ import httpx import pytest -from aiohttp import ContentTypeError +from json import JSONDecodeError from pytest_asyncio import fixture from servicex.models import TransformRequest, ResultDestination, ResultFormat @@ -55,12 +55,10 @@ def test_result_formats(): @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.get") +@patch("servicex.servicex_adapter.ClientSession.get") async def test_get_transforms(mock_get, servicex, transform_status_response): - mock_get.return_value.__aenter__.return_value.json.return_value = ( - transform_status_response - ) - mock_get.return_value.__aenter__.return_value.status = 200 + mock_get.return_value.json.return_value = transform_status_response + mock_get.return_value.status_code = 200 t = await servicex.get_transforms() assert len(t) == 1 assert t[0].request_id == "b8c508d0-ccf2-4deb-a1f7-65c839eebabf" @@ -70,12 +68,10 @@ async def test_get_transforms(mock_get, servicex, transform_status_response): @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.get") +@patch("servicex.servicex_adapter.ClientSession.get") async def test_get_transforms_error(mock_get, servicex, transform_status_response): - mock_get.return_value.__aenter__.return_value.json.return_value = { - "message": "error_message" - } - mock_get.return_value.__aenter__.return_value.status = 500 + mock_get.return_value.json.return_value = {"message": "error_message"} + mock_get.return_value.status_code = 500 with pytest.raises(RuntimeError) as err: await servicex.get_transforms() assert ( @@ -85,10 +81,10 @@ async def test_get_transforms_error(mock_get, servicex, transform_status_respons @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.get") +@patch("servicex.servicex_adapter.ClientSession.get") async def test_get_transforms_auth_error(mock_get, servicex): with pytest.raises(AuthorizationError) as err: - mock_get.return_value.__aenter__.return_value.status = 401 + mock_get.return_value.status_code = 401 await servicex.get_transforms() assert "Not authorized to access serviceX at" in str(err.value) @@ -119,18 +115,14 @@ async def test_get_transforms_wlcg_bearer_token( @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.post") -@patch("servicex.servicex_adapter.RetryClient.get") +@patch("servicex.servicex_adapter.ClientSession.post") +@patch("servicex.servicex_adapter.ClientSession.get") async def test_get_transforms_with_refresh(get, post, transform_status_response): servicex = ServiceXAdapter(url="https://servicex.org", refresh_token="refrescas") - post.return_value.__aenter__.return_value.json.return_value = { - "access_token": "luckycharms" - } - post.return_value.__aenter__.return_value.status = 200 - get.return_value.__aenter__.return_value.json.return_value = ( - transform_status_response - ) - get.return_value.__aenter__.return_value.status = 200 + post.return_value.json.return_value = {"access_token": "luckycharms"} + post.return_value.status_code = 200 + get.return_value.json.return_value = transform_status_response + get.return_value.status_code = 200 await servicex.get_transforms() post.assert_called_with( @@ -191,8 +183,8 @@ def dataset(): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_get_datasets(get, servicex, dataset): - get.return_value.__aenter__.return_value.json.return_value = {"datasets": [dataset]} - get.return_value.__aenter__.return_value.status = 200 + get.return_value.json.return_value = {"datasets": [dataset]} + get.return_value.status_code = 200 c = await servicex.get_datasets() assert len(c) == 1 @@ -205,8 +197,8 @@ async def test_get_datasets(get, servicex, dataset): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_get_datasets_show_deleted(get, servicex, dataset): - get.return_value.__aenter__.return_value.json.return_value = {"datasets": [dataset]} - get.return_value.__aenter__.return_value.status = 200 + get.return_value.json.return_value = {"datasets": [dataset]} + get.return_value.status_code = 200 c = await servicex.get_datasets(show_deleted=True) assert len(c) == 1 assert c[0].id == 123 @@ -220,7 +212,7 @@ async def test_get_datasets_show_deleted(get, servicex, dataset): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_get_datasets_auth_error(get, servicex): - get.return_value.__aenter__.return_value.status = 403 + get.return_value.status_code = 403 with pytest.raises(AuthorizationError) as err: await servicex.get_datasets() assert "Not authorized to access serviceX at" in str(err.value) @@ -229,8 +221,8 @@ async def test_get_datasets_auth_error(get, servicex): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_get_dataset(get, servicex, dataset): - get.return_value.__aenter__.return_value.json.return_value = dataset - get.return_value.__aenter__.return_value.status = 200 + get.return_value.json.return_value = dataset + get.return_value.status_code = 200 c = await servicex.get_dataset(123) assert c assert c.id == 123 @@ -239,21 +231,19 @@ async def test_get_dataset(get, servicex, dataset): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_get_dataset_errors(get, servicex, dataset): - get.return_value.__aenter__.return_value.status = 403 + get.return_value.status_code = 403 with pytest.raises(AuthorizationError) as err: await servicex.get_dataset(123) assert "Not authorized to access serviceX at" in str(err.value) - get.return_value.__aenter__.return_value.status = 404 + get.return_value.status_code = 404 with pytest.raises(ValueError) as err: await servicex.get_dataset(123) assert "Dataset 123 not found" in str(err.value) - get.return_value.__aenter__.return_value.json.side_effect = ContentTypeError( - None, None - ) - get.return_value.__aenter__.return_value.text.return_value = "error_message" - get.return_value.__aenter__.return_value.status = 500 + get.return_value.json.side_effect = JSONDecodeError("", "", 0) + get.return_value.text = "error_message" + get.return_value.status_code = 500 with pytest.raises(RuntimeError) as err: await servicex.get_dataset(123) assert "Failed to get dataset 123 - error_message" in str(err.value) @@ -262,11 +252,11 @@ async def test_get_dataset_errors(get, servicex, dataset): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.delete") async def test_delete_dataset(delete, servicex): - delete.return_value.__aenter__.return_value.json.return_value = { + delete.return_value.json.return_value = { "dataset-id": 123, "stale": True, } - delete.return_value.__aenter__.return_value.status = 200 + delete.return_value.status_code = 200 r = await servicex.delete_dataset(123) delete.assert_called_with( @@ -278,21 +268,19 @@ async def test_delete_dataset(delete, servicex): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.delete") async def test_delete_dataset_errors(delete, servicex): - delete.return_value.__aenter__.return_value.status = 403 + delete.return_value.status_code = 403 with pytest.raises(AuthorizationError) as err: await servicex.delete_dataset(123) assert "Not authorized to access serviceX at" in str(err.value) - delete.return_value.__aenter__.return_value.status = 404 + delete.return_value.status_code = 404 with pytest.raises(ValueError) as err: await servicex.delete_dataset(123) assert "Dataset 123 not found" in str(err.value) - delete.return_value.__aenter__.return_value.json.side_effect = ContentTypeError( - None, None - ) - delete.return_value.__aenter__.return_value.text.return_value = "error_message" - delete.return_value.__aenter__.return_value.status = 500 + delete.return_value.json.side_effect = JSONDecodeError("", "", 0) + delete.return_value.text = "error_message" + delete.return_value.status_code = 500 with pytest.raises(RuntimeError) as err: await servicex.delete_dataset(123) assert "Failed to delete dataset 123 - error_message" in str(err.value) @@ -301,7 +289,7 @@ async def test_delete_dataset_errors(delete, servicex): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.delete") async def test_delete_transform(delete, servicex): - delete.return_value.__aenter__.return_value.status = 200 + delete.return_value.status_code = 200 await servicex.delete_transform("123-45-6789") delete.assert_called_with( url="https://servicex.org/servicex/transformation/123-45-6789", headers={} @@ -311,21 +299,19 @@ async def test_delete_transform(delete, servicex): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.delete") async def test_delete_transform_errors(delete, servicex): - delete.return_value.__aenter__.return_value.status = 403 + delete.return_value.status_code = 403 with pytest.raises(AuthorizationError) as err: await servicex.delete_transform("123-45-6789") assert "Not authorized to access serviceX at" in str(err.value) - delete.return_value.__aenter__.return_value.status = 404 + delete.return_value.status_code = 404 with pytest.raises(ValueError) as err: await servicex.delete_transform("123-45-6789") assert "Transform 123-45-6789 not found" in str(err.value) - delete.return_value.__aenter__.return_value.json.side_effect = ContentTypeError( - None, None - ) - delete.return_value.__aenter__.return_value.text.return_value = "error_message" - delete.return_value.__aenter__.return_value.status = 500 + delete.return_value.json.side_effect = JSONDecodeError("", "", 0) + delete.return_value.text = "error_message" + delete.return_value.status_code = 500 with pytest.raises(RuntimeError) as err: await servicex.delete_transform("123-45-6789") assert "Failed to delete transform 123-45-6789 - error_message" in str(err.value) @@ -334,10 +320,10 @@ async def test_delete_transform_errors(delete, servicex): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_cancel_transform(get, servicex): - get.return_value.__aenter__.return_value.json.return_value = { + get.return_value.json.return_value = { "message": "Canceled transformation request 123" } - get.return_value.__aenter__.return_value.status = 200 + get.return_value.status_code = 200 await servicex.cancel_transform(123) get.assert_called_with( @@ -348,33 +334,29 @@ async def test_cancel_transform(get, servicex): @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_cancel_transform_errors(get, servicex): - get.return_value.__aenter__.return_value.status = 403 + get.return_value.status_code = 403 with pytest.raises(AuthorizationError) as err: await servicex.cancel_transform(123) assert "Not authorized to access serviceX at" in str(err.value) - get.return_value.__aenter__.return_value.status = 404 + get.return_value.status_code = 404 with pytest.raises(ValueError) as err: await servicex.cancel_transform(123) assert "Transform 123 not found" in str(err.value) - get.return_value.__aenter__.return_value.json.side_effect = ContentTypeError( - None, None - ) - get.return_value.__aenter__.return_value.text.return_value = "error_message" - get.return_value.__aenter__.return_value.status = 500 + get.return_value.json.side_effect = JSONDecodeError("", "", 0) + get.return_value.text = "error_message" + get.return_value.status_code = 500 with pytest.raises(RuntimeError) as err: await servicex.cancel_transform(123) assert "Failed to cancel transform 123 - error_message" in str(err.value) @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.post") +@patch("servicex.servicex_adapter.ClientSession.post") async def test_submit(post, servicex): - post.return_value.__aenter__.return_value.json.return_value = { - "request_id": "123-456-789" - } - post.return_value.__aenter__.return_value.status = 200 + post.return_value.json.return_value = {"request_id": "123-456-789"} + post.return_value.status_code = 200 request = TransformRequest( title="Test submission", did="rucio://foo.bar", @@ -388,9 +370,9 @@ async def test_submit(post, servicex): @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.post") +@patch("servicex.servicex_adapter.ClientSession.post") async def test_submit_errors(post, servicex): - post.return_value.__aenter__.return_value.status = 401 + post.return_value.status_code = 401 request = TransformRequest( title="Test submission", did="rucio://foo.bar", @@ -403,11 +385,9 @@ async def test_submit_errors(post, servicex): await servicex.submit_transform(request) assert "Not authorized to access serviceX at" in str(err.value) - post.return_value.__aenter__.return_value.json.side_effect = ContentTypeError( - None, None - ) - post.return_value.__aenter__.return_value.text.return_value = "error_message" - post.return_value.__aenter__.return_value.status = 500 + post.return_value.json.side_effect = JSONDecodeError("", "", 0) + post.return_value.text = "error_message" + post.return_value.status_code = 500 with pytest.raises(RuntimeError) as err: await servicex.submit_transform(request) assert ( @@ -415,19 +395,15 @@ async def test_submit_errors(post, servicex): == str(err.value) ) - post.return_value.__aenter__.return_value.json.reset_mock() - post.return_value.__aenter__.return_value.json.return_value = { - "message": "error_message" - } - post.return_value.__aenter__.return_value.status = 400 + post.return_value.json.reset_mock() + post.return_value.json.return_value = {"message": "error_message"} + post.return_value.status_code = 400 with pytest.raises(ValueError) as err: await servicex.submit_transform(request) assert "Invalid transform request: error_message" == str(err.value) - post.return_value.__aenter__.return_value.json.return_value = { - "message": "error_message" - } - post.return_value.__aenter__.return_value.status = 410 + post.return_value.json.return_value = {"message": "error_message"} + post.return_value.status_code = 410 with pytest.raises(RuntimeError) as err: await servicex.submit_transform(request) assert ( @@ -437,53 +413,53 @@ async def test_submit_errors(post, servicex): @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.get") +@patch("servicex.servicex_adapter.ClientSession.get") async def test_get_transform_status(get, servicex, transform_status_response): - get.return_value.__aenter__.return_value.json.return_value = ( - transform_status_response["requests"][0] - ) # NOQA: E501 - get.return_value.__aenter__.return_value.status = 200 + get.return_value.json.return_value = transform_status_response["requests"][ + 0 + ] # NOQA: E501 + get.return_value.status_code = 200 result = await servicex.get_transform_status("b8c508d0-ccf2-4deb-a1f7-65c839eebabf") assert result.request_id == "b8c508d0-ccf2-4deb-a1f7-65c839eebabf" @pytest.mark.asyncio -@patch("servicex.servicex_adapter.RetryClient.get") +@patch("servicex.servicex_adapter.ClientSession.get") async def test_get_transform_status_errors(get, servicex): with pytest.raises(AuthorizationError) as err: - get.return_value.__aenter__.return_value.status = 401 + get.return_value.status_code = 401 await servicex.get_transform_status("b8c508d0-ccf2-4deb-a1f7-65c839eebabf") assert "Not authorized to access serviceX at " in str(err.value) with pytest.raises(ValueError) as err: - get.return_value.__aenter__.return_value.status = 404 + get.return_value.status_code = 404 await servicex.get_transform_status("b8c508d0-ccf2-4deb-a1f7-65c839eebabf") assert "Transform ID b8c508d0-ccf2-4deb-a1f7-65c839eebabf not found" == str( err.value ) with pytest.raises(RuntimeError) as err: - get.return_value.__aenter__.return_value.status = 500 + get.return_value.status_code = 500 async def patch_json(): return {"message": "fifteen"} - get.return_value.__aenter__.return_value.json = patch_json + get.return_value.json = patch_json await servicex.get_transform_status("b8c508d0-ccf2-4deb-a1f7-65c839eebabf") assert "ServiceX WebAPI Error during transformation" in str(err.value) @pytest.mark.asyncio @patch("servicex.servicex_adapter.TransformStatus", side_effect=RuntimeError) -@patch("servicex.servicex_adapter.RetryClient.get") +@patch("servicex.servicex_adapter.ClientSession.get") async def test_get_tranform_status_retry_error( get, mock_transform_status, servicex, transform_status_response ): with pytest.raises(RuntimeError) as err: - get.return_value.__aenter__.return_value.json.return_value = ( - transform_status_response["requests"][0] - ) # NOQA: E501 - get.return_value.__aenter__.return_value.status = 200 + get.return_value.json.return_value = transform_status_response["requests"][ + 0 + ] # NOQA: E501 + get.return_value.status_code = 200 await servicex.get_transform_status("b8c508d0-ccf2-4deb-a1f7-65c839eebabf") assert "ServiceX WebAPI Error while getting transform status:" in str(err.value) From 99504c50aa02856649a4c2ad63c4ba6ac0715d98 Mon Sep 17 00:00:00 2001 From: Peter Onyisi Date: Thu, 22 May 2025 21:16:31 +0000 Subject: [PATCH 2/2] Coverage increases --- servicex/servicex_adapter.py | 4 +++- tests/test_servicex_adapter.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/servicex/servicex_adapter.py b/servicex/servicex_adapter.py index d01aa06d..41944b03 100644 --- a/servicex/servicex_adapter.py +++ b/servicex/servicex_adapter.py @@ -308,4 +308,6 @@ async def get_transform_status(self, request_id: str) -> TransformStatus: raise RuntimeError( "ServiceX WebAPI Error " f"while getting transform status: {e}" ) - raise RuntimeError("ServiceX WebAPI: unable to retrieve transform status") + raise RuntimeError( + "ServiceX WebAPI: unable to retrieve transform status" + ) # pragma: no cover diff --git a/tests/test_servicex_adapter.py b/tests/test_servicex_adapter.py index 3fb59985..9cc5de78 100644 --- a/tests/test_servicex_adapter.py +++ b/tests/test_servicex_adapter.py @@ -218,6 +218,15 @@ async def test_get_datasets_auth_error(get, servicex): assert "Not authorized to access serviceX at" in str(err.value) +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.ClientSession.get") +async def test_get_datasets_miscellaneous_error(get, servicex): + get.return_value.status_code = 500 + with pytest.raises(RuntimeError) as err: + await servicex.get_datasets() + assert "Failed to get datasets" in str(err.value) + + @pytest.mark.asyncio @patch("servicex.servicex_adapter.ClientSession.get") async def test_get_dataset(get, servicex, dataset):