diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index ea433137..8b11e47b 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -6,6 +6,7 @@ Iterator, ) from datetime import datetime +from pathlib import Path from typing import ( Any, Callable, @@ -14,23 +15,49 @@ cast, ) +from neptune_scale.api.validation import ( + verify_max_length, + verify_non_empty, + verify_type, +) +from neptune_scale.sync.files.queue import FileUploadQueue from neptune_scale.sync.metadata_splitter import MetadataSplitter from neptune_scale.sync.operations_queue import OperationsQueue +from neptune_scale.sync.parameters import MAX_FILE_UPLOAD_BUFFER_SIZE +from neptune_scale.sync.util import arg_to_datetime __all__ = ("Attribute", "AttributeStore") -def warn_unsupported_params(fn: Callable) -> Callable: - # Perform some simple heuristics to detect if a method is called with parameters - # that are not supported by Scale +def _extract_named_kwargs(fn: Callable) -> set[str]: + """Return a set of named arguments of a function, that are not positional-only.""" + import inspect + + sig = inspect.signature(fn) + kwargs = { + p.name + for p in sig.parameters.values() + if p.kind in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + } + + return kwargs + + +def warn_unsupported_kwargs(fn: Callable) -> Callable: + """Perform some simple heuristics to detect if a method is called with parameters that are not supported by + Scale. Some methods in the old client accepted a **kwargs argument, which we currently do not inspect in any + way, so it's important to notify the user that an argument is being ignored. + """ + warn = functools.partial(warnings.warn, stacklevel=3) + known_kwargs = _extract_named_kwargs(fn) @functools.wraps(fn) def wrapper(*args, **kwargs): # type: ignore if kwargs.get("wait") is not None: warn("The `wait` parameter is not yet implemented and will be ignored.") - extra_kwargs = set(kwargs.keys()) - {"wait", "step", "timestamp", "steps", "timestamps"} + extra_kwargs = set(kwargs.keys()) - known_kwargs if extra_kwargs: warn( f"`{fn.__name__}()` was called with additional keyword argument(s): `{', '.join(extra_kwargs)}`. " @@ -54,11 +81,14 @@ class AttributeStore: end consuming the queue (which would be SyncProcess). """ - def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) -> None: + def __init__( + self, project: str, run_id: str, operations_queue: OperationsQueue, file_upload_queue: FileUploadQueue + ) -> None: self._project = project self._run_id = run_id self._operations_queue = operations_queue self._attributes: dict[str, Attribute] = {} + self._file_upload_queue = file_upload_queue def __getitem__(self, path: str) -> "Attribute": path = cleanup_path(path) @@ -92,7 +122,7 @@ def log( project=self._project, run_id=self._run_id, step=step, - timestamp=timestamp, + timestamp=arg_to_datetime(timestamp), configs=configs, metrics=metrics, add_tags=tags_add, @@ -102,6 +132,24 @@ def log( for operation, metadata_size in splitter: self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step) + def upload_file( + self, + attribute_path: str, + local_path: Optional[Path], + data: Optional[Union[str, bytes]], + target_basename: Optional[str], + target_path: Optional[str], + timestamp: Optional[Union[float, datetime]] = None, + ) -> None: + self._file_upload_queue.submit( + timestamp=arg_to_datetime(timestamp), + attribute_path=attribute_path, + local_path=local_path, + data=data.encode("utf-8") if isinstance(data, str) else data, + target_basename=target_basename, + target_path=target_path, + ) + class Attribute: """Objects of this class are returned on dict-like access to Run. Attributes have a path and @@ -118,12 +166,12 @@ def __init__(self, store: AttributeStore, path: str) -> None: self._path = path # TODO: typehint value properly - @warn_unsupported_params + @warn_unsupported_kwargs def assign(self, value: Any, *, wait: bool = False) -> None: data = accumulate_dict_values(value, self._path) self._store.log(configs=data) - @warn_unsupported_params + @warn_unsupported_kwargs def append( self, value: Union[dict[str, Any], float], @@ -136,7 +184,7 @@ def append( data = accumulate_dict_values(value, self._path) self._store.log(metrics=data, step=step, timestamp=timestamp) - @warn_unsupported_params + @warn_unsupported_kwargs # TODO: this should be Iterable in Run as well # def add(self, values: Union[str, Iterable[str]], *, wait: bool = False) -> None: def add(self, values: Union[str, Union[list[str], set[str], tuple[str]]], *, wait: bool = False) -> None: @@ -144,7 +192,7 @@ def add(self, values: Union[str, Union[list[str], set[str], tuple[str]]], *, wai values = (values,) self._store.log(tags_add={self._path: values}) - @warn_unsupported_params + @warn_unsupported_kwargs # TODO: this should be Iterable in Run as well # def remove(self, values: Union[str, Iterable[str]], *, wait: bool = False) -> None: def remove(self, values: Union[str, Union[list[str], set[str], tuple[str]]], *, wait: bool = False) -> None: @@ -152,7 +200,7 @@ def remove(self, values: Union[str, Union[list[str], set[str], tuple[str]]], *, values = (values,) self._store.log(tags_remove={self._path: values}) - @warn_unsupported_params + @warn_unsupported_kwargs def extend( self, values: Collection[Union[float, int]], @@ -173,6 +221,54 @@ def extend( for value, step, timestamp in zip(values, steps, timestamps): self.append(value, step=step, timestamp=timestamp, wait=wait) + @warn_unsupported_kwargs + def upload( + self, + path: Optional[str] = None, + *, + data: Optional[Union[str, bytes]] = None, + mime_type: Optional[str] = None, + target_basename: Optional[str] = None, + target_path: Optional[str] = None, + timestamp: Optional[Union[float, datetime]] = None, + wait: bool = False, + ) -> None: + verify_type("path", path, (str, type(None))) + + if data is not None: + verify_type("data", data, (str, bytes, type(None))) + verify_max_length("data", data, MAX_FILE_UPLOAD_BUFFER_SIZE) + + verify_type("mime_type", mime_type, (str, type(None))) + verify_type("target_basename", target_basename, (str, type(None))) + verify_type("target_path", target_path, (str, type(None))) + + if path is None and data is None: + raise ValueError("Either `path` or `data` must be provided") + + if path is not None and data is not None: + raise ValueError("Only one of `path` or `data` can be provided") + + local_path: Optional[Path] = None + if path: + verify_non_empty("path", path) + + local_path = Path(path) + if not local_path.exists(): + raise FileNotFoundError(f"Path `{path}` does not exist") + + if not local_path.is_file(): + raise ValueError(f"Path `{path}` is not a file") + + self._store.upload_file( + attribute_path=self._path, + local_path=local_path, + data=data.encode("utf-8") if isinstance(data, str) else data, + target_basename=target_basename, + target_path=target_path, + timestamp=timestamp, + ) + # TODO: add value type validation to all the methods # TODO: change Run API to typehint timestamp as Union[datetime, float] diff --git a/src/neptune_scale/api/run.py b/src/neptune_scale/api/run.py index f6d05f3b..e56fff3e 100644 --- a/src/neptune_scale/api/run.py +++ b/src/neptune_scale/api/run.py @@ -7,6 +7,7 @@ __all__ = ["Run"] import atexit +import math import os import threading import time @@ -47,6 +48,7 @@ ErrorsMonitor, ErrorsQueue, ) +from neptune_scale.sync.files.queue import FileUploadQueue from neptune_scale.sync.lag_tracking import LagTracker from neptune_scale.sync.operations_queue import OperationsQueue from neptune_scale.sync.parameters import ( @@ -205,7 +207,10 @@ def __init__( max_size=max_queue_size, ) - self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue) + self._file_upload_queue = FileUploadQueue() + self._attr_store: AttributeStore = AttributeStore( + self._project, self._run_id, self._operations_queue, self._file_upload_queue + ) self._errors_queue: ErrorsQueue = ErrorsQueue() self._errors_monitor = ErrorsMonitor( @@ -221,8 +226,10 @@ def __init__( self._last_ack_timestamp = SharedFloat(-1) self._process_link = ProcessLink() + self._sync_process = SyncProcess( project=self._project, + run_id=self._run_id, family=self._run_id, operations_queue=self._operations_queue.queue, errors_queue=self._errors_queue, @@ -231,9 +238,11 @@ def __init__( last_queued_seq=self._last_queued_seq, last_ack_seq=self._last_ack_seq, last_ack_timestamp=self._last_ack_timestamp, + file_upload_queue=self._file_upload_queue, max_queue_size=max_queue_size, mode=mode, ) + self._lag_tracker: Optional[LagTracker] = None if async_lag_threshold is not None and on_async_lag_callback is not None: self._lag_tracker = LagTracker( @@ -270,19 +279,18 @@ def _on_child_link_closed(self, _: ProcessLink) -> None: @property def resources(self) -> tuple[Resource, ...]: - if self._lag_tracker is not None: - return ( - self._errors_queue, - self._operations_queue, - self._lag_tracker, - self._errors_monitor, - ) - return ( - self._errors_queue, + res: tuple[Resource, ...] = ( self._operations_queue, + self._errors_queue, self._errors_monitor, + self._file_upload_queue, ) + if self._lag_tracker is not None: + res += (self._lag_tracker,) + + return res + def _close(self, *, wait: bool = True) -> None: with self._lock: if self._is_closing: @@ -552,6 +560,53 @@ def log( step=step, timestamp=timestamp, configs=configs, metrics=metrics, tags_add=tags_add, tags_remove=tags_remove ) + def log_file( + self, + attribute_path: str, + *, + path: Optional[str] = None, + data: Optional[Union[str, bytes]] = None, + mime_type: Optional[str] = None, + target_basename: Optional[str] = None, + target_path: Optional[str] = None, + timestamp: Optional[datetime] = None, + ) -> None: + """ + Uploads a file under the specified attribute path. The file contents can be read from a local + file or provided directly as str/bytes. + + run.log_file("configs/files/foo.txt", path="path/to/local/file.txt") + run.log_file("configs/files/bar.txt", data="file content") + + Args: + attribute_path: attribute name under which the file will be stored. + path: local path to the file. If provided, `data` must be `None`. + data: file content as a string or bytes. If provided, `path` must be `None`. + + The maximum length of the data is 10 MB. If the data is larger, use `path` instead. + If data is of type `str`, it will be encoded using UTF-8. If you need different encoding, + pass the data as `bytes`. + mime_type: MIME type of the file. If not provided, it will be guessed based on the file extension first, + then attribute path. + target_basename: basename of the file in the underlying object storage. If not provided, the final path + will be generated automatically, using the local file's basename, or randomly, if `data` is provided. + target_path: the full path to the file in the underlying object storage. It always takes precedence, so + caution is advised, as it is possible to overwrite existing files in the object storage. + timestamp: timestamp to be recorded for the operation. Defaults to `datetime.now()`. + """ + + verify_type("attribute_path", attribute_path, str) + verify_non_empty("attribute_path", attribute_path) + + self._attr_store[attribute_path].upload( + path, + data=data, + mime_type=mime_type, + target_basename=target_basename, + target_path=target_path, + timestamp=timestamp, + ) + def _wait( self, phrase: str, @@ -563,12 +618,14 @@ def _wait( if verbose: logger.info(f"Waiting for all operations to be {phrase}") - if timeout is None and verbose: - logger.warning("No timeout specified. Waiting indefinitely") + if timeout is None: + timeout = math.inf + if verbose: + logger.warning("No timeout specified. Waiting indefinitely") - begin_time = time.time() - wait_time = min(sleep_time, timeout) if timeout is not None else sleep_time + begin_time = time.monotonic() last_print_timestamp: Optional[float] = None + wait_time = min(sleep_time, timeout) while True: try: @@ -579,12 +636,20 @@ def _wait( logger.warning("Sync process is not running") return # No need to wait if the sync process is not running + active_uploads = self._file_upload_queue.active_uploads + if active_uploads: + last_print_timestamp = print_message( + f"Waiting for {active_uploads} file uploads to complete", + last_print=last_print_timestamp, + verbose=verbose, + ) + + with wait_seq: # Handle the case where we get notified on `wait_seq` before we actually wait. # Otherwise, we would unnecessarily block, waiting on a notify_all() that never happens. if wait_seq.value >= self._operations_queue.last_sequence_id: break - with wait_seq: wait_seq.wait(timeout=wait_time) value = wait_seq.value @@ -594,7 +659,7 @@ def _wait( if self._operations_queue.last_sequence_id != -1: last_print_timestamp = print_message( f"Waiting. No operations were {phrase} yet. Operations to sync: %s", - self._operations_queue.last_sequence_id + 1, + self._operations_queue.last_sequence_id + 1 + active_uploads, last_print=last_print_timestamp, verbose=verbose, ) @@ -607,19 +672,30 @@ def _wait( elif value < last_queued_sequence_id: last_print_timestamp = print_message( f"Waiting for remaining %d operation(s) to be {phrase}", - last_queued_sequence_id - value + 1, + last_queued_sequence_id - value + 1 + active_uploads, last_print=last_print_timestamp, verbose=verbose, ) else: # Reaching the last queued sequence ID means that all operations were submitted - if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout): + if value >= last_queued_sequence_id or time.monotonic() - begin_time > timeout: break except KeyboardInterrupt: if verbose: logger.warning("Waiting interrupted by user") return + if self._file_upload_queue.active_uploads: + last_print_timestamp = print_message( + f"Waiting for {active_uploads} file uploads to complete", + last_print=last_print_timestamp, + verbose=verbose, + ) + + # TODO: properly calculate the timeout based on the actual time passed. + # The PR for that is already there, but it's not merged yet. + self._file_upload_queue.wait_for_completion(timeout=wait_time) + if verbose: logger.info(f"All operations were {phrase}") diff --git a/src/neptune_scale/api/validation.py b/src/neptune_scale/api/validation.py index 20a75719..24a7acab 100644 --- a/src/neptune_scale/api/validation.py +++ b/src/neptune_scale/api/validation.py @@ -38,8 +38,12 @@ def verify_non_empty(var_name: str, var: Any) -> None: raise ValueError(f"{var_name} must not be empty") -def verify_max_length(var_name: str, var: str, max_length: int) -> None: - byte_len = len(var.encode("utf8")) +def verify_max_length(var_name: str, var: Union[str, bytes], max_length: int) -> None: + if isinstance(var, str): + byte_len = len(var.encode("utf8")) + else: + byte_len = len(var) + if byte_len > max_length: raise ValueError(f"{var_name} must not exceed {max_length} bytes, got {byte_len} bytes.") diff --git a/src/neptune_scale/exceptions.py b/src/neptune_scale/exceptions.py index f1e6aa7e..7f6ed3d6 100644 --- a/src/neptune_scale/exceptions.py +++ b/src/neptune_scale/exceptions.py @@ -51,8 +51,8 @@ class NeptuneScaleError(Exception): def __init__(self, *args: Any, **kwargs: Any) -> None: ensure_style_detected() - message = kwargs.pop("message", self.message) - super().__init__(message.format(*args, **STYLES, **kwargs)) + self.message = kwargs.pop("message", self.message) + super().__init__(self.message.format(*args, **STYLES, **kwargs)) class NeptuneScaleWarning(Warning): diff --git a/src/neptune_scale/net/api_client.py b/src/neptune_scale/net/api_client.py index a6cf90c2..0e6e498a 100644 --- a/src/neptune_scale/net/api_client.py +++ b/src/neptune_scale/net/api_client.py @@ -129,6 +129,9 @@ def submit(self, operation: RunOperation, family: str) -> Response[SubmitRespons @abc.abstractmethod def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ... + @abc.abstractmethod + def fetch_file_storage_info(self, project: str, file_path: str, permission: Literal["read", "write"]) -> str: ... + class HostedApiClient(ApiClient): def __init__(self, api_token: str) -> None: @@ -153,6 +156,9 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]), ) + def fetch_file_storage_info(self, project: str, file_path: str, permission: Literal["read", "write"]) -> str: + return f"https://DUMMY.localhost/{project}/{file_path}" + def close(self) -> None: logger.debug("Closing API client") self.backend.__exit__() @@ -181,6 +187,10 @@ def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequ ) return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={}) + def fetch_file_storage_info(self, project: str, file_path: str, permission: Literal["read", "write"]) -> str: + # TODO: request the actual endpoint + return f"https://localhost:65530/{project}/{file_path}" + def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient: if mode == "disabled": diff --git a/src/neptune_scale/net/ingest_code.py b/src/neptune_scale/net/ingest_code.py new file mode 100644 index 00000000..74fe179b --- /dev/null +++ b/src/neptune_scale/net/ingest_code.py @@ -0,0 +1,55 @@ +from neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 import IngestCode + +from neptune_scale.exceptions import ( + NeptuneAttributePathEmpty, + NeptuneAttributePathExceedsSizeLimit, + NeptuneAttributePathInvalid, + NeptuneAttributePathNonWritable, + NeptuneAttributeTypeMismatch, + NeptuneAttributeTypeUnsupported, + NeptuneFloatValueNanInfUnsupported, + NeptuneProjectInvalidName, + NeptuneProjectNotFound, + NeptuneRunConflicting, + NeptuneRunDuplicate, + NeptuneRunForkParentNotFound, + NeptuneRunInvalidCreationParameters, + NeptuneRunNotFound, + NeptuneSeriesPointDuplicate, + NeptuneSeriesStepNonIncreasing, + NeptuneSeriesStepNotAfterForkPoint, + NeptuneSeriesTimestampDecreasing, + NeptuneStringSetExceedsSizeLimit, + NeptuneStringValueExceedsSizeLimit, + NeptuneUnexpectedError, +) + +CODE_TO_ERROR: dict[IngestCode.ValueType, type[Exception]] = { + IngestCode.PROJECT_NOT_FOUND: NeptuneProjectNotFound, + IngestCode.PROJECT_INVALID_NAME: NeptuneProjectInvalidName, + IngestCode.RUN_NOT_FOUND: NeptuneRunNotFound, + IngestCode.RUN_DUPLICATE: NeptuneRunDuplicate, + IngestCode.RUN_CONFLICTING: NeptuneRunConflicting, + IngestCode.RUN_FORK_PARENT_NOT_FOUND: NeptuneRunForkParentNotFound, + IngestCode.RUN_INVALID_CREATION_PARAMETERS: NeptuneRunInvalidCreationParameters, + IngestCode.FIELD_PATH_EXCEEDS_SIZE_LIMIT: NeptuneAttributePathExceedsSizeLimit, + IngestCode.FIELD_PATH_EMPTY: NeptuneAttributePathEmpty, + IngestCode.FIELD_PATH_INVALID: NeptuneAttributePathInvalid, + IngestCode.FIELD_PATH_NON_WRITABLE: NeptuneAttributePathNonWritable, + IngestCode.FIELD_TYPE_UNSUPPORTED: NeptuneAttributeTypeUnsupported, + IngestCode.FIELD_TYPE_CONFLICTING: NeptuneAttributeTypeMismatch, + IngestCode.SERIES_POINT_DUPLICATE: NeptuneSeriesPointDuplicate, + IngestCode.SERIES_STEP_NON_INCREASING: NeptuneSeriesStepNonIncreasing, + IngestCode.SERIES_STEP_NOT_AFTER_FORK_POINT: NeptuneSeriesStepNotAfterForkPoint, + IngestCode.SERIES_TIMESTAMP_DECREASING: NeptuneSeriesTimestampDecreasing, + IngestCode.FLOAT_VALUE_NAN_INF_UNSUPPORTED: NeptuneFloatValueNanInfUnsupported, + IngestCode.STRING_VALUE_EXCEEDS_SIZE_LIMIT: NeptuneStringValueExceedsSizeLimit, + IngestCode.STRING_SET_EXCEEDS_SIZE_LIMIT: NeptuneStringSetExceedsSizeLimit, +} + + +def code_to_exception(code: IngestCode.ValueType) -> type[Exception]: + if exc_class := CODE_TO_ERROR.get(code): + return exc_class + + return NeptuneUnexpectedError diff --git a/src/neptune_scale/sync/errors_tracking.py b/src/neptune_scale/sync/errors_tracking.py index fbf48b88..f8182e12 100644 --- a/src/neptune_scale/sync/errors_tracking.py +++ b/src/neptune_scale/sync/errors_tracking.py @@ -5,6 +5,7 @@ import multiprocessing import queue import time +import traceback from collections.abc import Callable from typing import Optional @@ -17,7 +18,10 @@ NeptuneUnexpectedError, ) from neptune_scale.sync.parameters import ERRORS_MONITOR_THREAD_SLEEP_TIME -from neptune_scale.util import get_logger +from neptune_scale.util import ( + envs, + get_logger, +) from neptune_scale.util.abstract import Resource from neptune_scale.util.daemon import Daemon from neptune_scale.util.process_killer import kill_me @@ -30,6 +34,10 @@ def __init__(self) -> None: self._errors_queue: multiprocessing.Queue[BaseException] = multiprocessing.Queue() def put(self, error: BaseException) -> None: + if envs.get_bool(envs.LOG_TRACEBACKS, True): + logger.error("An error occurred in Neptune:") + logger.error("".join(traceback.format_exception(type(error), error, error.__traceback__))) + self._errors_queue.put(error) def get(self, block: bool = True, timeout: Optional[float] = None) -> BaseException: diff --git a/src/neptune_scale/sync/files/__init__.py b/src/neptune_scale/sync/files/__init__.py new file mode 100644 index 00000000..54016b44 --- /dev/null +++ b/src/neptune_scale/sync/files/__init__.py @@ -0,0 +1 @@ +"""This subpackage contains code for syncing files with Neptune.""" diff --git a/src/neptune_scale/sync/files/queue.py b/src/neptune_scale/sync/files/queue.py new file mode 100644 index 00000000..45fae5eb --- /dev/null +++ b/src/neptune_scale/sync/files/queue.py @@ -0,0 +1,75 @@ +import multiprocessing +import pathlib +from datetime import datetime +from typing import ( + NamedTuple, + Optional, +) + +from neptune_scale.util import SharedInt +from neptune_scale.util.abstract import Resource + + +class UploadMessage(NamedTuple): + timestamp: datetime + attribute_path: str + local_path: Optional[pathlib.Path] + data: Optional[bytes] + target_path: Optional[str] + target_basename: Optional[str] + + +class FileUploadQueue(Resource): + """Queue for submitting file upload requests from the main process, to a + FiledUploadWorkerThread, spawned in the worker process. + + The main process submits requests by calling the `submit` method, and waits + for all uploads to complete by calling the `wait_for_completion` method. + """ + + def __init__(self) -> None: + self._queue: multiprocessing.Queue[UploadMessage] = multiprocessing.Queue(maxsize=4096) + self._active_uploads = SharedInt(0) + + @property + def active_uploads(self) -> int: + """Returns the number of currently active uploads.""" + with self._active_uploads: + return self._active_uploads.value + + # Main process API + def submit( + self, + *, + timestamp: datetime, + attribute_path: str, + local_path: Optional[pathlib.Path], + data: Optional[bytes], + target_path: Optional[str], + target_basename: Optional[str], + ) -> None: + assert data is not None or local_path + with self._active_uploads: + self._active_uploads.value += 1 + self._queue.put(UploadMessage(timestamp, attribute_path, local_path, data, target_path, target_basename)) + + def wait_for_completion(self, timeout: Optional[float] = None) -> bool: + """Blocks until all uploads are completed or the timeout is reached. + Returns True if all uploads completed, False if the timeout was reached. + """ + with self._active_uploads: + return self._active_uploads.wait_for(lambda: self._active_uploads.value == 0, timeout=timeout) + + def close(self) -> None: + self._queue.close() + self._queue.cancel_join_thread() + + # Worker process API + def decrement_active(self) -> None: + with self._active_uploads: + self._active_uploads.value -= 1 + assert self._active_uploads.value >= 0 + self._active_uploads.notify_all() + + def get(self, timeout: float) -> UploadMessage: + return self._queue.get(timeout=timeout) diff --git a/src/neptune_scale/sync/files/worker.py b/src/neptune_scale/sync/files/worker.py new file mode 100644 index 00000000..31d6463f --- /dev/null +++ b/src/neptune_scale/sync/files/worker.py @@ -0,0 +1,266 @@ +import io +import mimetypes +import time +import uuid +from collections.abc import ( + Callable, + Sequence, +) +from concurrent import futures +from datetime import datetime +from pathlib import Path +from queue import Empty +from typing import ( + BinaryIO, + Literal, + Optional, +) + +import backoff +from neptune_api.proto.google_rpc.code_pb2 import Code +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation +from neptune_api.types import Response + +from neptune_scale.exceptions import ( + NeptuneInternalServerError, + NeptuneRetryableError, + NeptuneUnauthorizedError, + NeptuneUnexpectedResponseError, +) +from neptune_scale.net.api_client import ( + ApiClient, + backend_factory, + with_api_errors_handling, +) +from neptune_scale.net.ingest_code import code_to_exception +from neptune_scale.net.serialization import datetime_to_proto +from neptune_scale.sync.errors_tracking import ErrorsQueue +from neptune_scale.sync.files.queue import ( + FileUploadQueue, + UploadMessage, +) +from neptune_scale.sync.parameters import MAX_REQUEST_RETRY_SECONDS +from neptune_scale.util import ( + Daemon, + get_logger, +) +from neptune_scale.util.abstract import Resource + +logger = get_logger() + + +class FileUploadWorkerThread(Daemon, Resource): + """Consumes messages from the provided FileUploadQueue and performs the upload operation + in a pool of worker threads. + """ + + def __init__( + self, + *, + project: str, + run_id: str, + api_token: str, + family: str, + mode: Literal["async", "disabled"], + input_queue: FileUploadQueue, + errors_queue: ErrorsQueue, + ) -> None: + super().__init__(sleep_time=0.5, name="FileUploader") + + self._project = project + self._run_id = run_id + self._api_token = api_token + self._family = family + self._mode = mode + self._input_queue = input_queue + self._errors_queue = errors_queue + self._executor = futures.ThreadPoolExecutor() + + self._backend: Optional[ApiClient] = None + + def work(self) -> None: + while self.is_running(): + try: + msg = self._input_queue.get(timeout=1) + except Empty: + continue + + try: + if self._backend is None: + self._backend = backend_factory(self._api_token, self._mode) + + future = self._executor.submit( + self._do_upload, + msg.timestamp, + msg.attribute_path, + msg.local_path, + msg.data, + msg.target_path, + msg.target_basename, + ) + future.add_done_callback(self._make_done_callback(msg)) + except Exception as e: + logger.error(f"Failed to submit file upload task for `{msg.attribute_path}`: {e}") + self._input_queue.decrement_active() + self._errors_queue.put(e) + + def close(self) -> None: + self._executor.shutdown() + + def _do_upload( + self, + timestamp: datetime, + attribute_path: str, + local_path: Optional[Path], + data: Optional[bytes], + target_path: Optional[str], + target_basename: Optional[str], + ) -> None: + path, mime_type = determine_path_and_mime_type( + self._run_id, attribute_path, local_path, target_path, target_basename + ) + + try: + url = self._request_upload_url(path) + src = local_path.open("rb") if local_path else io.BytesIO(data) # type: ignore + with src: + upload_file(src, url, mime_type) + + request_id = self._submit_attribute(attribute_path, path, timestamp) + self._wait_for_completion(request_id) + except Exception as e: + raise e + + @backoff.on_exception(backoff.expo, NeptuneRetryableError, max_time=MAX_REQUEST_RETRY_SECONDS) + @with_api_errors_handling + def _request_upload_url(self, file_path: str) -> str: + assert self._backend is not None + return self._backend.fetch_file_storage_info(self._project, file_path, "write") + + @backoff.on_exception(backoff.expo, NeptuneRetryableError, max_time=MAX_REQUEST_RETRY_SECONDS) + @with_api_errors_handling + def _submit_attribute(self, attribute_path: str, file_path: str, timestamp: datetime) -> Sequence[str]: + """Request the Ingest API to save a File type attribute under `attribute_path`. + Returns a request id for tracking the status of the operation. + """ + + assert self._backend is not None # mypy + + op = RunOperation( + project=self._project, + run_id=self._run_id, + # TODO: replace with the actual Value type once it's introduced to protobuf + update=UpdateRunSnapshot( + timestamp=datetime_to_proto(timestamp), assign={attribute_path: Value(string=file_path)} + ), + ) + + response = self._backend.submit(operation=op, family=self._family) + raise_on_response(response) + assert response.parsed # mypy + + return response.parsed.request_ids + + @backoff.on_exception(backoff.expo, NeptuneRetryableError, max_time=MAX_REQUEST_RETRY_SECONDS) + @with_api_errors_handling + def _wait_for_completion(self, request_ids: list[str]) -> None: + assert self._backend is not None # mypy + + while self.is_running(): + response = self._backend.check_batch(request_ids, self._project) + raise_on_response(response) + assert response.parsed # mypy + + status = response.parsed.statuses[0] + if any(code_status.code == Code.UNAVAILABLE for code_status in status.code_by_count): + # The request is still being processed, check back in a moment + time.sleep(1) + continue + + for code_status in status.code_by_count: + if code_status.code != Code.OK: + exc_class = code_to_exception(code_status.detail) + self._errors_queue.put(exc_class()) + + # The request finished successfully or with an error, either way we can break + break + + def _make_done_callback(self, message: UploadMessage) -> Callable[[futures.Future], None]: + """Returns a callback function suitable for use with Future.add_done_callback(). Decreases the active upload + count and propagates any exception to the errors queue. + """ + + def _on_task_completed(future: futures.Future) -> None: + self._input_queue.decrement_active() + + try: + future.result() + except Exception as e: + logger.error(f"Failed to upload file as `{message.attribute_path}`: {e}") + self._errors_queue.put(e) + + return _on_task_completed + + +def determine_path_and_mime_type( + run_id: str, + attribute_path: str, + local_path: Optional[Path], + target_path: Optional[str], + target_basename: Optional[str], +) -> tuple[str, str]: + mime_type = guess_mime_type(attribute_path, local_path) + + # Target path always takes precedence as-is + if target_path: + return target_path, mime_type + + if local_path: + local_basename = local_path.name + else: + local_basename = f"{uuid.uuid4()}{mimetypes.guess_extension(mime_type)}" + + if target_basename: + parts: tuple[str, ...] = (run_id, attribute_path, target_basename) + else: + parts = (run_id, attribute_path, str(uuid.uuid4()), local_basename) + + return "/".join(parts), mime_type + + +def upload_file(source: BinaryIO, url: str, mime_type: str) -> None: + # TODO: do the actual work :) + assert source and url and mime_type + time.sleep(1) + pass + + +def guess_mime_type(attribute_path: str, local_path: Optional[Path]) -> str: + if local_path: + mime_type, _ = mimetypes.guess_type(local_path or attribute_path) + if mime_type is not None: + return mime_type + + mime_type, _ = mimetypes.guess_type(attribute_path) + return mime_type or "application/octet-stream" + + +def raise_on_response(response: Response, allow_empty_response: bool = False) -> None: + if response.status_code == 200: + return + + if response.parsed is None and not allow_empty_response: + raise NeptuneUnexpectedResponseError(reason="Empty server response") + + if response.status_code == 403: + raise NeptuneUnauthorizedError() + + logger.error("HTTP response error: %s", response.status_code) + if response.status_code // 100 == 5: + raise NeptuneInternalServerError() + else: + raise NeptuneUnexpectedResponseError() diff --git a/src/neptune_scale/sync/parameters.py b/src/neptune_scale/sync/parameters.py index f3d4b7fe..dbdf91e2 100644 --- a/src/neptune_scale/sync/parameters.py +++ b/src/neptune_scale/sync/parameters.py @@ -31,3 +31,8 @@ # Status tracking MAX_REQUESTS_STATUS_BATCH_SIZE = 1000 + +# Files + +# Maximum size of file data provided via a buffer (as opposed to a file on a filesystem) +MAX_FILE_UPLOAD_BUFFER_SIZE = 10 * 1024**2 # 10MB diff --git a/src/neptune_scale/sync/sync_process.py b/src/neptune_scale/sync/sync_process.py index 2a08f916..444f87a9 100644 --- a/src/neptune_scale/sync/sync_process.py +++ b/src/neptune_scale/sync/sync_process.py @@ -17,11 +17,11 @@ NamedTuple, Optional, TypeVar, + cast, ) import backoff from neptune_api.proto.google_rpc.code_pb2 import Code -from neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 import IngestCode from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import ( BulkRequestStatus, SubmitResponse, @@ -29,30 +29,10 @@ from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation from neptune_scale.exceptions import ( - NeptuneAttributePathEmpty, - NeptuneAttributePathExceedsSizeLimit, - NeptuneAttributePathInvalid, - NeptuneAttributePathNonWritable, - NeptuneAttributeTypeMismatch, - NeptuneAttributeTypeUnsupported, NeptuneConnectionLostError, - NeptuneFloatValueNanInfUnsupported, NeptuneInternalServerError, NeptuneOperationsQueueMaxSizeExceeded, - NeptuneProjectInvalidName, - NeptuneProjectNotFound, NeptuneRetryableError, - NeptuneRunConflicting, - NeptuneRunDuplicate, - NeptuneRunForkParentNotFound, - NeptuneRunInvalidCreationParameters, - NeptuneRunNotFound, - NeptuneSeriesPointDuplicate, - NeptuneSeriesStepNonIncreasing, - NeptuneSeriesStepNotAfterForkPoint, - NeptuneSeriesTimestampDecreasing, - NeptuneStringSetExceedsSizeLimit, - NeptuneStringValueExceedsSizeLimit, NeptuneSynchronizationStopped, NeptuneUnauthorizedError, NeptuneUnexpectedError, @@ -63,8 +43,11 @@ backend_factory, with_api_errors_handling, ) +from neptune_scale.net.ingest_code import code_to_exception from neptune_scale.sync.aggregating_queue import AggregatingQueue from neptune_scale.sync.errors_tracking import ErrorsQueue +from neptune_scale.sync.files.queue import FileUploadQueue +from neptune_scale.sync.files.worker import FileUploadWorkerThread from neptune_scale.sync.parameters import ( INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME, MAX_QUEUE_SIZE, @@ -96,30 +79,6 @@ logger = get_logger() -CODE_TO_ERROR: dict[IngestCode.ValueType, Optional[type[Exception]]] = { - IngestCode.OK: None, - IngestCode.PROJECT_NOT_FOUND: NeptuneProjectNotFound, - IngestCode.PROJECT_INVALID_NAME: NeptuneProjectInvalidName, - IngestCode.RUN_NOT_FOUND: NeptuneRunNotFound, - IngestCode.RUN_DUPLICATE: NeptuneRunDuplicate, - IngestCode.RUN_CONFLICTING: NeptuneRunConflicting, - IngestCode.RUN_FORK_PARENT_NOT_FOUND: NeptuneRunForkParentNotFound, - IngestCode.RUN_INVALID_CREATION_PARAMETERS: NeptuneRunInvalidCreationParameters, - IngestCode.FIELD_PATH_EXCEEDS_SIZE_LIMIT: NeptuneAttributePathExceedsSizeLimit, - IngestCode.FIELD_PATH_EMPTY: NeptuneAttributePathEmpty, - IngestCode.FIELD_PATH_INVALID: NeptuneAttributePathInvalid, - IngestCode.FIELD_PATH_NON_WRITABLE: NeptuneAttributePathNonWritable, - IngestCode.FIELD_TYPE_UNSUPPORTED: NeptuneAttributeTypeUnsupported, - IngestCode.FIELD_TYPE_CONFLICTING: NeptuneAttributeTypeMismatch, - IngestCode.SERIES_POINT_DUPLICATE: NeptuneSeriesPointDuplicate, - IngestCode.SERIES_STEP_NON_INCREASING: NeptuneSeriesStepNonIncreasing, - IngestCode.SERIES_STEP_NOT_AFTER_FORK_POINT: NeptuneSeriesStepNotAfterForkPoint, - IngestCode.SERIES_TIMESTAMP_DECREASING: NeptuneSeriesTimestampDecreasing, - IngestCode.FLOAT_VALUE_NAN_INF_UNSUPPORTED: NeptuneFloatValueNanInfUnsupported, - IngestCode.STRING_VALUE_EXCEEDS_SIZE_LIMIT: NeptuneStringValueExceedsSizeLimit, - IngestCode.STRING_SET_EXCEEDS_SIZE_LIMIT: NeptuneStringSetExceedsSizeLimit, -} - class StatusTrackingElement(NamedTuple): sequence_id: int @@ -127,12 +86,6 @@ class StatusTrackingElement(NamedTuple): request_id: str -def code_to_exception(code: IngestCode.ValueType) -> Optional[type[Exception]]: - if code in CODE_TO_ERROR: - return CODE_TO_ERROR[code] - return NeptuneUnexpectedError - - class PeekableQueue(Generic[T]): def __init__(self) -> None: self._lock: threading.RLock = threading.RLock() @@ -167,8 +120,10 @@ def __init__( operations_queue: Queue, errors_queue: ErrorsQueue, process_link: ProcessLink, + file_upload_queue: FileUploadQueue, api_token: str, project: str, + run_id: str, family: str, mode: Literal["async", "disabled"], last_queued_seq: SharedInt, @@ -178,11 +133,13 @@ def __init__( ) -> None: super().__init__(name="SyncProcess") - self._external_operations_queue: Queue[SingleOperation] = operations_queue + self._input_operations_queue: Queue[SingleOperation] = operations_queue self._errors_queue: ErrorsQueue = errors_queue self._process_link: ProcessLink = process_link + self._file_upload_queue: FileUploadQueue = file_upload_queue self._api_token: str = api_token self._project: str = project + self._run_id: str = run_id self._family: str = family self._last_queued_seq: SharedInt = last_queued_seq self._last_ack_seq: SharedInt = last_ack_seq @@ -209,10 +166,12 @@ def run(self) -> None: worker = SyncProcessWorker( project=self._project, + run_id=self._run_id, family=self._family, api_token=self._api_token, errors_queue=self._errors_queue, - external_operations_queue=self._external_operations_queue, + input_queue=self._input_operations_queue, + file_upload_queue=self._file_upload_queue, last_queued_seq=self._last_queued_seq, last_ack_seq=self._last_ack_seq, max_queue_size=self._max_queue_size, @@ -240,10 +199,12 @@ def __init__( *, api_token: str, project: str, + run_id: str, family: str, mode: Literal["async", "disabled"], errors_queue: ErrorsQueue, - external_operations_queue: multiprocessing.Queue[SingleOperation], + input_queue: multiprocessing.Queue[SingleOperation], + file_upload_queue: FileUploadQueue, last_queued_seq: SharedInt, last_ack_seq: SharedInt, last_ack_timestamp: SharedFloat, @@ -262,9 +223,9 @@ def __init__( last_queued_seq=last_queued_seq, mode=mode, ) - self._external_to_internal_thread = InternalQueueFeederThread( - external=external_operations_queue, - internal=self._internal_operations_queue, + self._operation_dispatcher_thread = OperationDispatcherThread( + input_queue=input_queue, + operations_queue=self._internal_operations_queue, errors_queue=self._errors_queue, ) self._status_tracking_thread = StatusTrackingThread( @@ -276,14 +237,28 @@ def __init__( last_ack_seq=last_ack_seq, last_ack_timestamp=last_ack_timestamp, ) + self._file_upload_thread = FileUploadWorkerThread( + project=project, + run_id=run_id, + api_token=api_token, + family=family, + mode=mode, + input_queue=file_upload_queue, + errors_queue=self._errors_queue, + ) @property def threads(self) -> tuple[Daemon, ...]: - return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread + return ( + self._operation_dispatcher_thread, + self._sync_thread, + self._status_tracking_thread, + self._file_upload_thread, + ) @property def resources(self) -> tuple[Resource, ...]: - return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread + return cast(tuple[Resource], self.threads) def interrupt(self) -> None: for thread in self.threads: @@ -303,17 +278,17 @@ def join(self, timeout: Optional[int] = None) -> None: thread.join(timeout=timeout) -class InternalQueueFeederThread(Daemon, Resource): +class OperationDispatcherThread(Daemon, Resource): def __init__( self, - external: multiprocessing.Queue[SingleOperation], - internal: AggregatingQueue, + input_queue: multiprocessing.Queue[SingleOperation], + operations_queue: AggregatingQueue, errors_queue: ErrorsQueue, ) -> None: - super().__init__(name="InternalQueueFeederThread", sleep_time=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) + super().__init__(name="OperationDispatcherThread", sleep_time=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) - self._external: multiprocessing.Queue[SingleOperation] = external - self._internal: AggregatingQueue = internal + self._input_queue: multiprocessing.Queue[SingleOperation] = input_queue + self._operations_queue: AggregatingQueue = operations_queue self._errors_queue: ErrorsQueue = errors_queue self._latest_unprocessed: Optional[SingleOperation] = None @@ -323,7 +298,7 @@ def get_next(self) -> Optional[SingleOperation]: return self._latest_unprocessed try: - self._latest_unprocessed = self._external.get(timeout=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) + self._latest_unprocessed = self._input_queue.get(timeout=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) return self._latest_unprocessed except queue.Empty: return None @@ -339,11 +314,15 @@ def work(self) -> None: continue try: - self._internal.put_nowait(operation) + self._operations_queue.put_nowait(operation) self.commit() except queue.Full: - logger.debug("Internal queue is full (%d elements), waiting for free space", self._internal.maxsize) - self._errors_queue.put(NeptuneOperationsQueueMaxSizeExceeded(max_size=self._internal.maxsize)) + logger.debug( + "Operations queue is full (%d elements), waiting for free space", self._operations_queue.maxsize + ) + self._errors_queue.put( + NeptuneOperationsQueueMaxSizeExceeded(max_size=self._operations_queue.maxsize) + ) # Sleep before retry break except Exception as e: @@ -532,8 +511,9 @@ def work(self) -> None: break for code_status in request_status.code_by_count: - if code_status.code != Code.OK and (error := code_to_exception(code_status.detail)) is not None: - self._errors_queue.put(error()) + if code_status.code != Code.OK: + exc_class = code_to_exception(code_status.detail) + self._errors_queue.put(exc_class()) operations_to_commit += 1 processed_sequence_id, processed_timestamp = request_sequence_id, timestamp diff --git a/src/neptune_scale/sync/util.py b/src/neptune_scale/sync/util.py index 60fe4b0b..00c164a2 100644 --- a/src/neptune_scale/sync/util.py +++ b/src/neptune_scale/sync/util.py @@ -1,4 +1,12 @@ import signal +from datetime import ( + datetime, + timezone, +) +from typing import ( + Optional, + Union, +) def safe_signal_name(signum: int) -> str: @@ -8,3 +16,13 @@ def safe_signal_name(signum: int) -> str: signame = str(signum) return signame + + +def arg_to_datetime(timestamp: Optional[Union[float, datetime]] = None) -> datetime: + """Convert the provided float timestamp to datetime. If None, return current time in UTC.""" + + if timestamp is None: + timestamp = datetime.now() + elif isinstance(timestamp, (float, int)): + timestamp = datetime.fromtimestamp(timestamp, timezone.utc) + return timestamp diff --git a/src/neptune_scale/util/envs.py b/src/neptune_scale/util/envs.py index 843fa9c3..6a135b82 100644 --- a/src/neptune_scale/util/envs.py +++ b/src/neptune_scale/util/envs.py @@ -8,6 +8,9 @@ DEBUG_MODE = "NEPTUNE_DEBUG_MODE" +# Log tracebacks of any exceptions that make it into the ErrorQueue. Default: True. +LOG_TRACEBACKS = "NEPTUNE_LOG_TRACEBACKS" + SUBPROCESS_KILL_TIMEOUT = "NEPTUNE_SUBPROCESS_KILL_TIMEOUT" ALLOW_SELF_SIGNED_CERTIFICATE = "NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE" diff --git a/src/neptune_scale/util/shared_var.py b/src/neptune_scale/util/shared_var.py index 84f48c28..7c0a4354 100644 --- a/src/neptune_scale/util/shared_var.py +++ b/src/neptune_scale/util/shared_var.py @@ -25,7 +25,7 @@ class SharedVar(Generic[T]): # In one process with var: var.value += 1 - var.notify() # Notify the waiting process + var.notify_all() # Notify the waiting process # In another process with var: @@ -47,13 +47,13 @@ def value(self, new_value: T) -> None: with self._condition: self._value.value = new_value - def wait(self, timeout: Optional[float] = None) -> None: + def wait(self, timeout: Optional[float] = None) -> bool: with self._condition: - self._condition.wait(timeout) + return self._condition.wait(timeout) - def wait_for(self, predicate: Callable[[], bool], timeout: Optional[float] = None) -> None: + def wait_for(self, predicate: Callable[[], bool], timeout: Optional[float] = None) -> bool: with self._condition: - self._condition.wait_for(predicate, timeout) + return self._condition.wait_for(predicate, timeout) def notify_all(self) -> None: with self._condition: diff --git a/tests/unit/test_file_upload.py b/tests/unit/test_file_upload.py new file mode 100644 index 00000000..ce5afa24 --- /dev/null +++ b/tests/unit/test_file_upload.py @@ -0,0 +1,223 @@ +import tempfile +from datetime import datetime +from pathlib import Path +from queue import Empty +from unittest.mock import ( + Mock, + patch, +) + +import pytest +from pytest import ( + fixture, + mark, +) + +from neptune_scale.exceptions import NeptuneScaleError +from neptune_scale.sync.errors_tracking import ErrorsQueue +from neptune_scale.sync.files.queue import FileUploadQueue +from neptune_scale.sync.files.worker import ( + FileUploadWorkerThread, + determine_path_and_mime_type, +) + + +@fixture +def queue(): + return FileUploadQueue() + + +@fixture +def errors_queue(): + return ErrorsQueue() + + +@fixture +def worker(queue, api_token, errors_queue): + worker = FileUploadWorkerThread( + project="project", + run_id="run_id", + api_token=api_token, + family="family", + input_queue=queue, + errors_queue=errors_queue, + mode="disabled", + ) + + worker._request_upload_url = Mock(return_value="http://DUMMY.localhost/") + worker._submit_attribute = Mock() + worker._wait_for_completion = Mock() + + worker.start() + + return worker + + +@mark.parametrize( + "local, full, basename, expected", + ( + ("some/file.py", None, None, "RUN/ATTR/UUID4/file.py"), + ("some/file.py", None, "file.txt", "RUN/ATTR/file.txt"), + ("some/file.py", "full/path.txt", None, "full/path.txt"), + ("some/file.py", "full/path.txt", "basename", "full/path.txt"), + ), +) +def test_determine_path(local, full, basename, expected): + with patch("uuid.uuid4", return_value="UUID4"): + path, mimetype = determine_path_and_mime_type("RUN", "ATTR", Path(local), full, basename) + assert path == expected + + +@mark.parametrize( + "attr, local, expected", + ( + ("attr", None, "application/octet-stream"), + ("attr.jpg", None, "image/jpeg"), + ("attr.jpg", Path("local/file.py"), "text/x-python"), + ("attr.jpg", Path("local/file"), "image/jpeg"), + ), +) +def test_determine_mime_type(attr, local, expected): + path, mimetype = determine_path_and_mime_type("RUN", attr, local, None, None) + assert mimetype == expected + + +def test_queue_wait_for_completion(queue): + queue.submit( + attribute_path="attr", + local_path=None, + data=b"test", + target_path=None, + target_basename=None, + timestamp=datetime.now(), + ) + queue.submit( + attribute_path="attr2", + local_path=None, + data=b"test", + target_path=None, + target_basename=None, + timestamp=datetime.now(), + ) + + assert queue.active_uploads == 2 + + queue.decrement_active() + assert queue.active_uploads == 1 + + assert not queue.wait_for_completion(timeout=0.5) + + queue.decrement_active() + assert queue.active_uploads == 0 + + assert queue.wait_for_completion(timeout=1) + + +def test_successful_upload_from_buffer(worker, queue, errors_queue): + data = b"test" + + def expect_bytes(source, _url, _mime_type): + assert source.read() == data + + with patch("neptune_scale.sync.files.worker.upload_file", Mock(side_effect=expect_bytes)) as upload_file: + queue.submit( + attribute_path="attr.txt", + local_path=None, + data=data, + target_path=None, + target_basename=None, + timestamp=datetime.now(), + ) + assert queue.wait_for_completion(timeout=10) + assert queue.active_uploads == 0 + + worker.close() + + worker._request_upload_url.assert_called_once() + worker._submit_attribute.assert_called_once() + worker._wait_for_completion.assert_called_once() + + upload_file.assert_called_once() + with pytest.raises(Empty): + errors_queue.get(timeout=1) + + +def test_successful_upload_from_file(worker, queue, errors_queue): + data = b"test" + + def expect_bytes(source, _url, _mime_type): + assert source.read() == data + + # Note that we cannot use NamedTemporaryFile here, because the test will fail on Windows. Windows opens + # temporary files in a way that prevents them from being opened by another process / thread, which is + # exactly our case. + with ( + patch("neptune_scale.sync.files.worker.upload_file", Mock(side_effect=expect_bytes)) as upload_file, + tempfile.TemporaryDirectory() as temp_dir, + ): + file_path = Path(temp_dir) / "file.txt" + with file_path.open("wb") as temp_file: + temp_file.write(data) + + queue.submit( + attribute_path="attr.txt", + local_path=file_path, + data=None, + target_path=None, + target_basename=None, + timestamp=datetime.now(), + ) + + assert queue.wait_for_completion(timeout=10) + assert queue.active_uploads == 0 + + worker.close() + + worker._request_upload_url.assert_called_once() + worker._submit_attribute.assert_called_once() + worker._wait_for_completion.assert_called_once() + + upload_file.assert_called_once() + with pytest.raises(Empty): + errors_queue.get(timeout=1) + + +def test_file_does_not_exist(worker, queue, errors_queue): + queue.submit( + attribute_path="attr.txt", + local_path=Path("/does/not/exist"), + data=None, + target_path=None, + target_basename=None, + timestamp=datetime.now(), + ) + assert queue.wait_for_completion(timeout=10) + assert queue.active_uploads == 0 + + assert isinstance(errors_queue.get(timeout=1), FileNotFoundError) + + +def test_upload_error(worker, queue, errors_queue): + """Trigger an error in upload_file and check if the error is propagated to the errors_queue.""" + error = NeptuneScaleError(message="This error is expected to happen") + + with patch("neptune_scale.sync.files.worker.upload_file", Mock(side_effect=error)) as upload_file: + queue.submit( + attribute_path="attr.txt", + local_path=None, + data=b"", + target_path=None, + target_basename=None, + timestamp=datetime.now(), + ) + assert queue.wait_for_completion(timeout=10) + assert queue.active_uploads == 0 + + worker.close() + + worker._request_upload_url.assert_called_once() + worker._submit_attribute.assert_not_called() + worker._wait_for_completion.assert_not_called() + + upload_file.assert_called_once() + assert errors_queue.get(timeout=1).message == error.message diff --git a/tests/unit/test_log_file.py b/tests/unit/test_log_file.py new file mode 100644 index 00000000..4ebb90f6 --- /dev/null +++ b/tests/unit/test_log_file.py @@ -0,0 +1,80 @@ +import tempfile +from pathlib import Path +from unittest.mock import Mock + +import pytest +from pytest import fixture + +from neptune_scale import Run +from neptune_scale.sync.parameters import MAX_FILE_UPLOAD_BUFFER_SIZE + + +@fixture +def run(api_token): + run = Run(project="workspace/project", api_token=api_token, run_id="run_id", mode="disabled") + run._attr_store.upload_file = Mock() + + return run + + +def test_data_and_path_arguments(run): + with pytest.raises(ValueError) as exc: + run.log_file("file.txt") + + exc.match("Either `path` or `data` must be provided") + + with pytest.raises(ValueError) as exc: + run.log_file("file.txt", data=b"", path="/some/file.txt") + + exc.match("Only one of `path` or `data` can be provided") + + +def test_too_large_data(run): + with pytest.raises(ValueError) as exc: + run.log_file("file.txt", data=b"a" * MAX_FILE_UPLOAD_BUFFER_SIZE + b"foo") + + exc.match("must not exceed") + + +def test_file_upload_not_a_file(run): + with pytest.raises(ValueError) as exc, tempfile.TemporaryDirectory() as temp_dir: + run.log_file("file.txt", path=temp_dir) + + exc.match("is not a file") + + +def test_file_upload_file_does_not_exist(run): + with pytest.raises(FileNotFoundError) as exc: + run.log_file("file.txt", path="/does/not/exist") + + exc.match("does not exist") + + +def test_file_upload_with_data(run): + run.log_file("file.txt", data=b"foo") + + run._attr_store.upload_file.assert_called_once_with( + attribute_path="file.txt", + data=b"foo", + local_path=None, + target_basename=None, + target_path=None, + timestamp=None, + ) + + +def test_file_upload_with_local_file(run): + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(b"foo") + temp_file.flush() + + run.log_file("file.txt", path=temp_file.name) + + run._attr_store.upload_file.assert_called_once_with( + attribute_path="file.txt", + data=None, + local_path=Path(temp_file.name), + target_basename=None, + target_path=None, + timestamp=None, + ) diff --git a/tests/unit/test_shared_var.py b/tests/unit/test_shared_var.py index d4cf9e07..31d92e16 100644 --- a/tests/unit/test_shared_var.py +++ b/tests/unit/test_shared_var.py @@ -13,7 +13,7 @@ def _child(var): var.value = 1 var.notify_all() - var.wait(timeout=1) + assert var.wait(timeout=10) assert var.value == 2 @@ -24,7 +24,7 @@ def test_set_and_notify(tp): process = Process(target=_child, args=(var,)) process.start() - var.wait(timeout=1) + assert var.wait(timeout=10) assert var.value == 1 with var: