diff --git a/servicex/models.py b/servicex/models.py index 37bde544..23b9bd27 100644 --- a/servicex/models.py +++ b/servicex/models.py @@ -230,6 +230,16 @@ class TransformedResults(DocStringBaseModel): """URL for looking up logs on the ServiceX server""" +class ServiceXInfo(DocStringBaseModel): + r""" + Model for ServiceX Info properties + """ + + app_version: str = Field(alias="app-version") + code_gen_image: dict[str, str] = Field(alias="code-gen-image") + capabilities: list[str] = Field(default_factory=list) + + class DatasetFile(BaseModel): """ Model for a file in a cached dataset diff --git a/servicex/query_core.py b/servicex/query_core.py index 6f54a311..ad2341ef 100644 --- a/servicex/query_core.py +++ b/servicex/query_core.py @@ -27,6 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from __future__ import annotations +import datetime import abc import asyncio from abc import ABC @@ -342,13 +343,17 @@ def transform_complete(task: Task): download_files_task = loop.create_task( self.download_files( - signed_urls_only, expandable_progress, download_progress, cached_record + signed_urls_only, + expandable_progress, + download_progress, + cached_record, ) ) try: signed_urls = [] downloaded_files = [] + download_result = await download_files_task if signed_urls_only: signed_urls = download_result @@ -522,6 +527,7 @@ async def download_files( Task to monitor the list of files in the transform output's bucket. Any new files will be downloaded. """ + files_seen = set() result_uris = [] download_tasks = [] @@ -555,40 +561,70 @@ async def get_signed_url( if progress: progress.advance(task_id=download_progress, task_type="Download") + later_than = datetime.datetime.min.replace(tzinfo=datetime.timezone.utc) + + use_local_polling = ( + "poll_local_transformation_results" + in await self.servicex.get_servicex_capabilities() + ) + + if not use_local_polling: + logger.warning( + "ServiceX is using legacy S3 bucket polling. Future versions of the " + "ServiceX client will not support this method. Please update your " + "ServiceX server to the latest version." + ) + while True: if not cached_record: await asyncio.sleep(self.minio_polling_interval) if self.minio: # if self.minio exists, self.current_status will too if self.current_status.files_completed > len(files_seen): - files = await self.minio.list_bucket() + if use_local_polling: + files = await self.servicex.get_transformation_results( + self.current_status.request_id, later_than + ) + else: + files = await self.minio.list_bucket() + for file in files: - if file.filename not in files_seen: + filename = file.filename + + if filename != "" and filename not in files_seen: if signed_urls_only: download_tasks.append( loop.create_task( get_signed_url( self.minio, - file.filename, + filename, progress, download_progress, ) ) ) else: + if use_local_polling: + expected_size = file.total_bytes + else: + expected_size = file.size download_tasks.append( loop.create_task( download_file( self.minio, - file.filename, + filename, progress, download_progress, shorten_filename=self.configuration.shortened_downloaded_filename, # NOQA: E501 - expected_size=file.size, + expected_size=expected_size, ) ) ) # NOQA 501 - files_seen.add(file.filename) + files_seen.add(filename) + + if use_local_polling: + if file.created_at > later_than: + later_than = file.created_at # Once the transform is complete and all files are seen we can stop polling. # Also, if we are just downloading or signing urls for a previous transform diff --git a/servicex/servicex_adapter.py b/servicex/servicex_adapter.py index ad0d9fb3..25f7b8db 100644 --- a/servicex/servicex_adapter.py +++ b/servicex/servicex_adapter.py @@ -27,7 +27,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os import time +import datetime from typing import Optional, Dict, List +from dataclasses import dataclass import httpx from httpx import AsyncClient, Response @@ -41,13 +43,25 @@ retry_if_not_exception_type, ) -from servicex.models import TransformRequest, TransformStatus, CachedDataset +from servicex.models import ( + TransformRequest, + TransformStatus, + CachedDataset, + ServiceXInfo, +) class AuthorizationError(BaseException): pass +@dataclass +class ServiceXFile: + created_at: datetime.datetime + filename: str + total_bytes: int + + async def _extract_message(r: Response): try: o = r.json() @@ -63,6 +77,9 @@ def __init__(self, url: str, refresh_token: Optional[str] = None): self.refresh_token = refresh_token self.token = None + # interact with _servicex_info via get_servicex_info + self._servicex_info: Optional[ServiceXInfo] = None + async def _get_token(self): url = f"{self.url}/token/refresh" headers = {"Authorization": f"Bearer {self.refresh_token}"} @@ -120,6 +137,31 @@ async def _get_authorization(self, force_reauth: bool = False) -> Dict[str, str] await self._get_token() return {"Authorization": f"Bearer {self.token}"} + async def get_servicex_info(self) -> ServiceXInfo: + if self._servicex_info: + return self._servicex_info + + headers = await self._get_authorization() + retry_options = Retry(total=3, backoff_factor=10) + async with AsyncClient(transport=RetryTransport(retry=retry_options)) as client: + r = await client.get(url=f"{self.url}/servicex", headers=headers) + if r.status_code == 401: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + elif r.status_code > 400: + error_message = await _extract_message(r) + raise RuntimeError( + "ServiceX WebAPI Error during transformation " + f"submission: {r.status_code} - {error_message}" + ) + servicex_info = r.json() + self._servicex_info = ServiceXInfo(**servicex_info) + return self._servicex_info + + async def get_servicex_capabilities(self) -> List[str]: + return (await self.get_servicex_info()).capabilities + async def get_transforms(self) -> List[TransformStatus]: headers = await self._get_authorization() retry_options = Retry(total=3, backoff_factor=10) @@ -232,6 +274,48 @@ async def delete_transform(self, transform_id=None): msg = await _extract_message(r) raise RuntimeError(f"Failed to delete transform {transform_id} - {msg}") + async def get_transformation_results( + self, request_id: str, later_than: Optional[datetime.datetime] = None + ): + if ( + "poll_local_transformation_results" + not in await self.get_servicex_capabilities() + ): + raise ValueError("ServiceX capabilities not found") + + headers = await self._get_authorization() + url = self.url + f"/servicex/transformation/{request_id}/results" + params = {} + if later_than: + params["later_than"] = later_than.isoformat() + + async with AsyncClient() as session: + r = await session.get(headers=headers, url=url, params=params) + if r.status_code == 403: + raise AuthorizationError( + f"Not authorized to access serviceX at {self.url}" + ) + + if r.status_code == 404: + raise ValueError(f"Request {request_id} not found") + + if r.status_code != 200: + msg = await _extract_message(r) + raise RuntimeError(f"Failed with message: {msg}") + + data = r.json() + response = list() + for result in data.get("results", []): + file = ServiceXFile( + filename=result["s3-object-name"], + created_at=datetime.datetime.fromisoformat( + result["created_at"] + ).replace(tzinfo=datetime.timezone.utc), + total_bytes=result["total-bytes"], + ) + response.append(file) + return response + async def cancel_transform(self, transform_id=None): headers = await self._get_authorization() path_template = f"/servicex/transformation/{transform_id}/cancel" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2ef18283..1bf2446f 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -28,6 +28,7 @@ import pytest import tempfile import os +import datetime from unittest.mock import AsyncMock, Mock, patch from servicex.dataset_identifier import FileListDataset @@ -44,6 +45,8 @@ ) from rich.progress import Progress +from servicex.servicex_adapter import ServiceXFile + @pytest.mark.asyncio async def test_as_signed_urls_happy(transformed_result): @@ -124,12 +127,27 @@ async def test_download_files(python_dataset): minio_mock = AsyncMock() config = Configuration(cache_path="temp_dir", api_endpoints=[]) python_dataset.configuration = config + python_dataset.servicex = AsyncMock() + python_dataset.servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + + python_dataset.servicex.get_transformation_results = AsyncMock() + python_dataset.servicex.get_transformation_results.return_value = [ + ServiceXFile( + filename="file1.txt", + created_at=datetime.datetime.now(datetime.timezone.utc), + total_bytes=100, + ), + ServiceXFile( + filename="file2.txt", + created_at=datetime.datetime.now(datetime.timezone.utc), + total_bytes=100, + ), + ] + minio_mock.download_file.return_value = Path("/path/to/downloaded_file") minio_mock.get_signed_url.return_value = Path("http://example.com/signed_url") - minio_mock.list_bucket.return_value = [ - Mock(filename="file1.txt"), - Mock(filename="file2.txt"), - ] progress_mock = Mock() python_dataset.minio_polling_interval = 0 @@ -154,12 +172,27 @@ async def test_download_files_with_signed_urls(python_dataset): python_dataset.configuration = config minio_mock.download_file.return_value = "/path/to/downloaded_file" minio_mock.get_signed_url.return_value = "http://example.com/signed_url" - minio_mock.list_bucket.return_value = [ - Mock(filename="file1.txt"), - Mock(filename="file2.txt"), - ] progress_mock = Mock() + python_dataset.servicex = AsyncMock() + python_dataset.servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + + python_dataset.servicex.get_transformation_results = AsyncMock() + python_dataset.servicex.get_transformation_results.return_value = [ + ServiceXFile( + filename="file1.txt", + created_at=datetime.datetime.now(datetime.timezone.utc), + total_bytes=100, + ), + ServiceXFile( + filename="file2.txt", + created_at=datetime.datetime.now(datetime.timezone.utc), + total_bytes=100, + ), + ] + python_dataset.minio_polling_interval = 0 python_dataset.minio = minio_mock python_dataset.current_status = Mock(status="Complete", files_completed=2) diff --git a/tests/test_servicex_adapter.py b/tests/test_servicex_adapter.py index 9c5810a8..4d9f4bd1 100644 --- a/tests/test_servicex_adapter.py +++ b/tests/test_servicex_adapter.py @@ -28,14 +28,20 @@ import os import tempfile import time -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock +import datetime import httpx import pytest from json import JSONDecodeError from pytest_asyncio import fixture -from servicex.models import TransformRequest, ResultDestination, ResultFormat +from servicex.models import ( + TransformRequest, + ResultDestination, + ResultFormat, + ServiceXInfo, +) from servicex.servicex_adapter import ServiceXAdapter, AuthorizationError @@ -92,13 +98,23 @@ async def test_get_transforms_auth_error(mock_get, servicex): @pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") @patch("servicex.servicex_adapter.AsyncClient.post") @patch("servicex.servicex_adapter.jwt.decode") async def test_get_transforms_wlcg_bearer_token( - decode, post, servicex, transform_status_response + decode, post, http_get, servicex, transform_status_response ): post.return_value.json.return_value = {"access_token": "luckycharms"} post.return_value.status_code = 401 + http_get.return_value.__aenter__.return_value.json.return_value = ( + transform_status_response + ) + http_get.return_value.__aenter__.return_value.status = 200 + servicex.get_servicex_capabilities = AsyncMock(return_value=[]) + post.return_value.__aenter__.return_value.json.return_value = { + "access_token": "luckycharms" + } + post.return_value.__aenter__.return_value.status = 401 token_file = tempfile.NamedTemporaryFile(mode="w+t", delete=False) token_file.write( """" @@ -506,3 +522,295 @@ async def test_get_authorization(servicex): with patch("google.auth.jwt.decode", return_value={"exp": time.time() - 90}): r = await servicex._get_authorization() get_token.assert_called_once() + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_transformation_results_success(get, servicex): + servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + get.return_value = MagicMock() + get.return_value.json.return_value = { + "results": [ + { + "file-path": "file1.txt", + "total-bytes": 100, + "s3-object-name": "file1.txt", + "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), + }, + { + "file-path": "file2.txt", + "total-bytes": 100, + "s3-object-name": "file2.txt", + "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), + }, + ] + } + get.return_value.status_code = 200 + + request_id = "123-45-6789" + now = datetime.datetime.now(datetime.timezone.utc) + await servicex.get_transformation_results(request_id, now) + + get.assert_called_with( + url=f"https://servicex.org/servicex/transformation/{request_id}/results", + headers={}, + params={ + "later_than": now.isoformat(), + }, + ) + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_transformation_results_no_feature_flag(get, servicex): + servicex.get_servicex_capabilities = AsyncMock(return_value=[]) + request_id = "123-45-6789" + now = datetime.datetime.now(datetime.timezone.utc) + with pytest.raises(ValueError): + await servicex.get_transformation_results(request_id, now) + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_transformation_results_not_found( + get_transformation_results, servicex +): + servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + get_transformation_results.return_value = MagicMock() + get_transformation_results.return_value.status_code = 404 + + request_id = "123-45-6789" + now = datetime.datetime.now(datetime.timezone.utc) + + with pytest.raises(ValueError): + await servicex.get_transformation_results(request_id, now) + + get_transformation_results.assert_called_with( + url=f"https://servicex.org/servicex/transformation/{request_id}/results", + headers={}, + params={ + "later_than": now.isoformat(), + }, + ) + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_transformation_results_not_authorized( + get_transformation_results, servicex +): + servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + get_transformation_results.return_value = MagicMock() + get_transformation_results.return_value.status_code = 403 + request_id = "123-45-6789" + now = datetime.datetime.now(datetime.timezone.utc) + + with pytest.raises(AuthorizationError): + await servicex.get_transformation_results(request_id, now) + + get_transformation_results.assert_called_with( + url=f"https://servicex.org/servicex/transformation/{request_id}/results", + headers={}, + params={ + "later_than": now.isoformat(), + }, + ) + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_transformation_results_server_error( + get_transformation_results, servicex +): + servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + get_transformation_results.return_value = MagicMock() + get_transformation_results.return_value.status = 500 + request_id = "123-45-6789" + now = datetime.datetime.now(datetime.timezone.utc) + + with pytest.raises(RuntimeError): + await servicex.get_transformation_results(request_id, now) + + get_transformation_results.assert_called_with( + url=f"https://servicex.org/servicex/transformation/{request_id}/results", + headers={}, + params={ + "later_than": now.isoformat(), + }, + ) + + +def test_get_bearer_token_file(tmp_path, monkeypatch): + token_file = tmp_path / "btf" + token_file.write_text("bearer123") + monkeypatch.setenv("BEARER_TOKEN_FILE", str(token_file)) + assert ServiceXAdapter._get_bearer_token_file() == "bearer123" + monkeypatch.delenv("BEARER_TOKEN_FILE", raising=False) + assert ServiceXAdapter._get_bearer_token_file() is None + + +@patch("servicex.servicex_adapter.jwt.decode", return_value={"exp": 1600000000}) +def test_get_token_expiration_success(decode): + assert ServiceXAdapter._get_token_expiration("dummy") == 1600000000 + + +@patch("servicex.servicex_adapter.jwt.decode", return_value={"sub": "noexp"}) +def test_get_token_expiration_no_exp(decode): + with pytest.raises(RuntimeError): + ServiceXAdapter._get_token_expiration("dummy") + + +@pytest.mark.asyncio +async def test_get_authorization_no_token_no_refresh(servicex, monkeypatch): + monkeypatch.delenv("BEARER_TOKEN_FILE", raising=False) + headers = await servicex._get_authorization() + assert headers == {} + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.jwt.decode", return_value={"exp": time.time() + 120}) +async def test_get_authorization_with_valid_token(decode, servicex): + servicex.token = "tok123" + headers = await servicex._get_authorization() + assert headers == {"Authorization": "Bearer tok123"} + + +@pytest.mark.asyncio +async def test_get_authorization_with_refresh(monkeypatch): + s = ServiceXAdapter("https://servicex.org", refresh_token="rftok") + monkeypatch.delenv("BEARER_TOKEN_FILE", raising=False) + + async def fake_get_token(self): + self.token = "newtoken" + + monkeypatch.setattr(ServiceXAdapter, "_get_token", fake_get_token) + headers = await s._get_authorization() + assert headers == {"Authorization": "Bearer newtoken"} + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_servicex_info_success(mock_get, servicex): + mock_get.return_value.status_code = 200 + mock_get.return_value.json = MagicMock( + return_value={ + "capabilities": ["a", "b"], + "app-version": "1.0", + "code-gen-image": {"func_adl": "image1", "uproot": "image2"}, + } + ) + info = await servicex.get_servicex_info() + assert isinstance(info, ServiceXInfo) + assert info.capabilities == ["a", "b"] + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_servicex_info_auth_error(mock_get, servicex): + mock_get.return_value.status_code = 401 + with pytest.raises(AuthorizationError): + await servicex.get_servicex_info() + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_servicex_info_server_error(mock_get, servicex): + mock_get.return_value.status_code = 500 + mock_get.return_value.json = MagicMock(return_value={"message": "oops"}) + with pytest.raises(RuntimeError) as e: + await servicex.get_servicex_info() + assert "ServiceX WebAPI Error during transformation submission: 500 - oops" in str( + e.value + ) + + +@pytest.mark.asyncio +async def test_get_servicex_info_caching(servicex): + servicex_info_data = { + "capabilities": ["a", "b"], + "app-version": "1.0", + "code-gen-image": {"func_adl": "image1", "uproot": "image2"}, + } + + with patch("servicex.servicex_adapter.AsyncClient.get") as mock_get: + mock_get.return_value.status_code = 200 + mock_get.return_value.json = MagicMock(return_value=servicex_info_data) + + info1 = await servicex.get_servicex_info() + assert isinstance(info1, ServiceXInfo) + assert info1.capabilities == ["a", "b"] + assert mock_get.call_count == 1 + + # Second call should use cached ServiceXInfo without additional HTTP request + info2 = await servicex.get_servicex_info() + assert info2 is info1 + assert mock_get.call_count == 1 + + +@pytest.mark.asyncio +async def test_get_servicex_capabilities(servicex): + servicex_info_data = { + "capabilities": ["feature1", "feature2", "feature3"], + "app-version": "1.0", + "code-gen-image": {"func_adl": "image1", "uproot": "image2"}, + } + + with patch("servicex.servicex_adapter.AsyncClient.get") as mock_get: + mock_get.return_value.status_code = 200 + mock_get.return_value.json = MagicMock(return_value=servicex_info_data) + + capabilities1 = await servicex.get_servicex_capabilities() + assert capabilities1 == ["feature1", "feature2", "feature3"] + assert mock_get.call_count == 1 + + # Second call should use cached ServiceXInfo without additional HTTP request + capabilities2 = await servicex.get_servicex_capabilities() + assert capabilities2 == ["feature1", "feature2", "feature3"] + assert mock_get.call_count == 1 + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_transformation_results_parsing(mock_get, servicex): + servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + msg_time = datetime.datetime(2025, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + mock_get.return_value = MagicMock() + mock_get.return_value.status_code = 200 + mock_get.return_value.json = MagicMock( + return_value={ + "results": [ + { + "file-path": "dir1/file.txt", + "s3-object-name": "dir1:file.txt", + "total-bytes": 100, + "created_at": msg_time.isoformat(), + } + ] + } + ) + res = await servicex.get_transformation_results("id123", None) + assert len(res) == 1 + assert res[0].filename == "dir1:file.txt" + assert res[0].created_at == msg_time + + +@pytest.mark.asyncio +@patch("servicex.servicex_adapter.AsyncClient.get") +async def test_get_transformation_results_empty(mock_get, servicex): + servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + mock_get.return_value.status_code = 200 + mock_get.return_value.json = MagicMock(return_value={"results": []}) + res = await servicex.get_transformation_results("id123", None) + assert res == [] diff --git a/tests/test_servicex_dataset.py b/tests/test_servicex_dataset.py index 28cbd0f7..4299a3c2 100644 --- a/tests/test_servicex_dataset.py +++ b/tests/test_servicex_dataset.py @@ -25,6 +25,7 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import datetime import tempfile from typing import List from unittest.mock import AsyncMock, patch @@ -46,6 +47,7 @@ ) from servicex.query_cache import QueryCache from servicex.query_core import ServiceXException, Query +from servicex.servicex_adapter import ServiceXFile from servicex.servicex_client import ServiceXClient from servicex.uproot_raw.uproot_raw import UprootRawQuery @@ -201,11 +203,42 @@ def cache_transform(record: TransformedResults): return +@pytest.mark.parametrize("use_s3_polling", [False, True]) @pytest.mark.asyncio -async def test_submit(mocker): +async def test_submit(mocker, use_s3_polling): servicex = AsyncMock() servicex.submit_transform = AsyncMock() - servicex.submit_transform.return_value = {"request_id": '123-456-789"'} + servicex.submit_transform.return_value = {"request_id": "123-456-789"} + + # Configure capabilities based on polling type + capabilities = [] if use_s3_polling else ["poll_local_transformation_results"] + servicex.get_servicex_capabilities = AsyncMock(return_value=capabilities) + + if not use_s3_polling: + servicex.get_transformation_results = AsyncMock( + side_effect=[ + [ + ServiceXFile( + filename="file1", + total_bytes=100, + created_at=datetime.datetime.now(datetime.timezone.utc), + ) + ], + [ + ServiceXFile( + filename="file1", + total_bytes=100, + created_at=datetime.datetime.now(datetime.timezone.utc), + ), + ServiceXFile( + filename="file2", + total_bytes=100, + created_at=datetime.datetime.now(datetime.timezone.utc), + ), + ], + ] + ) + servicex.get_transform_status = AsyncMock() servicex.get_transform_status.side_effect = [ transform_status1, @@ -214,17 +247,21 @@ async def test_submit(mocker): ] mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1, file2]]) mock_minio.download_file = AsyncMock( side_effect=lambda a, _, shorten_filename, expected_size: PurePath(a) ) + if use_s3_polling: + mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1, file2]]) + mock_cache = mocker.MagicMock(QueryCache) mock_cache.get_transform_by_hash = mocker.MagicMock(return_value=None) mock_cache.transformed_results = mocker.MagicMock(side_effect=transformed_results) mock_cache.cache_transform = mocker.MagicMock(side_effect=cache_transform) mock_cache.cache_path_for_transform = mocker.MagicMock(return_value=PurePath(".")) + mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) + did = FileListDataset("/foo/bar/baz.root") datasource = Query( dataset_identifier=did, @@ -235,21 +272,47 @@ async def test_submit(mocker): config=Configuration(api_endpoints=[]), ) datasource.query_string_generator = FuncADLQuery_Uproot().FromTree("nominal") + with ExpandableProgress(display_progress=False) as progress: datasource.result_format = ResultFormat.parquet result = await datasource.submit_and_download( signed_urls_only=False, expandable_progress=progress ) - print(mock_minio.download_file.call_args) + assert result.file_list == ["file1", "file2"] mock_cache.cache_transform.assert_called_once() +@pytest.mark.parametrize("use_s3_polling", [False, True]) @pytest.mark.asyncio -async def test_submit_partial_success(mocker): +async def test_submit_partial_success(mocker, use_s3_polling): servicex = AsyncMock() servicex.submit_transform = AsyncMock() - servicex.submit_transform.return_value = {"request_id": '123-456-789"'} + servicex.submit_transform.return_value = {"request_id": "123-456-789"} + + capabilities = [] if use_s3_polling else ["poll_local_transformation_results"] + servicex.get_servicex_capabilities = AsyncMock(return_value=capabilities) + + if not use_s3_polling: + servicex.get_transformation_results = AsyncMock( + side_effect=[ + [ + ServiceXFile( + filename="file1", + created_at=datetime.datetime.now(datetime.timezone.utc), + total_bytes=100, + ) + ], + [ + ServiceXFile( + filename="file1", + created_at=datetime.datetime.now(datetime.timezone.utc), + total_bytes=100, + ) + ], + ] + ) + servicex.get_transform_status = AsyncMock() servicex.get_transform_status.side_effect = [ transform_status1, @@ -258,17 +321,21 @@ async def test_submit_partial_success(mocker): ] mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1]]) mock_minio.download_file = AsyncMock( side_effect=lambda a, _, shorten_filename, expected_size: PurePath(a) ) + if use_s3_polling: + mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1]]) + mock_cache = mocker.MagicMock(QueryCache) mock_cache.get_transform_by_hash = mocker.MagicMock(return_value=None) mock_cache.transformed_results = mocker.MagicMock(side_effect=transformed_results) mock_cache.cache_transform = mocker.MagicMock(side_effect=cache_transform) mock_cache.cache_path_for_transform = mocker.MagicMock(return_value=PurePath(".")) + mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) + did = FileListDataset("/foo/bar/baz.root") datasource = Query( dataset_identifier=did, @@ -279,34 +346,58 @@ async def test_submit_partial_success(mocker): config=Configuration(api_endpoints=[]), ) datasource.query_string_generator = FuncADLQuery_Uproot().FromTree("nominal") + with ExpandableProgress(display_progress=False) as progress: datasource.result_format = ResultFormat.parquet result = await datasource.submit_and_download( signed_urls_only=False, expandable_progress=progress ) - print(mock_minio.download_file.call_args) + assert result.file_list == ["file1"] mock_cache.cache_transform.assert_not_called() +@pytest.mark.parametrize("use_s3_polling", [False, True]) @pytest.mark.asyncio -async def test_use_of_cache(mocker): +async def test_use_of_cache(mocker, use_s3_polling): """Do we pick up the cache on the second request for the same transform?""" servicex = AsyncMock() servicex.submit_transform = AsyncMock() - servicex.submit_transform.return_value = {"request_id": '123-456-789"'} + servicex.submit_transform.return_value = {"request_id": "123-456-789"} + + capabilities = [] if use_s3_polling else ["poll_local_transformation_results"] + servicex.get_servicex_capabilities = AsyncMock(return_value=capabilities) + servicex.get_transform_status = AsyncMock() servicex.get_transform_status.side_effect = [ transform_status1, transform_status3, ] + + if not use_s3_polling: + servicex.get_transformation_results = AsyncMock() + servicex.get_transformation_results.return_value = [ + ServiceXFile( + filename="file1.txt", + total_bytes=100, + created_at=datetime.datetime.now(datetime.timezone.utc), + ), + ServiceXFile( + filename="file2.txt", + total_bytes=100, + created_at=datetime.datetime.now(datetime.timezone.utc), + ), + ] + mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(return_value=[file1, file2]) mock_minio.download_file = AsyncMock( side_effect=lambda a, _, shorten_filename, expected_size: PurePath(a) ) mock_minio.get_signed_url = AsyncMock(side_effect=["http://file1", "http://file2"]) + if use_s3_polling: + mock_minio.list_bucket = AsyncMock(return_value=[file1, file2]) + mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) did = FileListDataset("/foo/bar/baz.root") @@ -333,10 +424,13 @@ async def test_use_of_cache(mocker): upd.assert_not_called() upd.reset_mock() assert mock_minio.get_signed_url.await_count == 2 - # second round, should hit the cache (and not call the sx_adapter, minio, or update_record) + with ExpandableProgress(display_progress=False) as progress: servicex2 = AsyncMock() - mock_minio.list_bucket.reset_mock() + if use_s3_polling: + mock_minio.list_bucket.reset_mock() + else: + servicex.get_transformation_results.reset_mock() mock_minio.get_signed_url.reset_mock() datasource2 = Query( dataset_identifier=did, @@ -354,15 +448,22 @@ async def test_use_of_cache(mocker): signed_urls_only=True, expandable_progress=progress ) servicex2.assert_not_awaited() - mock_minio.list_bucket.assert_not_awaited() + if use_s3_polling: + mock_minio.list_bucket.assert_not_awaited() + else: + servicex.get_transformation_results.assert_not_awaited() mock_minio.get_signed_url.assert_not_awaited() upd.assert_not_called() assert result1 == result2 upd.reset_mock() servicex.get_transform_status.reset_mock(side_effect=True) servicex.get_transform_status.return_value = transform_status3 - mock_minio.list_bucket.reset_mock(side_effect=True) - # third round, should hit the cache and download files (and call update_record) + + if use_s3_polling: + mock_minio.list_bucket.reset_mock(side_effect=True) + else: + servicex.get_transformation_results.reset_mock(side_effect=True) + with ExpandableProgress(display_progress=False) as progress: await datasource.submit_and_download( signed_urls_only=False, expandable_progress=progress @@ -370,15 +471,21 @@ async def test_use_of_cache(mocker): servicex.assert_not_awaited() assert mock_minio.download_file.await_count == 2 upd.assert_called_once() - # fourth round, should hit the cache (and nothing else) - mock_minio.list_bucket.reset_mock() + + if use_s3_polling: + mock_minio.list_bucket.reset_mock() + else: + servicex.get_transformation_results.reset_mock() mock_minio.download_file.reset_mock() with ExpandableProgress(display_progress=False) as progress: await datasource.submit_and_download( signed_urls_only=False, expandable_progress=progress ) servicex.assert_not_awaited() - mock_minio.list_bucket.assert_not_awaited() + if use_s3_polling: + mock_minio.list_bucket.assert_not_awaited() + else: + servicex.get_transformation_results.assert_not_awaited() mock_minio.download_file.assert_not_awaited() upd.assert_called_once() cache.close() @@ -396,7 +503,6 @@ async def test_submit_cancel(mocker): ] mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1]]) mock_minio.download_file = AsyncMock( side_effect=lambda a, _, shorten_filename: PurePath(a) ) @@ -437,7 +543,6 @@ async def test_submit_fatal(mocker): ] mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1]]) mock_minio.download_file = AsyncMock( side_effect=lambda a, _, shorten_filename: PurePath(a) ) @@ -482,7 +587,6 @@ async def test_submit_generic(mocker, codegen_list): ] mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1, file2]]) mock_minio.download_file = AsyncMock() mock_cache = mocker.MagicMock(QueryCache) @@ -531,7 +635,6 @@ async def test_submit_cancelled(mocker, codegen_list): sx.get_transform_status.side_effect = [transform_status4] mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(side_effect=[[file1], [file1, file2]]) mock_minio.download_file = AsyncMock() mock_cache = mocker.MagicMock(QueryCache) @@ -601,10 +704,24 @@ async def test_use_of_ignore_cache(mocker, servicex): transform_status3, ] ) - + servicex.get_servicex_capabilities = AsyncMock( + return_value=["poll_local_transformation_results"] + ) + servicex.get_transformation_results = AsyncMock() + servicex.get_transformation_results.return_value = [ + ServiceXFile( + filename="file1.txt", + total_bytes=100, + created_at=datetime.datetime.now(datetime.timezone.utc), + ), + ServiceXFile( + filename="file2.txt", + total_bytes=100, + created_at=datetime.datetime.now(datetime.timezone.utc), + ), + ] # Prepare Minio mock_minio = AsyncMock() - mock_minio.list_bucket = AsyncMock(return_value=[file1, file2]) mock_minio.get_signed_url = AsyncMock(side_effect=["http://file1", "http://file2"]) mocker.patch("servicex.minio_adapter.MinioAdapter", return_value=mock_minio) did = FileListDataset("/foo/bar/baz.root") @@ -674,13 +791,13 @@ async def test_use_of_ignore_cache(mocker, servicex): transform_status1, transform_status3, ] - mock_minio.list_bucket.reset_mock() + servicex.get_transformation_results.reset_mock() mock_minio.download_file.reset_mock() with ExpandableProgress(display_progress=False) as progress: res = await datasource_without_ignore_cache.submit_and_download( signed_urls_only=True, expandable_progress=progress ) # noqa - mock_minio.list_bucket.assert_not_awaited() + servicex.get_transformation_results.assert_not_awaited() mock_minio.download_file.assert_not_awaited() assert len(res.signed_url_list) == 2 cache.close()