diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index 2613b49c..643f25c4 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -1,7 +1,11 @@ import functools import itertools +import threading import warnings -from datetime import datetime +from datetime import ( + datetime, + timezone, +) from typing import ( Any, Callable, @@ -16,6 +20,7 @@ cast, ) +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing from neptune_scale.sync.metadata_splitter import MetadataSplitter from neptune_scale.sync.operations_queue import OperationsQueue @@ -61,6 +66,11 @@ def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) self._run_id = run_id self._operations_queue = operations_queue self._attributes: Dict[str, Attribute] = {} + # Keep a list of path -> (last step, last value) mappings to detect non-increasing steps + # at call site. The backend will detect this error as well, but it's more convenient for the user + # to get the error as soon as possible. + self._metric_state: Dict[str, Tuple[float, float]] = {} + self._lock = threading.RLock() def __getitem__(self, path: str) -> "Attribute": path = cleanup_path(path) @@ -87,22 +97,41 @@ def log( ) -> None: if timestamp is None: timestamp = datetime.now() - elif isinstance(timestamp, float): - timestamp = datetime.fromtimestamp(timestamp) + elif isinstance(timestamp, (float, int)): + timestamp = datetime.fromtimestamp(timestamp, timezone.utc) + + with self._lock: + self._verify_and_update_metrics_state(step, metrics) + + # TODO: Move splitting into the worker process. Here we should just send messages as they are. + splitter: MetadataSplitter = MetadataSplitter( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + configs=configs, + metrics=metrics, + add_tags=tags_add, + remove_tags=tags_remove, + ) + + for operation, metadata_size in splitter: + self._operations_queue.enqueue(operation=operation, size=metadata_size, step=step) - splitter: MetadataSplitter = MetadataSplitter( - project=self._project, - run_id=self._run_id, - step=step, - timestamp=timestamp, - configs=configs, - metrics=metrics, - add_tags=tags_add, - remove_tags=tags_remove, - ) + def _verify_and_update_metrics_state(self, step: Optional[float], metrics: Optional[Dict[str, float]]) -> None: + """Check if step in provided metrics is increasing, raise `NeptuneSeriesStepNonIncreasing` if not.""" - for operation, metadata_size in splitter: - self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step) + if step is None or metrics is None: + return + + for metric, value in metrics.items(): + if (state := self._metric_state.get(metric)) is not None: + last_step, last_value = state + # Repeating a step is fine as long as the value does not change + if step < last_step or (step == last_step and value != last_value): + raise NeptuneSeriesStepNonIncreasing() + + self._metric_state[metric] = (step, value) class Attribute: @@ -130,7 +159,7 @@ def append( self, value: Union[Dict[str, Any], float], *, - step: Union[float, int], + step: Optional[Union[float, int]] = None, timestamp: Optional[Union[float, datetime]] = None, wait: bool = False, **kwargs: Any, diff --git a/src/neptune_scale/api/run.py b/src/neptune_scale/api/run.py index 921d85e1..5e053b83 100644 --- a/src/neptune_scale/api/run.py +++ b/src/neptune_scale/api/run.py @@ -403,7 +403,7 @@ def __setitem__(self, key: str, value: Any) -> None: def log_metrics( self, data: Dict[str, Union[float, int]], - step: Union[float, int], + step: Optional[Union[float, int]] = None, *, timestamp: Optional[datetime] = None, ) -> None: diff --git a/src/neptune_scale/sync/aggregating_queue.py b/src/neptune_scale/sync/aggregating_queue.py index a2adcafd..735c0b04 100644 --- a/src/neptune_scale/sync/aggregating_queue.py +++ b/src/neptune_scale/sync/aggregating_queue.py @@ -72,7 +72,7 @@ def get(self) -> BatchedOperations: start = time.monotonic() batch_operations: list[RunOperation] = [] - last_batch_key: Optional[float] = None + current_batch_step: Optional[float] = None batch_sequence_id: Optional[int] = None batch_timestamp: Optional[float] = None @@ -97,7 +97,7 @@ def get(self) -> BatchedOperations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) batch_operations.append(new_operation) - last_batch_key = element.batch_key + current_batch_step = element.step batch_bytes += len(element.operation) else: if not element.is_batchable: @@ -112,9 +112,26 @@ def get(self) -> BatchedOperations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - if element.batch_key != last_batch_key: + + # This is where we decide if we need to wrap up the current UpdateSnapshot and start a new one. + # This happens if the step changes, but also if it is None. + # On None, the backend will assign the next available step. This is why we cannot merge here, + # especially considering metrics, since we would overwrite them: + # + # log metric1=1.0, step=None + # log metric1=1.2, step=None + # + # After merging by step, we would end up with a single value (the most recent one). + # + # TODO: we could potentially keep merging until we encounter a metric already seen in this batch. + # Something to optimize in the future. Given the metrics: + # m1, m2, m3, m4, m1, m2, m3, ... + # we could batch up to m4 and close the batch when encountering m1, as long as steps are None + # We could also keep batching if there are no metrics in a given operation, although this would + # not be a common case. + if element.step is None or element.step != current_batch_step: batch_operations.append(new_operation) - last_batch_key = element.batch_key + current_batch_step = element.step else: merge_run_operation(batch_operations[-1], new_operation) batch_bytes += element.metadata_size diff --git a/src/neptune_scale/sync/operations_queue.py b/src/neptune_scale/sync/operations_queue.py index 0069b472..f1ec0df3 100644 --- a/src/neptune_scale/sync/operations_queue.py +++ b/src/neptune_scale/sync/operations_queue.py @@ -57,7 +57,7 @@ def last_timestamp(self) -> Optional[float]: with self._lock: return self._last_timestamp - def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Optional[float] = None) -> None: + def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, step: Optional[float] = None) -> None: try: is_metadata_update = operation.HasField("update") serialized_operation = operation.SerializeToString() @@ -75,7 +75,7 @@ def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: O operation=serialized_operation, metadata_size=size, is_batchable=is_metadata_update, - batch_key=key, + step=step, ), block=True, timeout=None, diff --git a/src/neptune_scale/sync/queue_element.py b/src/neptune_scale/sync/queue_element.py index 1c37ff11..da33eaed 100644 --- a/src/neptune_scale/sync/queue_element.py +++ b/src/neptune_scale/sync/queue_element.py @@ -26,5 +26,5 @@ class SingleOperation(NamedTuple): is_batchable: bool # Size of the metadata in the operation (without project, family, run_id etc.) metadata_size: Optional[int] - # Update metadata key - batch_key: Optional[float] + # Step provided by the user + step: Optional[float] diff --git a/src/neptune_scale/sync/sync_process.py b/src/neptune_scale/sync/sync_process.py index ed76337a..9230b431 100644 --- a/src/neptune_scale/sync/sync_process.py +++ b/src/neptune_scale/sync/sync_process.py @@ -420,6 +420,8 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]: def work(self) -> None: try: + # TODO: is there a point in serializing the data on AggregatingQueue? It does not move between processes, + # so we could just pass around instances of RunOperation while (operation := self.get_next()) is not None: sequence_id, timestamp, data = operation diff --git a/tests/e2e/test_log_and_fetch.py b/tests/e2e/test_log_and_fetch.py index fc132bf0..456be551 100644 --- a/tests/e2e/test_log_and_fetch.py +++ b/tests/e2e/test_log_and_fetch.py @@ -164,3 +164,66 @@ def test_single_non_finite_metric(value, sync_run, ro_run): path = unique_path("test_series/non_finite") sync_run.log_metrics(data={path: value}, step=1) assert path not in refresh(ro_run).field_names + + +@mark.parametrize("first_step", [None, 0, 10]) +def test_auto_step_with_initial_step(run, ro_run, first_step): + """Logging series values with step=None results in backend-side step assignment""" + + path = unique_path(f"test_series/auto_step_{first_step}") + + _, values = random_series() + + run.log_metrics(data={path: values[0]}, step=first_step) + for value in values[1:]: + run.log_metrics(data={path: value}) + + run.wait_for_processing() + + # Backend will assign steps starting from zero by default, + # so handle this test case properly + if first_step is None: + first_step = 0 + + df = ro_run[path].fetch_values() + assert df["step"].tolist() == [float(x) for x in list(range(first_step, first_step + len(values)))] + assert df["value"].tolist() == values + + +def test_auto_step_with_manual_increase(run, ro_run): + """Increase step manually at a single point in series, then use auto-step""" + + path = unique_path("test_series/auto_step_increase") + run.log_metrics(data={path: 1}) + run.log_metrics(data={path: 2}, step=10) + run.log_metrics(data={path: 3}) + + run.wait_for_processing() + + df = ro_run[path].fetch_values() + assert df["step"].tolist() == [0, 10, 11] + assert df["value"].tolist() == [1, 2, 3] + + +def test_auto_step_with_different_metrics(run, ro_run): + path1 = unique_path("test_series/auto_step_different_metrics1") + path2 = unique_path("test_series/auto_step_different_metrics2") + + run.log_metrics(data={path1: 1}) + run.log_metrics(data={path2: 1}, step=10) + + run.log_metrics(data={path1: 2}) + run.log_metrics(data={path2: 2}, step=20) + + run.log_metrics(data={path1: 3}, step=5) + run.log_metrics(data={path2: 3}) + + run.wait_for_processing() + + df1 = ro_run[path1].fetch_values() + assert df1["step"].tolist() == [0.0, 1.0, 5.0] + assert df1["value"].tolist() == [1, 2, 3] + + df2 = ro_run[path2].fetch_values() + assert df2["step"].tolist() == [10.0, 20.0, 21.0] + assert df2["value"].tolist() == [1, 2, 3] diff --git a/tests/unit/test_aggregating_queue.py b/tests/unit/test_aggregating_queue.py index d988f696..6ca605f5 100644 --- a/tests/unit/test_aggregating_queue.py +++ b/tests/unit/test_aggregating_queue.py @@ -32,7 +32,7 @@ def test__simple(): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, + step=None, ) # and @@ -60,7 +60,7 @@ def test__max_size_exceeded(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, + step=None, ) element2 = SingleOperation( sequence_id=2, @@ -68,7 +68,7 @@ def test__max_size_exceeded(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, + step=None, ) # and @@ -108,7 +108,7 @@ def test__batch_size_limit(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, + step=None, ) element2 = SingleOperation( sequence_id=2, @@ -116,7 +116,7 @@ def test__batch_size_limit(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, + step=None, ) # and @@ -132,7 +132,7 @@ def test__batch_size_limit(): @freeze_time("2024-09-01") -def test__batching(): +def test__batching_with_step(): # given update1 = UpdateRunSnapshot(step=None, assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) update2 = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) @@ -144,19 +144,19 @@ def test__batching(): # and element1 = SingleOperation( sequence_id=1, - timestamp=time.process_time(), + timestamp=time.monotonic(), operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=None, + step=1, ) element2 = SingleOperation( sequence_id=2, - timestamp=time.process_time(), + timestamp=time.monotonic(), operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=None, + step=1, ) # and @@ -182,6 +182,55 @@ def test__batching(): assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) +@freeze_time("2024-09-01") +def test__batch_metrics_no_step(): + """ + Log a single metric 2 times, with step being None. The two operations should be batched, but not merged. + """ + + # given + updates = [UpdateRunSnapshot(step=None, append={"a": Value(int64=value)}) for value in range(2)] + + # and + operations = [RunOperation(update=update, project="project", run_id="run_id") for update in updates] + + elements = [ + SingleOperation( + sequence_id=seq, + timestamp=time.monotonic(), + operation=operations[seq].SerializeToString(), + is_batchable=True, + metadata_size=updates[seq].ByteSize(), + step=None, + ) + for seq in range(2) + ] + + # and + queue = AggregatingQueue(max_queue_size=10, max_elements_in_batch=10, wait_time=0.1) + + # and + for e in elements: + queue.put_nowait(element=e) + + # when + result = queue.get() + + # then + assert result.sequence_id == elements[-1].sequence_id + assert result.timestamp == elements[-1].timestamp + + batch = RunOperation() + batch.ParseFromString(result.operation) + + assert batch.project == "project" + assert batch.run_id == "run_id" + assert len(batch.update_batch.snapshots) == len(operations) + + for i, update in enumerate(batch.update_batch.snapshots): + assert update.append == {"a": Value(int64=i)} + + @freeze_time("2024-09-01") def test__queue_element_size_limit_with_different_steps(): # given @@ -195,7 +244,7 @@ def test__queue_element_size_limit_with_different_steps(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, + step=1.0, ) element2 = SingleOperation( sequence_id=2, @@ -203,7 +252,7 @@ def test__queue_element_size_limit_with_different_steps(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=2.0, + step=2.0, ) # and @@ -235,7 +284,7 @@ def test__not_merge_two_run_creation(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, + step=None, ) element2 = SingleOperation( sequence_id=2, @@ -243,7 +292,7 @@ def test__not_merge_two_run_creation(): operation=operation2.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, + step=None, ) # and @@ -301,7 +350,7 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation1.SerializeToString(), is_batchable=False, metadata_size=0, - batch_key=None, + step=None, ) element2 = SingleOperation( sequence_id=2, @@ -309,7 +358,7 @@ def test__not_merge_run_creation_with_metadata_update(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, + step=None, ) # and @@ -367,7 +416,7 @@ def test__merge_same_key(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=update1.ByteSize(), - batch_key=1.0, + step=1.0, ) element2 = SingleOperation( sequence_id=2, @@ -375,7 +424,7 @@ def test__merge_same_key(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=update2.ByteSize(), - batch_key=1.0, + step=1.0, ) # and @@ -419,7 +468,7 @@ def test__merge_two_different_steps(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, + step=1.0, ) element2 = SingleOperation( sequence_id=2, @@ -427,7 +476,7 @@ def test__merge_two_different_steps(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=2.0, + step=2.0, ) # and @@ -470,7 +519,7 @@ def test__merge_step_with_none(): operation=operation1.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=1.0, + step=1.0, ) element2 = SingleOperation( sequence_id=2, @@ -478,7 +527,7 @@ def test__merge_step_with_none(): operation=operation2.SerializeToString(), is_batchable=True, metadata_size=0, - batch_key=None, + step=None, ) # and diff --git a/tests/unit/test_attribute.py b/tests/unit/test_attribute.py index 146b9d93..29f32efc 100644 --- a/tests/unit/test_attribute.py +++ b/tests/unit/test_attribute.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from unittest.mock import Mock @@ -8,12 +9,14 @@ ) from neptune_scale import Run +from neptune_scale.exceptions import NeptuneSeriesStepNonIncreasing @fixture def run(api_token): run = Run(project="dummy/project", run_id="dummy-run", mode="disabled", api_token=api_token) - run._attr_store.log = Mock() + # Mock log to be able to assert calls, but also proxy to the actual method so it does its job + run._attr_store.log = Mock(side_effect=run._attr_store.log) with run: yield run @@ -66,8 +69,27 @@ def test_tags(run, store): def test_series(run, store): - run["sys/series"].append(1, step=1, timestamp=10) - store.log.assert_called_with(metrics={"sys/series": 1}, step=1, timestamp=10) + timestamp = time.time() + run["my/series"].append(1, step=1, timestamp=timestamp) + store.log.assert_called_with(metrics={"my/series": 1}, step=1, timestamp=timestamp) - run["sys/series"].append({"foo": 1, "bar": 2}, step=2) - store.log.assert_called_with(metrics={"sys/series/foo": 1, "sys/series/bar": 2}, step=2, timestamp=None) + run["my/series"].append({"foo": 1, "bar": 2}, step=2) + store.log.assert_called_with(metrics={"my/series/foo": 1, "my/series/bar": 2}, step=2, timestamp=None) + + +def test_error_on_non_increasing_step(run): + run["series"].append(1, step=2) + + # Step lower than previous + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(2, step=1) + + # Equal to previous, but different value + with pytest.raises(NeptuneSeriesStepNonIncreasing): + run["series"].append(3, step=2) + + # Equal to previous, same value -> should pass + run["series"].append(1, step=2) + + # None should pass, as it means auto-increment + run["series"].append(4, step=None) diff --git a/tests/unit/test_sync_process.py b/tests/unit/test_sync_process.py index 7f62aa46..e67264bb 100644 --- a/tests/unit/test_sync_process.py +++ b/tests/unit/test_sync_process.py @@ -34,7 +34,7 @@ def single_operation(update: UpdateRunSnapshot, sequence_id): operation=operation.SerializeToString(), is_batchable=True, metadata_size=update.ByteSize(), - batch_key=None, + step=None, )