diff --git a/nominal/core/connection.py b/nominal/core/connection.py index bb335449..326be660 100644 --- a/nominal/core/connection.py +++ b/nominal/core/connection.py @@ -3,6 +3,7 @@ import itertools import logging from dataclasses import dataclass, field +from datetime import timedelta from itertools import groupby from typing import Iterable, Mapping, Protocol, Sequence, cast @@ -18,7 +19,7 @@ from nominal.core._clientsbunch import HasAuthHeader from nominal.core._utils import HasRid from nominal.core.channel import Channel -from nominal.core.stream import BatchItem, NominalWriteStream +from nominal.core.stream import BatchItem, WriteStream from nominal.ts import _SecondsNanos @@ -137,19 +138,33 @@ def get_channel(self, name: str, tags: dict[str, str] | None = None) -> Channel: series = self._clients.logical_series.get_logical_series(self._clients.auth_header, resolved_series.rid) return Channel._from_conjure_logicalseries_api(self._clients, series) - def get_nominal_write_stream(self, batch_size: int = 10, max_wait_sec: int = 5) -> NominalWriteStream: - """Nominal Stream to write non-blocking messages to a datasource. + def get_nominal_write_stream(self, batch_size: int = 10, max_wait_sec: int = 5) -> WriteStream: + """get_nominal_write_stream is deprecated and will be removed in a future version, + use get_write_stream instead. + """ + import warnings + + warnings.warn( + "get_nominal_write_stream is deprecated and will be removed in a future version," + "use get_write_stream instead.", + UserWarning, + stacklevel=2, + ) + return self.get_write_stream(batch_size, timedelta(seconds=max_wait_sec)) + + def get_write_stream(self, batch_size: int = 10, max_wait: timedelta = timedelta(seconds=5)) -> WriteStream: + """Stream to write non-blocking messages to a datasource. Args: ---- batch_size (int): How big the batch can get before writing to Nominal. Default 10 - max_wait_sec (int): How long a batch can exist before being flushed to Nominal. Default 5 + max_wait (timedelta): How long a batch can exist before being flushed to Nominal. Default 5 seconds Examples: -------- Standard Usage: ```py - with connection.get_nominal_write_stream() as stream: + with connection.get_write_stream() as stream: stream.enqueue("my_channel_name", "2021-01-01T00:00:00Z", 42.0) stream.enqueue("my_channel_name2", "2021-01-01T00:00:01Z", 43.0, {"tag1": "value1"}) ... @@ -157,7 +172,7 @@ def get_nominal_write_stream(self, batch_size: int = 10, max_wait_sec: int = 5) Without a context manager: ```py - stream = connection.get_nominal_write_stream() + stream = connection.get_write_stream() stream.enqueue("my_channel_name", "2021-01-01T00:00:00Z", 42.0) stream.enqueue("my_channel_name2", "2021-01-01T00:00:01Z", 43.0, {"tag1": "value1"}) ... @@ -166,7 +181,7 @@ def get_nominal_write_stream(self, batch_size: int = 10, max_wait_sec: int = 5) """ if self._nominal_data_source_rid is not None: - return NominalWriteStream(self._process_batch, batch_size, max_wait_sec) + return WriteStream.create(batch_size, max_wait, self._process_batch) else: raise ValueError("Writing not implemented for this connection type") diff --git a/nominal/core/stream.py b/nominal/core/stream.py index c4586855..a4b9b977 100644 --- a/nominal/core/stream.py +++ b/nominal/core/stream.py @@ -4,13 +4,28 @@ import logging import threading import time +import warnings from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timedelta from types import TracebackType -from typing import Callable, Dict, Sequence, Type +from typing import Any, Callable, Sequence, Type + +from typing_extensions import Self from nominal.ts import IntegralNanosecondsUTC + +def __getattr__(name: str) -> Any: + if name == "NominalWriteStream": + warnings.warn( + "NominalWriteStream is deprecated, use WriteStream instead", + DeprecationWarning, + stacklevel=2, + ) + return WriteStream + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + logger = logging.getLogger(__name__) @@ -19,31 +34,44 @@ class BatchItem: channel_name: str timestamp: str | datetime | IntegralNanosecondsUTC value: float | str - tags: Dict[str, str] | None = None + tags: dict[str, str] | None = None -class NominalWriteStream: - def __init__( - self, +@dataclass(frozen=True) +class WriteStream: + batch_size: int + max_wait: timedelta + _process_batch: Callable[[Sequence[BatchItem]], None] + _executor: concurrent.futures.ThreadPoolExecutor + _thread_safe_batch: ThreadSafeBatch + _stop: threading.Event + _pending_jobs: threading.BoundedSemaphore + + @classmethod + def create( + cls, + batch_size: int, + max_wait: timedelta, process_batch: Callable[[Sequence[BatchItem]], None], - batch_size: int = 10, - max_wait_sec: int = 5, - max_workers: int | None = None, - ): + ) -> Self: """Create the stream.""" - self._process_batch = process_batch - self.batch_size = batch_size - self.max_wait_sec = max_wait_sec - self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) - self._batch: list[BatchItem] = [] - self._batch_lock = threading.Lock() - self._last_batch_time = time.time() - self._running = True + executor = concurrent.futures.ThreadPoolExecutor() + + instance = cls( + batch_size, + max_wait, + process_batch, + executor, + ThreadSafeBatch(), + threading.Event(), + threading.BoundedSemaphore(3), + ) - self._timeout_thread = threading.Thread(target=self._process_timeout_batches, daemon=True) - self._timeout_thread.start() + executor.submit(instance._process_timeout_batches) - def __enter__(self) -> "NominalWriteStream": + return instance + + def __enter__(self) -> WriteStream: """Create the stream as a context manager.""" return self @@ -58,7 +86,7 @@ def enqueue( channel_name: str, timestamp: str | datetime | IntegralNanosecondsUTC, value: float | str, - tags: Dict[str, str] | None = None, + tags: dict[str, str] | None = None, ) -> None: """Add a message to the queue. @@ -71,7 +99,7 @@ def enqueue_batch( channel_name: str, timestamps: Sequence[str | datetime | IntegralNanosecondsUTC], values: Sequence[float | str], - tags: Dict[str, str] | None = None, + tags: dict[str, str] | None = None, ) -> None: """Add a sequence of messages to the queue. @@ -82,56 +110,68 @@ def enqueue_batch( f"Expected equal numbers of timestamps and values! Received: {len(timestamps)} vs. {len(values)}" ) - with self._batch_lock: - for timestamp, value in zip(timestamps, values): - self._batch.append(BatchItem(channel_name, timestamp, value, tags)) - - if len(self._batch) >= self.batch_size: - self.flush() + self._thread_safe_batch.add( + [BatchItem(channel_name, timestamp, value, tags) for timestamp, value in zip(timestamps, values)] + ) + self._flush(condition=lambda size: size >= self.batch_size) - def flush(self, wait: bool = False, timeout: float | None = None) -> None: - """Flush current batch of records to nominal in a background thread. - - Args: - ---- - wait: If true, wait for the batch to complete uploading before returning - timeout: If wait is true, the time to wait for flush completion. - NOTE: If none, waits indefinitely. + def _flush(self, condition: Callable[[int], bool] | None = None) -> concurrent.futures.Future[None] | None: + batch = self._thread_safe_batch.swap(condition) - """ - if not self._batch: + if batch is None: + return None + if not batch: logger.debug("Not flushing... no enqueued batch") - return + return None + + self._pending_jobs.acquire() def process_future(fut: concurrent.futures.Future) -> None: # type: ignore[type-arg] """Callback to print errors to the console if a batch upload fails.""" + self._pending_jobs.release() maybe_ex = fut.exception() if maybe_ex is not None: logger.error("Batched upload task failed with exception", exc_info=maybe_ex) else: logger.debug("Batched upload task succeeded") - logger.debug(f"Starting flush with {len(self._batch)} records") - future = self._executor.submit(self._process_batch, self._batch) + logger.debug(f"Starting flush with {len(batch)} records") + future = self._executor.submit(self._process_batch, batch) future.add_done_callback(process_future) + return future - # Clear metadata - self._batch = [] - self._last_batch_time = time.time() + def flush(self, wait: bool = False, timeout: float | None = None) -> None: + """Flush current batch of records to nominal in a background thread. + + Args: + ---- + wait: If true, wait for the batch to complete uploading before returning + timeout: If wait is true, the time to wait for flush completion. + NOTE: If none, waits indefinitely. + + """ + future = self._flush() # Synchronously wait, if requested - if wait: + if wait and future is not None: # Warn user if timeout is too short _, pending = concurrent.futures.wait([future], timeout) if pending: logger.warning("Upload task still pending after flushing batch... increase timeout or setting to None") def _process_timeout_batches(self) -> None: - while self._running: - time.sleep(self.max_wait_sec / 10) - with self._batch_lock: - if self._batch and (time.time() - self._last_batch_time) >= self.max_wait_sec: - self.flush() + while not self._stop.is_set(): + now = time.time() + + last_batch_time = self._thread_safe_batch.last_time + timeout = max(self.max_wait.seconds - (now - last_batch_time), 0) + self._stop.wait(timeout=timeout) + + # check if flush has been called in the mean time + if self._thread_safe_batch.last_time > last_batch_time: + continue + + self._flush() def close(self, wait: bool = True) -> None: """Close the Nominal Stream. @@ -139,10 +179,38 @@ def close(self, wait: bool = True) -> None: Stop the process timeout thread Flush any remaining batches """ - self._running = False - self._timeout_thread.join() + self._stop.set() - with self._batch_lock: - self.flush() + self._flush() self._executor.shutdown(wait=wait, cancel_futures=not wait) + + +class ThreadSafeBatch: + def __init__(self) -> None: + """Thread-safe access to batch and last swap time.""" + self._batch: list[BatchItem] = [] + self._last_time = time.time() + self._lock = threading.Lock() + + def swap(self, condition: Callable[[int], bool] | None = None) -> list[BatchItem] | None: + """Swap the current batch with an empty one and return the old batch. + + If condition is provided, the swap will only occur if the condition is met, otherwise None is returned. + """ + with self._lock: + if condition and not condition(len(self._batch)): + return None + batch = self._batch + self._batch = [] + self._last_time = time.time() + return batch + + def add(self, items: Sequence[BatchItem]) -> None: + with self._lock: + self._batch.extend(items) + + @property + def last_time(self) -> float: + with self._lock: + return self._last_time