Skip to content

Commit d9511c9

Browse files
authored
remove s3 bucket polling when waiting for transformation results (#587)
* remove s3 bucket polling when waiting for transformation results
1 parent ae3379c commit d9511c9

File tree

6 files changed

+634
-46
lines changed

6 files changed

+634
-46
lines changed

servicex/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,16 @@ class TransformedResults(DocStringBaseModel):
230230
"""URL for looking up logs on the ServiceX server"""
231231

232232

233+
class ServiceXInfo(DocStringBaseModel):
234+
r"""
235+
Model for ServiceX Info properties
236+
"""
237+
238+
app_version: str = Field(alias="app-version")
239+
code_gen_image: dict[str, str] = Field(alias="code-gen-image")
240+
capabilities: list[str] = Field(default_factory=list)
241+
242+
233243
class DatasetFile(BaseModel):
234244
"""
235245
Model for a file in a cached dataset

servicex/query_core.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828
from __future__ import annotations
2929

30+
import datetime
3031
import abc
3132
import asyncio
3233
from abc import ABC
@@ -342,13 +343,17 @@ def transform_complete(task: Task):
342343

343344
download_files_task = loop.create_task(
344345
self.download_files(
345-
signed_urls_only, expandable_progress, download_progress, cached_record
346+
signed_urls_only,
347+
expandable_progress,
348+
download_progress,
349+
cached_record,
346350
)
347351
)
348352

349353
try:
350354
signed_urls = []
351355
downloaded_files = []
356+
352357
download_result = await download_files_task
353358
if signed_urls_only:
354359
signed_urls = download_result
@@ -522,6 +527,7 @@ async def download_files(
522527
Task to monitor the list of files in the transform output's bucket. Any new files
523528
will be downloaded.
524529
"""
530+
525531
files_seen = set()
526532
result_uris = []
527533
download_tasks = []
@@ -555,40 +561,70 @@ async def get_signed_url(
555561
if progress:
556562
progress.advance(task_id=download_progress, task_type="Download")
557563

564+
later_than = datetime.datetime.min.replace(tzinfo=datetime.timezone.utc)
565+
566+
use_local_polling = (
567+
"poll_local_transformation_results"
568+
in await self.servicex.get_servicex_capabilities()
569+
)
570+
571+
if not use_local_polling:
572+
logger.warning(
573+
"ServiceX is using legacy S3 bucket polling. Future versions of the "
574+
"ServiceX client will not support this method. Please update your "
575+
"ServiceX server to the latest version."
576+
)
577+
558578
while True:
559579
if not cached_record:
560580
await asyncio.sleep(self.minio_polling_interval)
561581
if self.minio:
562582
# if self.minio exists, self.current_status will too
563583
if self.current_status.files_completed > len(files_seen):
564-
files = await self.minio.list_bucket()
584+
if use_local_polling:
585+
files = await self.servicex.get_transformation_results(
586+
self.current_status.request_id, later_than
587+
)
588+
else:
589+
files = await self.minio.list_bucket()
590+
565591
for file in files:
566-
if file.filename not in files_seen:
592+
filename = file.filename
593+
594+
if filename != "" and filename not in files_seen:
567595
if signed_urls_only:
568596
download_tasks.append(
569597
loop.create_task(
570598
get_signed_url(
571599
self.minio,
572-
file.filename,
600+
filename,
573601
progress,
574602
download_progress,
575603
)
576604
)
577605
)
578606
else:
607+
if use_local_polling:
608+
expected_size = file.total_bytes
609+
else:
610+
expected_size = file.size
579611
download_tasks.append(
580612
loop.create_task(
581613
download_file(
582614
self.minio,
583-
file.filename,
615+
filename,
584616
progress,
585617
download_progress,
586618
shorten_filename=self.configuration.shortened_downloaded_filename, # NOQA: E501
587-
expected_size=file.size,
619+
expected_size=expected_size,
588620
)
589621
)
590622
) # NOQA 501
591-
files_seen.add(file.filename)
623+
files_seen.add(filename)
624+
625+
if use_local_polling:
626+
if file.created_at > later_than:
627+
later_than = file.created_at
592628

593629
# Once the transform is complete and all files are seen we can stop polling.
594630
# Also, if we are just downloading or signing urls for a previous transform

servicex/servicex_adapter.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828
import os
2929
import time
30+
import datetime
3031
from typing import Optional, Dict, List
32+
from dataclasses import dataclass
3133

3234
import httpx
3335
from httpx import AsyncClient, Response
@@ -41,13 +43,25 @@
4143
retry_if_not_exception_type,
4244
)
4345

44-
from servicex.models import TransformRequest, TransformStatus, CachedDataset
46+
from servicex.models import (
47+
TransformRequest,
48+
TransformStatus,
49+
CachedDataset,
50+
ServiceXInfo,
51+
)
4552

4653

4754
class AuthorizationError(BaseException):
4855
pass
4956

5057

58+
@dataclass
59+
class ServiceXFile:
60+
created_at: datetime.datetime
61+
filename: str
62+
total_bytes: int
63+
64+
5165
async def _extract_message(r: Response):
5266
try:
5367
o = r.json()
@@ -63,6 +77,9 @@ def __init__(self, url: str, refresh_token: Optional[str] = None):
6377
self.refresh_token = refresh_token
6478
self.token = None
6579

80+
# interact with _servicex_info via get_servicex_info
81+
self._servicex_info: Optional[ServiceXInfo] = None
82+
6683
async def _get_token(self):
6784
url = f"{self.url}/token/refresh"
6885
headers = {"Authorization": f"Bearer {self.refresh_token}"}
@@ -120,6 +137,31 @@ async def _get_authorization(self, force_reauth: bool = False) -> Dict[str, str]
120137
await self._get_token()
121138
return {"Authorization": f"Bearer {self.token}"}
122139

140+
async def get_servicex_info(self) -> ServiceXInfo:
141+
if self._servicex_info:
142+
return self._servicex_info
143+
144+
headers = await self._get_authorization()
145+
retry_options = Retry(total=3, backoff_factor=10)
146+
async with AsyncClient(transport=RetryTransport(retry=retry_options)) as client:
147+
r = await client.get(url=f"{self.url}/servicex", headers=headers)
148+
if r.status_code == 401:
149+
raise AuthorizationError(
150+
f"Not authorized to access serviceX at {self.url}"
151+
)
152+
elif r.status_code > 400:
153+
error_message = await _extract_message(r)
154+
raise RuntimeError(
155+
"ServiceX WebAPI Error during transformation "
156+
f"submission: {r.status_code} - {error_message}"
157+
)
158+
servicex_info = r.json()
159+
self._servicex_info = ServiceXInfo(**servicex_info)
160+
return self._servicex_info
161+
162+
async def get_servicex_capabilities(self) -> List[str]:
163+
return (await self.get_servicex_info()).capabilities
164+
123165
async def get_transforms(self) -> List[TransformStatus]:
124166
headers = await self._get_authorization()
125167
retry_options = Retry(total=3, backoff_factor=10)
@@ -232,6 +274,48 @@ async def delete_transform(self, transform_id=None):
232274
msg = await _extract_message(r)
233275
raise RuntimeError(f"Failed to delete transform {transform_id} - {msg}")
234276

277+
async def get_transformation_results(
278+
self, request_id: str, later_than: Optional[datetime.datetime] = None
279+
):
280+
if (
281+
"poll_local_transformation_results"
282+
not in await self.get_servicex_capabilities()
283+
):
284+
raise ValueError("ServiceX capabilities not found")
285+
286+
headers = await self._get_authorization()
287+
url = self.url + f"/servicex/transformation/{request_id}/results"
288+
params = {}
289+
if later_than:
290+
params["later_than"] = later_than.isoformat()
291+
292+
async with AsyncClient() as session:
293+
r = await session.get(headers=headers, url=url, params=params)
294+
if r.status_code == 403:
295+
raise AuthorizationError(
296+
f"Not authorized to access serviceX at {self.url}"
297+
)
298+
299+
if r.status_code == 404:
300+
raise ValueError(f"Request {request_id} not found")
301+
302+
if r.status_code != 200:
303+
msg = await _extract_message(r)
304+
raise RuntimeError(f"Failed with message: {msg}")
305+
306+
data = r.json()
307+
response = list()
308+
for result in data.get("results", []):
309+
file = ServiceXFile(
310+
filename=result["s3-object-name"],
311+
created_at=datetime.datetime.fromisoformat(
312+
result["created_at"]
313+
).replace(tzinfo=datetime.timezone.utc),
314+
total_bytes=result["total-bytes"],
315+
)
316+
response.append(file)
317+
return response
318+
235319
async def cancel_transform(self, transform_id=None):
236320
headers = await self._get_authorization()
237321
path_template = f"/servicex/transformation/{transform_id}/cancel"

tests/test_dataset.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pytest
2929
import tempfile
3030
import os
31+
import datetime
3132

3233
from unittest.mock import AsyncMock, Mock, patch
3334
from servicex.dataset_identifier import FileListDataset
@@ -44,6 +45,8 @@
4445
)
4546
from rich.progress import Progress
4647

48+
from servicex.servicex_adapter import ServiceXFile
49+
4750

4851
@pytest.mark.asyncio
4952
async def test_as_signed_urls_happy(transformed_result):
@@ -124,12 +127,27 @@ async def test_download_files(python_dataset):
124127
minio_mock = AsyncMock()
125128
config = Configuration(cache_path="temp_dir", api_endpoints=[])
126129
python_dataset.configuration = config
130+
python_dataset.servicex = AsyncMock()
131+
python_dataset.servicex.get_servicex_capabilities = AsyncMock(
132+
return_value=["poll_local_transformation_results"]
133+
)
134+
135+
python_dataset.servicex.get_transformation_results = AsyncMock()
136+
python_dataset.servicex.get_transformation_results.return_value = [
137+
ServiceXFile(
138+
filename="file1.txt",
139+
created_at=datetime.datetime.now(datetime.timezone.utc),
140+
total_bytes=100,
141+
),
142+
ServiceXFile(
143+
filename="file2.txt",
144+
created_at=datetime.datetime.now(datetime.timezone.utc),
145+
total_bytes=100,
146+
),
147+
]
148+
127149
minio_mock.download_file.return_value = Path("/path/to/downloaded_file")
128150
minio_mock.get_signed_url.return_value = Path("http://example.com/signed_url")
129-
minio_mock.list_bucket.return_value = [
130-
Mock(filename="file1.txt"),
131-
Mock(filename="file2.txt"),
132-
]
133151

134152
progress_mock = Mock()
135153
python_dataset.minio_polling_interval = 0
@@ -154,12 +172,27 @@ async def test_download_files_with_signed_urls(python_dataset):
154172
python_dataset.configuration = config
155173
minio_mock.download_file.return_value = "/path/to/downloaded_file"
156174
minio_mock.get_signed_url.return_value = "http://example.com/signed_url"
157-
minio_mock.list_bucket.return_value = [
158-
Mock(filename="file1.txt"),
159-
Mock(filename="file2.txt"),
160-
]
161175
progress_mock = Mock()
162176

177+
python_dataset.servicex = AsyncMock()
178+
python_dataset.servicex.get_servicex_capabilities = AsyncMock(
179+
return_value=["poll_local_transformation_results"]
180+
)
181+
182+
python_dataset.servicex.get_transformation_results = AsyncMock()
183+
python_dataset.servicex.get_transformation_results.return_value = [
184+
ServiceXFile(
185+
filename="file1.txt",
186+
created_at=datetime.datetime.now(datetime.timezone.utc),
187+
total_bytes=100,
188+
),
189+
ServiceXFile(
190+
filename="file2.txt",
191+
created_at=datetime.datetime.now(datetime.timezone.utc),
192+
total_bytes=100,
193+
),
194+
]
195+
163196
python_dataset.minio_polling_interval = 0
164197
python_dataset.minio = minio_mock
165198
python_dataset.current_status = Mock(status="Complete", files_completed=2)

0 commit comments

Comments
 (0)