diff --git a/.github/actions/install-package/action.yml b/.github/actions/install-package/action.yml new file mode 100644 index 00000000..050c8b59 --- /dev/null +++ b/.github/actions/install-package/action.yml @@ -0,0 +1,28 @@ +--- +name: Package +description: Install python and package +inputs: + python-version: + description: "Python version" + required: true + os: + description: "Operating system" + required: true + +runs: + using: "composite" + steps: + - name: Install Python ${{ inputs.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ inputs.python-version }} + + - name: Install dependencies + run: | + pip install -r dev_requirements.txt + shell: bash + + - name: List dependencies + run: | + pip list + shell: bash diff --git a/.github/actions/test-unit/action.yml b/.github/actions/test-unit/action.yml new file mode 100644 index 00000000..b3c5c79c --- /dev/null +++ b/.github/actions/test-unit/action.yml @@ -0,0 +1,47 @@ +--- +name: Test Unit +description: Check unit tests +inputs: + python-version: + description: "Python version" + required: true + os: + description: "Operating system" + required: true + report_job: + description: "Job name to update by JUnit report" + required: true + +runs: + using: "composite" + steps: + - name: Install package + uses: ./.github/actions/install-package + with: + python-version: ${{ inputs.python-version }} + os: ${{ inputs.os }}-latest + + - name: Test + run: | + pytest -v ./tests/unit/ \ + --timeout=120 --timeout_method=thread \ + --color=yes \ + --junitxml="./test-results/test-unit-new-${{ inputs.os }}-${{ inputs.python-version }}.xml" + shell: bash + + - name: Upload test reports + uses: actions/upload-artifact@v3 + if: always() + with: + name: test-artifacts + path: ./test-results + + - name: Report + uses: mikepenz/action-junit-report@v3.6.2 + if: always() + with: + report_paths: './test-results/test-unit-*.xml' + update_check: true + include_passed: true + annotate_notice: true + job_name: ${{ inputs.report_job }} diff --git a/.github/workflows/unit-in-pull-request.yml b/.github/workflows/unit-in-pull-request.yml new file mode 100644 index 00000000..efa0a6a1 --- /dev/null +++ b/.github/workflows/unit-in-pull-request.yml @@ -0,0 +1,30 @@ +name: Unittests + +on: + workflow_dispatch: + push: + branches-ignore: + - main + +jobs: + test: + timeout-minutes: 75 + strategy: + fail-fast: false + matrix: + os: [ubuntu, windows, macos] + python-version: ["3.8"] + name: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' + runs-on: ${{ matrix.os }}-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Run tests + uses: ./.github/actions/test-unit + with: + python-version: ${{ matrix.python-version }} + os: ${{ matrix.os }} + report_job: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml new file mode 100644 index 00000000..efff71a4 --- /dev/null +++ b/.github/workflows/unit.yml @@ -0,0 +1,35 @@ +name: unit + +on: + workflow_call: + workflow_dispatch: + schedule: + - cron: "0 4 * * *" # Run every day at arbitrary time (4:00 AM UTC) + push: + branches: + - main + +jobs: + test: + timeout-minutes: 75 + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + os: [ubuntu, windows, macos] + + name: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' + runs-on: ${{ matrix.os }}-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.client_payload.pull_request.head.ref }} + + - name: Run tests + uses: ./.github/actions/test-unit + with: + python-version: ${{ matrix.python-version }} + os: ${{ matrix.os }} + report_job: 'test (${{ matrix.os }} - py${{ matrix.python-version }})' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86b15790..4ac066f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,7 @@ repos: args: [ --config-file, pyproject.toml ] pass_filenames: false additional_dependencies: - - neptune-api==0.3.0 + - neptune-api==0.4.0 + - more-itertools default_language_version: python: python3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 801acbe6..99ab9add 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,18 @@ -## [UNRELEASED] neptune-client-scale 0.1.0 +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added +- Added minimal Run classes ([#6](https://github.com/neptune-ai/neptune-client-scale/pull/6)) +- Added support for `max_queue_size` and `max_queue_size_exceeded_callback` parameters in `Run` ([#7](https://github.com/neptune-ai/neptune-client-scale/pull/7)) +- Added support for logging metadata ([#8](https://github.com/neptune-ai/neptune-client-scale/pull/8)) +- Added support for `creation_time` ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for Forking ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for Experiments ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for Run resume ([#9](https://github.com/neptune-ai/neptune-client-scale/pull/9)) +- Added support for env variables for project and api token ([#11](https://github.com/neptune-ai/neptune-client-scale/pull/11)) diff --git a/dev_requirements.txt b/dev_requirements.txt index a19f8a59..73ee4cb9 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -2,3 +2,6 @@ # dev pre-commit +pytest +pytest-timeout +freezegun diff --git a/pyproject.toml b/pyproject.toml index 0ac2ca72..f5fab8f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,8 @@ pattern = "default-unprefixed" [tool.poetry.dependencies] python = "^3.8" -# Networking -neptune-api = "0.3.0" +neptune-api = "0.4.0" +more-itertools = "^10.0.0" [tool.poetry] name = "neptune-client-scale" diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index e69de29b..2356e5ac 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -0,0 +1,273 @@ +""" +Python package +""" + +from __future__ import annotations + +__all__ = ["Run"] + +import os +import threading +from contextlib import AbstractContextManager +from datetime import datetime +from typing import Callable + +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.api.api_client import ApiClient +from neptune_scale.core.components.abstract import ( + Resource, + WithResources, +) +from neptune_scale.core.components.errors_monitor import ErrorsMonitor +from neptune_scale.core.components.errors_queue import ErrorsQueue +from neptune_scale.core.components.operations_queue import OperationsQueue +from neptune_scale.core.metadata_splitter import MetadataSplitter +from neptune_scale.core.serialization import ( + datetime_to_proto, + make_step, +) +from neptune_scale.core.validation import ( + verify_collection_type, + verify_max_length, + verify_non_empty, + verify_project_qualified_name, + verify_type, +) +from neptune_scale.envs import ( + API_TOKEN_ENV_NAME, + PROJECT_ENV_NAME, +) +from neptune_scale.parameters import ( + MAX_FAMILY_LENGTH, + MAX_QUEUE_SIZE, + MAX_RUN_ID_LENGTH, +) + + +class Run(WithResources, AbstractContextManager): + """ + Representation of tracked metadata. + """ + + def __init__( + self, + *, + family: str, + run_id: str, + project: str | None = None, + api_token: str | None = None, + resume: bool = False, + as_experiment: str | None = None, + creation_time: datetime | None = None, + from_run_id: str | None = None, + from_step: int | float | None = None, + max_queue_size: int = MAX_QUEUE_SIZE, + max_queue_size_exceeded_callback: Callable[[int, BaseException], None] | None = None, + ) -> None: + """ + Initializes a run that logs the model-building metadata to Neptune. + + Args: + family: Identifies related runs. For example, the same value must apply to all runs within a run hierarchy. + Max length: 128 characters. + run_id: Unique identifier of a run. Must be unique within the project. Max length: 128 characters. + project: Name of the project where the metadata is logged, in the form `workspace-name/project-name`. + If not provided, the value of the `NEPTUNE_PROJECT` environment variable is used. + api_token: Your Neptune API token. If not provided, the value of the `NEPTUNE_API_TOKEN` environment + variable is used. + resume: Whether to resume an existing run. + as_experiment: If creating a run as an experiment, ID of an experiment to be associated with the run. + creation_time: Custom creation time of the run. + from_run_id: If forking from an existing run, ID of the run to fork from. + from_step: If forking from an existing run, step number to fork from. + max_queue_size: Maximum number of operations in a queue. + max_queue_size_exceeded_callback: Callback function triggered when a queue is full. + Accepts two arguments: + - Maximum size of the queue. + - Exception that made the queue full. + """ + verify_type("family", family, str) + verify_type("run_id", run_id, str) + verify_type("resume", resume, bool) + verify_type("project", project, (str, type(None))) + verify_type("api_token", api_token, (str, type(None))) + verify_type("as_experiment", as_experiment, (str, type(None))) + verify_type("creation_time", creation_time, (datetime, type(None))) + verify_type("from_run_id", from_run_id, (str, type(None))) + verify_type("from_step", from_step, (int, float, type(None))) + verify_type("max_queue_size", max_queue_size, int) + verify_type("max_queue_size_exceeded_callback", max_queue_size_exceeded_callback, (Callable, type(None))) + + if resume and creation_time is not None: + raise ValueError("`resume` and `creation_time` cannot be used together.") + if resume and as_experiment is not None: + raise ValueError("`resume` and `as_experiment` cannot be used together.") + if (from_run_id is not None and from_step is None) or (from_run_id is None and from_step is not None): + raise ValueError("`from_run_id` and `from_step` must be used together.") + if resume and from_run_id is not None: + raise ValueError("`resume` and `from_run_id` cannot be used together.") + if resume and from_step is not None: + raise ValueError("`resume` and `from_step` cannot be used together.") + + if max_queue_size < 1: + raise ValueError("`max_queue_size` must be greater than 0.") + + project = project or os.environ.get(PROJECT_ENV_NAME) + verify_non_empty("project", project) + assert project is not None # mypy + input_project: str = project + + api_token = api_token or os.environ.get(API_TOKEN_ENV_NAME) + verify_non_empty("api_token", api_token) + assert api_token is not None # mypy + input_api_token: str = api_token + + verify_non_empty("family", family) + verify_non_empty("run_id", run_id) + if as_experiment is not None: + verify_non_empty("as_experiment", as_experiment) + if from_run_id is not None: + verify_non_empty("from_run_id", from_run_id) + + verify_project_qualified_name("project", project) + + verify_max_length("family", family, MAX_FAMILY_LENGTH) + verify_max_length("run_id", run_id, MAX_RUN_ID_LENGTH) + + self._project: str = input_project + self._family: str = family + self._run_id: str = run_id + + self._lock = threading.RLock() + self._operations_queue: OperationsQueue = OperationsQueue( + lock=self._lock, max_size=max_queue_size, max_size_exceeded_callback=max_queue_size_exceeded_callback + ) + self._errors_queue: ErrorsQueue = ErrorsQueue() + self._errors_monitor = ErrorsMonitor(errors_queue=self._errors_queue) + self._backend: ApiClient = ApiClient(api_token=input_api_token) + + self._errors_monitor.start() + + if not resume: + self._create_run( + creation_time=datetime.now() if creation_time is None else creation_time, + as_experiment=as_experiment, + from_run_id=from_run_id, + from_step=from_step, + ) + + def __enter__(self) -> Run: + return self + + @property + def resources(self) -> tuple[Resource, ...]: + return ( + self._operations_queue, + self._backend, + self._errors_monitor, + self._errors_queue, + ) + + def close(self) -> None: + """ + Stops the connection to Neptune and synchronizes all data. + """ + super().close() + + def _create_run( + self, + creation_time: datetime, + as_experiment: str | None, + from_run_id: str | None, + from_step: int | float | None, + ) -> None: + fork_point: ForkPoint | None = None + if from_run_id is not None and from_step is not None: + fork_point = ForkPoint( + parent_project=self._project, parent_run_id=from_run_id, step=make_step(number=from_step) + ) + + operation = RunOperation( + project=self._project, + run_id=self._run_id, + create=CreateRun( + family=self._family, + fork_point=fork_point, + experiment_id=as_experiment, + creation_time=None if creation_time is None else datetime_to_proto(creation_time), + ), + ) + self._backend.submit(operation=operation, family=self._family) + # TODO: Enqueue on the operations queue + # self._operations_queue.enqueue(operation=operation) + + def log( + self, + step: float | int | None = None, + timestamp: datetime | None = None, + fields: dict[str, float | bool | int | str | datetime | list | set] | None = None, + metrics: dict[str, float] | None = None, + add_tags: dict[str, list[str] | set[str]] | None = None, + remove_tags: dict[str, list[str] | set[str]] | None = None, + ) -> None: + """ + Logs the specified metadata to Neptune. + + Args: + step: Index of the log entry, must be increasing. If None, the highest of the already logged indexes is used. + timestamp: Time of logging the metadata. + fields: Dictionary of fields to log. + metrics: Dictionary of metrics to log. + add_tags: Dictionary of tags to add to the run. + remove_tags: Dictionary of tags to remove from the run. + + Examples: + ``` + >>> with Run(...) as run: + ... run.log(step=1, fields={"parameters/learning_rate": 0.001}) + ... run.log(step=2, add_tags={"sys/group_tags": ["group1", "group2"]}) + ... run.log(step=3, metrics={"metrics/loss": 0.1}) + ``` + + """ + verify_type("step", step, (float, int, type(None))) + verify_type("timestamp", timestamp, (datetime, type(None))) + verify_type("fields", fields, (dict, type(None))) + verify_type("metrics", metrics, (dict, type(None))) + verify_type("add_tags", add_tags, (dict, type(None))) + verify_type("remove_tags", remove_tags, (dict, type(None))) + + timestamp = datetime.now() if timestamp is None else timestamp + fields = {} if fields is None else fields + metrics = {} if metrics is None else metrics + add_tags = {} if add_tags is None else add_tags + remove_tags = {} if remove_tags is None else remove_tags + + verify_collection_type("`fields` keys", list(fields.keys()), str) + verify_collection_type("`metrics` keys", list(metrics.keys()), str) + verify_collection_type("`add_tags` keys", list(add_tags.keys()), str) + verify_collection_type("`remove_tags` keys", list(remove_tags.keys()), str) + + verify_collection_type("`fields` values", list(fields.values()), (float, bool, int, str, datetime, list, set)) + verify_collection_type("`metrics` values", list(metrics.values()), float) + verify_collection_type("`add_tags` values", list(add_tags.values()), (list, set)) + verify_collection_type("`remove_tags` values", list(remove_tags.values()), (list, set)) + + splitter: MetadataSplitter = MetadataSplitter( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + ) + + for operation in splitter: + self._backend.submit(operation=operation, family=self._family) + # TODO: Enqueue on the operations queue + # self._operations_queue.enqueue(operation=operation) diff --git a/src/neptune_scale/api/__init__.py b/src/neptune_scale/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neptune_scale/api/api_client.py b/src/neptune_scale/api/api_client.py new file mode 100644 index 00000000..80a15d31 --- /dev/null +++ b/src/neptune_scale/api/api_client.py @@ -0,0 +1,87 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +__all__ = ["ApiClient"] + + +from dataclasses import dataclass + +from neptune_api import ( + AuthenticatedClient, + Client, +) +from neptune_api.api.backend import get_client_config +from neptune_api.api.data_ingestion import submit_operation +from neptune_api.auth_helpers import exchange_api_key +from neptune_api.credentials import Credentials +from neptune_api.models import ( + ClientConfig, + Error, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.components.abstract import Resource + + +class ApiClient(Resource): + def __init__(self, api_token: str) -> None: + credentials = Credentials.from_api_key(api_key=api_token) + config, token_urls = get_config_and_token_urls(credentials=credentials) + self._backend = create_auth_api_client(credentials=credentials, config=config, token_refreshing_urls=token_urls) + + def submit(self, operation: RunOperation, family: str) -> None: + _ = submit_operation.sync(client=self._backend, family=family, body=operation) + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self._backend.__exit__() + + +@dataclass +class TokenRefreshingURLs: + authorization_endpoint: str + token_endpoint: str + + @classmethod + def from_dict(cls, data: dict) -> TokenRefreshingURLs: + return TokenRefreshingURLs( + authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"] + ) + + +def get_config_and_token_urls(*, credentials: Credentials) -> tuple[ClientConfig, TokenRefreshingURLs]: + with Client(base_url=credentials.base_url) as client: + config = get_client_config.sync(client=client) + if config is None or isinstance(config, Error): + raise RuntimeError(f"Failed to get client config: {config}") + response = client.get_httpx_client().get(config.security.open_id_discovery) + token_urls = TokenRefreshingURLs.from_dict(response.json()) + return config, token_urls + + +def create_auth_api_client( + *, credentials: Credentials, config: ClientConfig, token_refreshing_urls: TokenRefreshingURLs +) -> AuthenticatedClient: + return AuthenticatedClient( + base_url=credentials.base_url, + credentials=credentials, + client_id=config.security.client_id, + token_refreshing_endpoint=token_refreshing_urls.token_endpoint, + api_key_exchange_callback=exchange_api_key, + ) diff --git a/src/neptune_scale/core/__init__.py b/src/neptune_scale/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neptune_scale/core/components/__init__.py b/src/neptune_scale/core/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/neptune_scale/core/components/abstract.py b/src/neptune_scale/core/components/abstract.py new file mode 100644 index 00000000..00242fa5 --- /dev/null +++ b/src/neptune_scale/core/components/abstract.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from abc import ( + ABC, + abstractmethod, +) +from types import TracebackType + + +class AutoCloseable(ABC): + def __enter__(self) -> AutoCloseable: + return self + + @abstractmethod + def close(self) -> None: ... + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + +class Resource(AutoCloseable): + @abstractmethod + def cleanup(self) -> None: ... + + def flush(self) -> None: + pass + + def close(self) -> None: + self.flush() + + +class WithResources(Resource): + @property + @abstractmethod + def resources(self) -> tuple[Resource, ...]: ... + + def flush(self) -> None: + for resource in self.resources: + resource.flush() + + def close(self) -> None: + for resource in self.resources: + resource.close() + + def cleanup(self) -> None: + for resource in self.resources: + resource.cleanup() diff --git a/src/neptune_scale/core/components/daemon.py b/src/neptune_scale/core/components/daemon.py new file mode 100644 index 00000000..d2ef8713 --- /dev/null +++ b/src/neptune_scale/core/components/daemon.py @@ -0,0 +1,84 @@ +__all__ = ["Daemon"] + +import abc +import threading +from enum import Enum + + +class Daemon(threading.Thread): + class DaemonState(Enum): + INIT = 1 + WORKING = 2 + PAUSING = 3 + PAUSED = 4 + INTERRUPTED = 5 + STOPPED = 6 + + def __init__(self, sleep_time: float, name: str) -> None: + super().__init__(daemon=True, name=name) + self._sleep_time = sleep_time + self._state: Daemon.DaemonState = Daemon.DaemonState.INIT + self._wait_condition = threading.Condition() + + def interrupt(self) -> None: + with self._wait_condition: + self._state = Daemon.DaemonState.INTERRUPTED + self._wait_condition.notify_all() + + def pause(self) -> None: + with self._wait_condition: + if self._state != Daemon.DaemonState.PAUSED: + if not self._is_interrupted(): + self._state = Daemon.DaemonState.PAUSING + self._wait_condition.notify_all() + self._wait_condition.wait_for(lambda: self._state != Daemon.DaemonState.PAUSING) + + def resume(self) -> None: + with self._wait_condition: + if not self._is_interrupted(): + self._state = Daemon.DaemonState.WORKING + self._wait_condition.notify_all() + + def wake_up(self) -> None: + with self._wait_condition: + self._wait_condition.notify_all() + + def disable_sleep(self) -> None: + self._sleep_time = 0 + + def is_running(self) -> bool: + with self._wait_condition: + return self._state in ( + Daemon.DaemonState.WORKING, + Daemon.DaemonState.PAUSING, + Daemon.DaemonState.PAUSED, + ) + + def _is_interrupted(self) -> bool: + with self._wait_condition: + return self._state in (Daemon.DaemonState.INTERRUPTED, Daemon.DaemonState.STOPPED) + + def run(self) -> None: + with self._wait_condition: + if not self._is_interrupted(): + self._state = Daemon.DaemonState.WORKING + try: + while not self._is_interrupted(): + with self._wait_condition: + if self._state == Daemon.DaemonState.PAUSING: + self._state = Daemon.DaemonState.PAUSED + self._wait_condition.notify_all() + self._wait_condition.wait_for(lambda: self._state != Daemon.DaemonState.PAUSED) + + if self._state == Daemon.DaemonState.WORKING: + self.work() + with self._wait_condition: + if self._sleep_time > 0 and self._state == Daemon.DaemonState.WORKING: + self._wait_condition.wait(timeout=self._sleep_time) + finally: + with self._wait_condition: + self._state = Daemon.DaemonState.STOPPED + self._wait_condition.notify_all() + + @abc.abstractmethod + def work(self) -> None: ... diff --git a/src/neptune_scale/core/components/errors_monitor.py b/src/neptune_scale/core/components/errors_monitor.py new file mode 100644 index 00000000..dc9950be --- /dev/null +++ b/src/neptune_scale/core/components/errors_monitor.py @@ -0,0 +1,46 @@ +__all__ = ("ErrorsMonitor",) + +import logging +import queue +from typing import Callable + +from neptune_scale.core.components.abstract import Resource +from neptune_scale.core.components.daemon import Daemon +from neptune_scale.core.components.errors_queue import ErrorsQueue + +logger = logging.getLogger("neptune") +logger.setLevel(level=logging.INFO) + + +def on_error(error: BaseException) -> None: + logger.error(error) + + +class ErrorsMonitor(Daemon, Resource): + def __init__( + self, + errors_queue: ErrorsQueue, + on_error_callback: Callable[[BaseException], None] = on_error, + ): + super().__init__(name="ErrorsMonitor", sleep_time=2) + self._errors_queue = errors_queue + self._on_error_callback = on_error_callback + + def work(self) -> None: + try: + error = self._errors_queue.get(block=False) + if error is not None: + self._on_error_callback(error) + except KeyboardInterrupt: + with self._wait_condition: + self._wait_condition.notify_all() + raise + except queue.Empty: + pass + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self.interrupt() + self.join(timeout=10) diff --git a/src/neptune_scale/core/components/errors_queue.py b/src/neptune_scale/core/components/errors_queue.py new file mode 100644 index 00000000..33bdc38e --- /dev/null +++ b/src/neptune_scale/core/components/errors_queue.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +__all__ = ("ErrorsQueue",) + +from multiprocessing import Queue + +from neptune_scale.core.components.abstract import Resource + + +class ErrorsQueue(Resource): + def __init__(self) -> None: + self._errors_queue: Queue[BaseException] = Queue() + + def put(self, error: BaseException) -> None: + self._errors_queue.put(error) + + def get(self, block: bool = True, timeout: float | None = None) -> BaseException: + return self._errors_queue.get(block=block, timeout=timeout) + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self._errors_queue.close() diff --git a/src/neptune_scale/core/components/operations_queue.py b/src/neptune_scale/core/components/operations_queue.py new file mode 100644 index 00000000..bed30265 --- /dev/null +++ b/src/neptune_scale/core/components/operations_queue.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +__all__ = ("OperationsQueue",) + +from multiprocessing import Queue +from time import monotonic +from typing import ( + TYPE_CHECKING, + Callable, + NamedTuple, +) + +from neptune_scale.core.components.abstract import Resource +from neptune_scale.core.validation import verify_type +from neptune_scale.parameters import MAX_QUEUE_ELEMENT_SIZE + +if TYPE_CHECKING: + from threading import RLock + + from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + + +class QueueElement(NamedTuple): + sequence_id: int + occured_at: float + operation: bytes + + +def default_max_size_exceeded_callback(max_size: int, e: BaseException) -> None: + raise ValueError(f"Queue is full (max size: {max_size})") from e + + +class OperationsQueue(Resource): + def __init__( + self, + *, + lock: RLock, + max_size: int = 0, + max_size_exceeded_callback: Callable[[int, BaseException], None] | None = None, + ) -> None: + verify_type("max_size", max_size, int) + + self._lock: RLock = lock + self._max_size: int = max_size + self._max_size_exceeded_callback: Callable[[int, BaseException], None] = ( + max_size_exceeded_callback if max_size_exceeded_callback is not None else default_max_size_exceeded_callback + ) + + self._sequence_id: int = 0 + self._queue: Queue[QueueElement] = Queue(maxsize=max_size) + + def enqueue(self, *, operation: RunOperation) -> None: + try: + # TODO: This lock could be moved to the Run class + with self._lock: + serialized_operation = operation.SerializeToString() + + if len(serialized_operation) > MAX_QUEUE_ELEMENT_SIZE: + raise ValueError(f"Operation size exceeds the maximum allowed size ({MAX_QUEUE_ELEMENT_SIZE})") + + self._queue.put_nowait(QueueElement(self._sequence_id, monotonic(), serialized_operation)) + self._sequence_id += 1 + except Exception as e: + self._max_size_exceeded_callback(self._max_size, e) + + def cleanup(self) -> None: + pass + + def close(self) -> None: + self._queue.close() diff --git a/src/neptune_scale/core/metadata_splitter.py b/src/neptune_scale/core/metadata_splitter.py new file mode 100644 index 00000000..1aba2656 --- /dev/null +++ b/src/neptune_scale/core/metadata_splitter.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +__all__ = ("MetadataSplitter",) + +from datetime import datetime +from typing import ( + Any, + Callable, + Iterator, + TypeVar, +) + +from more_itertools import peekable +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.serialization import ( + datetime_to_proto, + make_step, + make_value, + pb_key_size, +) + +T = TypeVar("T", bound=Any) + + +class MetadataSplitter(Iterator[RunOperation]): + def __init__( + self, + *, + project: str, + run_id: str, + step: int | float | None, + timestamp: datetime, + fields: dict[str, float | bool | int | str | datetime | list | set], + metrics: dict[str, float], + add_tags: dict[str, list[str] | set[str]], + remove_tags: dict[str, list[str] | set[str]], + max_message_bytes_size: int = 1024 * 1024, + ): + self._step = None if step is None else make_step(number=step) + self._timestamp = datetime_to_proto(timestamp) + self._project = project + self._run_id = run_id + self._fields = peekable(fields.items()) + self._metrics = peekable(metrics.items()) + self._add_tags = peekable(add_tags.items()) + self._remove_tags = peekable(remove_tags.items()) + + self._max_update_bytes_size = ( + max_message_bytes_size + - RunOperation( + project=self._project, + run_id=self._run_id, + update=UpdateRunSnapshot(step=self._step, timestamp=self._timestamp), + ).ByteSize() + ) + + self._has_returned = False + + def __iter__(self) -> MetadataSplitter: + self._has_returned = False + return self + + def __next__(self) -> RunOperation: + size = 0 + update = UpdateRunSnapshot( + step=self._step, + timestamp=self._timestamp, + assign={}, + append={}, + modify_sets={}, + ) + + size = self.populate( + assets=self._fields, + update_producer=lambda key, value: update.assign[key].MergeFrom(value), + size=size, + ) + size = self.populate( + assets=self._metrics, + update_producer=lambda key, value: update.append[key].MergeFrom(value), + size=size, + ) + size = self.populate_tags( + update=update, + assets=self._add_tags, + operation=SET_OPERATION.ADD, + size=size, + ) + _ = self.populate_tags( + update=update, + assets=self._remove_tags, + operation=SET_OPERATION.REMOVE, + size=size, + ) + + if not self._has_returned or update.assign or update.append or update.modify_sets: + self._has_returned = True + return RunOperation(project=self._project, run_id=self._run_id, update=update) + else: + raise StopIteration + + def populate( + self, + assets: peekable[Any], + update_producer: Callable[[str, Value], None], + size: int, + ) -> int: + while size < self._max_update_bytes_size: + try: + key, value = assets.peek() + except StopIteration: + break + + proto_value = make_value(value) + new_size = size + pb_key_size(key) + proto_value.ByteSize() + 6 + + if new_size > self._max_update_bytes_size: + break + + update_producer(key, proto_value) + size, _ = new_size, next(assets) + + return size + + def populate_tags( + self, update: UpdateRunSnapshot, assets: peekable[Any], operation: SET_OPERATION.ValueType, size: int + ) -> int: + while size < self._max_update_bytes_size: + try: + key, values = assets.peek() + except StopIteration: + break + + if not isinstance(values, peekable): + values = peekable(values) + + is_full = False + new_size = size + pb_key_size(key) + 6 + for value in values: + tag_size = pb_key_size(value) + 6 + if new_size + tag_size > self._max_update_bytes_size: + values.prepend(value) + is_full = True + break + + update.modify_sets[key].string.values[value] = operation + new_size += tag_size + + size, _ = new_size, next(assets) + + if is_full: + assets.prepend((key, list(values))) + break + + return size diff --git a/src/neptune_scale/core/serialization.py b/src/neptune_scale/core/serialization.py new file mode 100644 index 00000000..0858d8bc --- /dev/null +++ b/src/neptune_scale/core/serialization.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +__all__ = ( + "make_value", + "make_step", + "datetime_to_proto", + "pb_key_size", +) + +from datetime import datetime + +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + Step, + StringSet, + Value, +) + + +def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: + if isinstance(value, Value): + return value + if isinstance(value, float): + return Value(float64=value) + elif isinstance(value, bool): + return Value(bool=value) + elif isinstance(value, int): + return Value(int64=value) + elif isinstance(value, str): + return Value(string=value) + elif isinstance(value, datetime): + return Value(timestamp=datetime_to_proto(value)) + elif isinstance(value, (list, set)): + fv = Value(string_set=StringSet(values=value)) + return fv + else: + raise ValueError(f"Unsupported ingest field value type: {type(value)}") + + +def datetime_to_proto(dt: datetime) -> Timestamp: + dt_ts = dt.timestamp() + return Timestamp(seconds=int(dt_ts), nanos=int((dt_ts % 1) * 1e9)) + + +def make_step(number: float | int, raise_on_step_precision_loss: bool = False) -> Step: + """ + Converts a number to protobuf Step value. Example: + >>> assert make_step(7.654321, True) == Step(whole=7, micro=654321) + Args: + number: step expressed as number + raise_on_step_precision_loss: inform converter whether it should silently drop precision and + round down to 6 decimal places or raise an error. + + Returns: Step protobuf used in Neptune API. + """ + m = int(1e6) + micro: int = int(number * m) + if raise_on_step_precision_loss and number * m - micro != 0: + raise ValueError(f"step must not use more than 6-decimal points, got: {number}") + + whole = micro // m + micro = micro % m + + return Step(whole=whole, micro=micro) + + +def pb_key_size(key: str) -> int: + """ + Calculates the size of the string in the protobuf message including an overhead of the length prefix (varint) + with an assumption of maximal string length. + """ + key_bin = bytes(key, "utf-8") + return len(key_bin) + 2 + (1 if len(key_bin) > 127 else 0) diff --git a/src/neptune_scale/core/validation.py b/src/neptune_scale/core/validation.py new file mode 100644 index 00000000..95a732e1 --- /dev/null +++ b/src/neptune_scale/core/validation.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +__all__ = ( + "verify_type", + "verify_non_empty", + "verify_max_length", + "verify_project_qualified_name", + "verify_collection_type", +) + +from typing import Any + + +def get_type_name(var_type: type | tuple) -> str: + return var_type.__name__ if hasattr(var_type, "__name__") else str(var_type) + + +def verify_type(var_name: str, var: Any, expected_type: type | tuple) -> None: + try: + if isinstance(expected_type, tuple): + type_name = " or ".join(get_type_name(t) for t in expected_type) + else: + type_name = get_type_name(expected_type) + except Exception as e: + # Just to be sure that nothing weird will be raised here + raise TypeError(f"Incorrect type of {var_name}") from e + + if not isinstance(var, expected_type): + raise TypeError(f"{var_name} must be a {type_name} (was {type(var)})") + + +def verify_non_empty(var_name: str, var: Any) -> None: + if not var: + raise ValueError(f"{var_name} must not be empty") + + +def verify_max_length(var_name: str, var: Any, max_length: int) -> None: + if len(var) > max_length: + raise ValueError(f"{var_name} must not exceed {max_length} characters") + + +def verify_project_qualified_name(var_name: str, var: Any) -> None: + verify_type(var_name, var, str) + verify_non_empty(var_name, var) + + project_parts = var.split("/") + if len(project_parts) != 2: + raise ValueError(f"{var_name} is not in expected format, should be 'workspace-name/project-name") + + +def verify_collection_type(var_name: str, var: list | set | tuple, expected_type: type | tuple) -> None: + verify_type(var_name, var, (list, set, tuple)) + + for value in var: + verify_type(f"elements of collection '{var_name}'", value, expected_type) diff --git a/src/neptune_scale/envs.py b/src/neptune_scale/envs.py new file mode 100644 index 00000000..02681d9b --- /dev/null +++ b/src/neptune_scale/envs.py @@ -0,0 +1,3 @@ +PROJECT_ENV_NAME = "NEPTUNE_PROJECT" + +API_TOKEN_ENV_NAME = "NEPTUNE_API_TOKEN" diff --git a/src/neptune_scale/parameters.py b/src/neptune_scale/parameters.py new file mode 100644 index 00000000..44112374 --- /dev/null +++ b/src/neptune_scale/parameters.py @@ -0,0 +1,4 @@ +MAX_RUN_ID_LENGTH = 128 +MAX_FAMILY_LENGTH = 128 +MAX_QUEUE_SIZE = 32767 +MAX_QUEUE_ELEMENT_SIZE = 1024 * 1024 # 1MB diff --git a/tests/unit/test_errors_monitor.py b/tests/unit/test_errors_monitor.py new file mode 100644 index 00000000..e4352d7e --- /dev/null +++ b/tests/unit/test_errors_monitor.py @@ -0,0 +1,22 @@ +from unittest.mock import Mock + +from neptune_scale.core.components.errors_monitor import ErrorsMonitor +from neptune_scale.core.components.errors_queue import ErrorsQueue + + +def test_errors_monitor(): + # given + callback = Mock() + + # and + errors_queue = ErrorsQueue() + errors_monitor = ErrorsMonitor(errors_queue=errors_queue, on_error_callback=callback) + + # when + errors_queue.put(ValueError("error1")) + errors_monitor.start() + errors_monitor.interrupt() + errors_monitor.join(timeout=1) + + # then + callback.assert_called() diff --git a/tests/unit/test_metadata_splitter.py b/tests/unit/test_metadata_splitter.py new file mode 100644 index 00000000..4d842506 --- /dev/null +++ b/tests/unit/test_metadata_splitter.py @@ -0,0 +1,267 @@ +from datetime import datetime + +from freezegun import freeze_time +from google.protobuf.timestamp_pb2 import Timestamp +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + SET_OPERATION, + ModifySet, + ModifyStringSet, + Step, + StringSet, + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.metadata_splitter import MetadataSplitter + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_empty(): + # given + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot(step=Step(whole=1, micro=0), timestamp=Timestamp(seconds=1722341532, nanos=21934)), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_fields(): + # given + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={ + "some/string": "value", + "some/int": 2501, + "some/float": 3.14, + "some/bool": True, + "some/datetime": datetime.now(), + "some/tags": {"tag1", "tag2"}, + }, + metrics={}, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + assign={ + "some/string": Value(string="value"), + "some/int": Value(int64=2501), + "some/float": Value(float64=3.14), + "some/bool": Value(bool=True), + "some/datetime": Value(timestamp=Timestamp(seconds=1722341532, nanos=21934)), + "some/tags": Value(string_set=StringSet(values={"tag1", "tag2"})), + }, + ), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_metrics(): + # given + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={ + "some/metric": 3.14, + }, + add_tags={}, + remove_tags={}, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + append={ + "some/metric": Value(float64=3.14), + }, + ), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_tags(): + # given + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=datetime.now(), + fields={}, + metrics={}, + add_tags={ + "some/tags": {"tag1", "tag2"}, + "some/other_tags2": {"tag2", "tag3"}, + }, + remove_tags={ + "some/group_tags": {"tag0", "tag1"}, + "some/other_tags": {"tag2", "tag3"}, + }, + ) + + # when + result = list(builder) + + # then + assert len(result) == 1 + assert result[0] == RunOperation( + project="workspace/project", + run_id="run_id", + update=UpdateRunSnapshot( + step=Step(whole=1, micro=0), + timestamp=Timestamp(seconds=1722341532, nanos=21934), + modify_sets={ + "some/tags": ModifySet( + string=ModifyStringSet(values={"tag1": SET_OPERATION.ADD, "tag2": SET_OPERATION.ADD}) + ), + "some/other_tags2": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.ADD, "tag3": SET_OPERATION.ADD}) + ), + "some/group_tags": ModifySet( + string=ModifyStringSet(values={"tag0": SET_OPERATION.REMOVE, "tag1": SET_OPERATION.REMOVE}) + ), + "some/other_tags": ModifySet( + string=ModifyStringSet(values={"tag2": SET_OPERATION.REMOVE, "tag3": SET_OPERATION.REMOVE}) + ), + }, + ), + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_splitting(): + # given + max_size = 1024 + timestamp = datetime.now() + metrics = {f"metric{v}": 7 / 9.0 * v for v in range(1000)} + fields = {f"field{v}": v for v in range(1000)} + add_tags = {f"add/tag{v}": {f"value{v}"} for v in range(1000)} + remove_tags = {f"remove/tag{v}": {f"value{v}"} for v in range(1000)} + + # and + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + max_message_bytes_size=max_size, + ) + + # when + result = list(builder) + + # then + assert len(result) > 1 + + # Every message should be smaller than max_size + assert all(len(op.SerializeToString()) <= max_size for op in result) + + # Common metadata + assert all(op.project == "workspace/project" for op in result) + assert all(op.run_id == "run_id" for op in result) + assert all(op.update.step.whole == 1 for op in result) + assert all(op.update.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op in result) + + # Check if all metrics, fields and tags are present in the result + assert sorted([key for op in result for key in op.update.append.keys()]) == sorted(list(metrics.keys())) + assert sorted([key for op in result for key in op.update.assign.keys()]) == sorted(list(fields.keys())) + assert sorted([key for op in result for key in op.update.modify_sets.keys()]) == sorted( + list(add_tags.keys()) + list(remove_tags.keys()) + ) + + +@freeze_time("2024-07-30 12:12:12.000022") +def test_split_large_tags(): + # given + max_size = 1024 + timestamp = datetime.now() + metrics = {} + fields = {} + add_tags = {"add/tag": {f"value{v}" for v in range(1000)}} + remove_tags = {"remove/tag": {f"value{v}" for v in range(1000)}} + + # and + builder = MetadataSplitter( + project="workspace/project", + run_id="run_id", + step=1, + timestamp=timestamp, + fields=fields, + metrics=metrics, + add_tags=add_tags, + remove_tags=remove_tags, + max_message_bytes_size=max_size, + ) + + # when + result = list(builder) + + # then + assert len(result) > 1 + + # Every message should be smaller than max_size + assert all(len(op.SerializeToString()) <= max_size for op in result) + + # Common metadata + assert all(op.project == "workspace/project" for op in result) + assert all(op.run_id == "run_id" for op in result) + assert all(op.update.step.whole == 1 for op in result) + assert all(op.update.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op in result) + + # Check if all StringSet values are split correctly + assert set([key for op in result for key in op.update.modify_sets.keys()]) == set( + list(add_tags.keys()) + list(remove_tags.keys()) + ) + + # Check if all tags are present in the result + assert {tag for op in result for tag in op.update.modify_sets["add/tag"].string.values.keys()} == add_tags[ + "add/tag" + ] + assert {tag for op in result for tag in op.update.modify_sets["remove/tag"].string.values.keys()} == remove_tags[ + "remove/tag" + ] diff --git a/tests/unit/test_operations_queue.py b/tests/unit/test_operations_queue.py new file mode 100644 index 00000000..f7c4d59a --- /dev/null +++ b/tests/unit/test_operations_queue.py @@ -0,0 +1,63 @@ +import threading +from unittest.mock import MagicMock + +import pytest +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( + UpdateRunSnapshot, + Value, +) +from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation + +from neptune_scale.core.components.operations_queue import OperationsQueue + + +def test__enqueue(): + # given + lock = threading.RLock() + queue = OperationsQueue(lock=lock, max_size=0) + + # and + operation = RunOperation() + + # when + queue.enqueue(operation=operation) + + # then + assert queue._sequence_id == 1 + + # when + queue.enqueue(operation=operation) + + # then + assert queue._sequence_id == 2 + + +def test__max_queue_size_exceeded(): + # given + lock = threading.RLock() + callback = MagicMock() + queue = OperationsQueue(lock=lock, max_size=1, max_size_exceeded_callback=callback) + + # and + operation = RunOperation() + + # when + queue.enqueue(operation=operation) + queue.enqueue(operation=operation) + + # then + callback.assert_called_once() + + +def test__max_element_size_exceeded(): + # given + lock = threading.RLock() + queue = OperationsQueue(lock=lock, max_size=1) + + # and + snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string=("a" * 1024)) for i in range(1024)}) + operation = RunOperation(update=snapshot) + + # then + with pytest.raises(ValueError): + queue.enqueue(operation=operation) diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py new file mode 100644 index 00000000..a3cb8dc9 --- /dev/null +++ b/tests/unit/test_run.py @@ -0,0 +1,252 @@ +import base64 +import json +import uuid +from datetime import datetime +from unittest.mock import patch + +import pytest +from freezegun import freeze_time + +from neptune_scale import Run + + +@pytest.fixture(scope="session") +def api_token(): + return base64.b64encode(json.dumps({"api_address": "aa", "api_url": "bb"}).encode("utf-8")).decode("utf-8") + + +class MockedApiClient: + def __init__(self, *args, **kwargs) -> None: + pass + + def submit(self, operation, family) -> None: + pass + + def close(self) -> None: + pass + + def cleanup(self) -> None: + pass + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_context_manager(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_close(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # and + run = Run(project=project, api_token=api_token, family=family, run_id=run_id) + + # when + run.close() + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_family_too_long(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + + # and + family = "a" * 1000 + + # when + with pytest.raises(ValueError): + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_run_id_too_long(api_token): + # given + project = "workspace/project" + family = str(uuid.uuid4()) + + # and + run_id = "a" * 1000 + + # then + with pytest.raises(ValueError): + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_invalid_project_name(api_token): + # given + run_id = str(uuid.uuid4()) + family = run_id + + # and + project = "just-project" + + # then + with pytest.raises(ValueError): + with Run(project=project, api_token=api_token, family=family, run_id=run_id): + ... + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_metadata(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=1, + timestamp=datetime.now(), + fields={ + "int": 1, + "string": "test", + "float": 3.14, + "bool": True, + "datetime": datetime.now(), + }, + metrics={ + "metric": 1.0, + }, + add_tags={ + "tags": ["tag1"], + }, + remove_tags={ + "group_tags": ["tag2"], + }, + ) + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_log_without_step(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + timestamp=datetime.now(), + fields={ + "int": 1, + }, + ) + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_log_step_float(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=3.14, + timestamp=datetime.now(), + fields={ + "int": 1, + }, + ) + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_log_no_timestamp(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # then + with Run(project=project, api_token=api_token, family=family, run_id=run_id) as run: + run.log( + step=3.14, + fields={ + "int": 1, + }, + ) + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_resume(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id, resume=True): + ... + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +@freeze_time("2024-07-30 12:12:12.000022") +def test_creation_time(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id, creation_time=datetime.now()): + ... + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_assign_experiment(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run(project=project, api_token=api_token, family=family, run_id=run_id, as_experiment="experiment_id"): + ... + + # then + assert True + + +@patch("neptune_scale.ApiClient", MockedApiClient) +def test_forking(api_token): + # given + project = "workspace/project" + run_id = str(uuid.uuid4()) + family = run_id + + # when + with Run( + project=project, api_token=api_token, family=family, run_id=run_id, from_run_id="parent-run-id", from_step=3.14 + ): + ... + + # then + assert True