Skip to content

add supported server resources checks #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions servicex/query_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations

import datetime
import abc
import asyncio
from abc import ABC
Expand Down Expand Up @@ -318,6 +319,7 @@ def transform_complete(task: Task):
else None
)

begin_at = datetime.datetime.now(tz=datetime.timezone.utc)
if not cached_record:

if self.cache.is_transform_request_submitted(sx_request_hash):
Expand All @@ -342,13 +344,18 @@ def transform_complete(task: Task):

download_files_task = loop.create_task(
self.download_files(
signed_urls_only, expandable_progress, download_progress, cached_record
signed_urls_only,
expandable_progress,
download_progress,
cached_record,
begin_at,
)
)

try:
signed_urls = []
downloaded_files = []

download_result = await download_files_task
if signed_urls_only:
signed_urls = download_result
Expand Down Expand Up @@ -517,11 +524,13 @@ async def download_files(
progress: ExpandableProgress,
download_progress: TaskID,
cached_record: Optional[TransformedResults],
begin_at: datetime.datetime,
) -> List[str]:
"""
Task to monitor the list of files in the transform output's bucket. Any new files
will be downloaded.
"""

files_seen = set()
result_uris = []
download_tasks = []
Expand Down Expand Up @@ -551,21 +560,40 @@ async def get_signed_url(
if progress:
progress.advance(task_id=download_progress, task_type="Download")

transformation_results_enabled = (
"transformationresults" in await self.servicex.get_resources()
)

while True:
if not cached_record:
await asyncio.sleep(self.minio_polling_interval)
if self.minio:
# if self.minio exists, self.current_status will too
if self.current_status.files_completed > len(files_seen):
files = await self.minio.list_bucket()
if transformation_results_enabled:
new_begin_at = datetime.datetime.now(tz=datetime.timezone.utc)
files = await self.servicex.get_transformation_results(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any reason that get_transformation_results can't just return the same list of objects that self.minio.list_bucket() would? This would simplify the failover logic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_bucket returns a list of objects with a filename attribute, while the ServiceX adapter call is just turning JSON into dictionaries which have a 'file-path' key. I can change the key name to 'filename', but I don't think it's a good idea to change the dictionaries to objects with filename attributes given that pattern isn't reused anywhere else in the ServiceX adapter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle you could of course change the minio adapter to list the files with a dictionary instead of a list of objects of course... I would just like the "list of files to be downloaded" to be a type that doesn't depend on the source of the information.

Note also my comment on the backend PR - I'm concerned that this thing may need a new column in the database anyway, and whether you want to keep the same name may be up for discussion.

self.current_status.request_id, begin_at
)
begin_at = new_begin_at
else:
files = await self.minio.list_bucket()

for file in files:
if file.filename not in files_seen:
if transformation_results_enabled:
if "file-path" not in file:
continue
file_path = file.get("file-path", "").replace("/", ":")
else:
file_path = file.filename

if file_path not in files_seen:
if signed_urls_only:
download_tasks.append(
loop.create_task(
get_signed_url(
self.minio,
file.filename,
file_path,
progress,
download_progress,
)
Expand All @@ -576,14 +604,14 @@ async def get_signed_url(
loop.create_task(
download_file(
self.minio,
file.filename,
file_path,
progress,
download_progress,
shorten_filename=self.configuration.shortened_downloaded_filename, # NOQA: E501
)
)
) # NOQA 501
files_seen.add(file.filename)
files_seen.add(file_path)

# Once the transform is complete and all files are seen we can stop polling.
# Also, if we are just downloading or signing urls for a previous transform
Expand Down
154 changes: 152 additions & 2 deletions servicex/servicex_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import time
from typing import Optional, Dict, List
import datetime
import asyncio
from typing import Optional, Dict, List, Any, TypeVar, Callable, cast
from functools import wraps

from aiohttp import ClientSession
import httpx
Expand All @@ -43,6 +46,85 @@

from servicex.models import TransformRequest, TransformStatus, CachedDataset

T = TypeVar("T")


def requires_resource(
resource_name: str,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Decorator to check if a specific API resource is available on the server.

Args:
resource_name: The name of the resource that needs to be available

Returns:
A decorator function that wraps around class methods

Raises:
ResourceNotAvailableError: If the required resource is not available on the server
"""

def decorator(func: Callable[..., T]) -> Callable[..., T]:
# Determine if function is async at decoration time (not runtime)
is_async = asyncio.iscoroutinefunction(func)
func_name = func.__name__

# Class-level cache for sync method resources
sync_cache_key = f"_sync_resources_for_{resource_name}"

if is_async:

@wraps(func)
async def async_wrapper(self, *args: Any, **kwargs: Any) -> T:
# Get resources and check availability in one operation
if resource_name not in await self.get_resources():
raise ResourceNotAvailableError(
f"Resource '{resource_name}' required for '{func_name}' is unavailable"
)
return await func(self, *args, **kwargs)

return cast(Callable[..., T], async_wrapper)
else:

@wraps(func)
def sync_wrapper(self, *args: Any, **kwargs: Any) -> T:
# Initialize class-level cache attributes if needed
cls = self.__class__
if not hasattr(cls, sync_cache_key):
setattr(cls, sync_cache_key, (None, 0)) # (resources, timestamp)

cache_ttl = getattr(self, "_resources_cache_ttl", 300)
cached_resources, timestamp = getattr(cls, sync_cache_key)
current_time = time.time()

# Check if cache needs refresh
if cached_resources is None or (current_time - timestamp) >= cache_ttl:
loop = asyncio.new_event_loop()
try:
cached_resources = loop.run_until_complete(self.get_resources())
setattr(cls, sync_cache_key, (cached_resources, current_time))
finally:
loop.close()

# Check resource availability
if resource_name not in cached_resources:
raise ResourceNotAvailableError(
f"Resource '{resource_name}' required for '{func_name}' is unavailable"
)

return func(self, *args, **kwargs)

return cast(Callable[..., T], sync_wrapper)

return decorator


class ResourceNotAvailableError(Exception):
"""Exception raised when a required resource is not available on the server."""

pass


class AuthorizationError(BaseException):
pass
Expand All @@ -63,6 +145,47 @@ def __init__(self, url: str, refresh_token: Optional[str] = None):
self.refresh_token = refresh_token
self.token = None

self._available_resources: Optional[Dict[str, Any]] = None
self._resources_last_updated: Optional[float] = None
self._resources_cache_ttl = 60 * 5

async def get_resources(self) -> Dict[str, Any]:
"""
Fetches the list of available resources from the server.
Caches the result for 5 minutes to avoid excessive API calls.

Returns:
A dictionary of available resources with their properties
"""
current_time = time.time()

# Return cached resources if they exist and are not expired
if (
self._available_resources is not None
and self._resources_last_updated is not None
and current_time - self._resources_last_updated < self._resources_cache_ttl
):
return self._available_resources

# Fetch resources from server
headers = await self._get_authorization()
async with ClientSession() as session:
async with session.get(
headers=headers, url=f"{self.url}/servicex/resources"
) as r:
if r.status == 403:
raise AuthorizationError(
f"Not authorized to access serviceX at {self.url}"
)
elif r.status != 200:
msg = await _extract_message(r)
raise RuntimeError(f"Failed to get resources: {r.status} - {msg}")

self._available_resources = await r.json()
self._resources_last_updated = current_time

return self._available_resources

async def _get_token(self):
url = f"{self.url}/token/refresh"
headers = {"Authorization": f"Bearer {self.refresh_token}"}
Expand Down Expand Up @@ -228,14 +351,41 @@ async def delete_transform(self, transform_id=None):
f"Failed to delete transform {transform_id} - {msg}"
)

@requires_resource("transformationresults")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, in your code, we should never fail this check, right?

async def get_transformation_results(
self, request_id: str, begin_at: datetime.datetime
):
headers = await self._get_authorization()
url = self.url + f"/servicex/transformation/{request_id}/results"

params = {}
if begin_at:
params["begin_at"] = begin_at.isoformat()

async with ClientSession() as session:
async with session.get(headers=headers, url=url, params=params) as r:
if r.status == 403:
raise AuthorizationError(
f"Not authorized to access serviceX at {self.url}"
)

if r.status == 404:
raise ValueError(f"Request {request_id} not found")

if r.status != 200:
msg = await _extract_message(r)
raise RuntimeError(f"Failed with message: {msg}")

data = await r.json()
return data.get("results")

async def cancel_transform(self, transform_id=None):
headers = await self._get_authorization()
path_template = f"/servicex/transformation/{transform_id}/cancel"
url = self.url + path_template.format(transform_id=transform_id)

async with ClientSession() as session:
async with session.get(headers=headers, url=url) as r:

if r.status == 403:
raise AuthorizationError(
f"Not authorized to access serviceX at {self.url}"
Expand Down
22 changes: 20 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pytest
import tempfile
import os
import datetime

from unittest.mock import AsyncMock, Mock, patch
from servicex.dataset_identifier import FileListDataset
Expand Down Expand Up @@ -120,10 +121,19 @@ async def test_as_files_cached(transformed_result, python_dataset):
@pytest.mark.asyncio
async def test_download_files(python_dataset):
signed_urls_only = False
begin_at = datetime.datetime.now(tz=datetime.timezone.utc)
download_progress = "download_task_id"
minio_mock = AsyncMock()
config = Configuration(cache_path="temp_dir", api_endpoints=[])
python_dataset.configuration = config
python_dataset.servicex = AsyncMock()
python_dataset.servicex.get_transformation_results = AsyncMock(
side_effect=[
[{"file-path": "file1.txt"}],
[{"file-path": "file2.txt"}],
]
)

minio_mock.download_file.return_value = Path("/path/to/downloaded_file")
minio_mock.get_signed_url.return_value = Path("http://example.com/signed_url")
minio_mock.list_bucket.return_value = [
Expand All @@ -138,7 +148,7 @@ async def test_download_files(python_dataset):
python_dataset.configuration.shortened_downloaded_filename = False

result_uris = await python_dataset.download_files(
signed_urls_only, progress_mock, download_progress, None
signed_urls_only, progress_mock, download_progress, None, begin_at
)
minio_mock.download_file.assert_awaited()
minio_mock.get_signed_url.assert_not_awaited()
Expand All @@ -148,6 +158,7 @@ async def test_download_files(python_dataset):
@pytest.mark.asyncio
async def test_download_files_with_signed_urls(python_dataset):
signed_urls_only = True
begin_at = datetime.datetime.now(tz=datetime.timezone.utc)
download_progress = "download_task_id"
minio_mock = AsyncMock()
config = Configuration(cache_path="temp_dir", api_endpoints=[])
Expand All @@ -160,13 +171,20 @@ async def test_download_files_with_signed_urls(python_dataset):
]
progress_mock = Mock()

python_dataset.servicex = AsyncMock()
python_dataset.servicex.get_transformation_results = AsyncMock(
side_effect=[
[{"file-path": "file1.txt"}],
[{"file-path": "file2.txt"}],
]
)
python_dataset.minio_polling_interval = 0
python_dataset.minio = minio_mock
python_dataset.current_status = Mock(status="Complete", files_completed=2)
python_dataset.configuration.shortened_downloaded_filename = False

result_uris = await python_dataset.download_files(
signed_urls_only, progress_mock, download_progress, None
signed_urls_only, progress_mock, download_progress, None, begin_at
)
minio_mock.download_file.assert_not_called()
minio_mock.get_signed_url.assert_called()
Expand Down
Loading
Loading