From 448034d15bf9c8673a613d54e4690dc11bacd180 Mon Sep 17 00:00:00 2001 From: Krzysztof Godlewski Date: Wed, 27 Nov 2024 17:51:38 +0100 Subject: [PATCH 1/3] Add `black` to `dev_requirements.txt` --- dev_requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dev_requirements.txt b/dev_requirements.txt index 59d76f06..e6cb5378 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,7 @@ -e . # dev +black pre-commit pytest pytest-timeout From 34063135dbe8da069c4b5231ddcdbf4574c90f97 Mon Sep 17 00:00:00 2001 From: Krzysztof Godlewski Date: Wed, 27 Nov 2024 18:03:31 +0100 Subject: [PATCH 2/3] Move the core logic from `Run.log()` `AttributeStore.log()` is now responsible for feeding data to the `OperationsQueue` --- src/neptune_scale/api/attribute.py | 147 ++++++++++++-------- src/neptune_scale/api/run.py | 45 ++---- src/neptune_scale/api/validation.py | 24 +++- src/neptune_scale/sync/metadata_splitter.py | 26 ++-- 4 files changed, 135 insertions(+), 107 deletions(-) diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index fb9b7c8a..3d140025 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -7,7 +7,6 @@ ) from datetime import datetime from typing import ( - TYPE_CHECKING, Any, Callable, Optional, @@ -15,8 +14,8 @@ cast, ) -if TYPE_CHECKING: - from neptune_scale.api.run import Run +from neptune_scale.sync.metadata_splitter import MetadataSplitter +from neptune_scale.sync.operations_queue import OperationsQueue __all__ = ("Attribute", "AttributeStore") @@ -44,12 +43,21 @@ def wrapper(*args, **kwargs): # type: ignore # TODO: proper typehinting +# AtomType = Union[float, bool, int, str, datetime, list, set, tuple] ValueType = Any # Union[float, int, str, bool, datetime, Tuple, List, Dict, Set] class AttributeStore: - def __init__(self, run: "Run") -> None: - self._run = run + """ + Responsible for managing local attribute store, and pushing log() operations + to the provided OperationsQueue -- assuming that there is something on the other + end consuming the queue (which would be SyncProcess). + """ + + def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) -> None: + self._project = project + self._run_id = run_id + self._operations_queue = operations_queue self._attributes: dict[str, Attribute] = {} def __getitem__(self, path: str) -> "Attribute": @@ -62,7 +70,7 @@ def __getitem__(self, path: str) -> "Attribute": return attr def __setitem__(self, key: str, value: ValueType) -> None: - # TODO: validate type if attr is already known + # TODO: validate type if attr is already known? attr = self[key] attr.assign(value) @@ -75,31 +83,34 @@ def log( tags_add: Optional[dict[str, Union[list[str], set[str], tuple[str]]]] = None, tags_remove: Optional[dict[str, Union[list[str], set[str], tuple[str]]]] = None, ) -> None: - # TODO: This should not call Run.log, but do the actual work. Reverse the current dependency so that this - # class handles all the logging - timestamp = datetime.now() if timestamp is None else timestamp - - # TODO: Remove this and teach MetadataSplitter to handle Nones - configs = {} if configs is None else configs - metrics = {} if metrics is None else metrics - tags_add = {} if tags_add is None else tags_add - tags_remove = {} if tags_remove is None else tags_remove - - # TODO: remove once Run.log accepts Union[datetime, float] - timestamp = cast(datetime, timestamp) - self._run.log( - step=step, timestamp=timestamp, configs=configs, metrics=metrics, tags_add=tags_add, tags_remove=tags_remove + if timestamp is None: + timestamp = datetime.now() + elif isinstance(timestamp, float): + timestamp = datetime.fromtimestamp(timestamp) + + splitter: MetadataSplitter = MetadataSplitter( + project=self._project, + run_id=self._run_id, + step=step, + timestamp=timestamp, + configs=configs, + metrics=metrics, + add_tags=tags_add, + remove_tags=tags_remove, ) + for operation, metadata_size in splitter: + self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step) + class Attribute: """Objects of this class are returned on dict-like access to Run. Attributes have a path and - allow logging values under it. + allow logging values under it. Example: - run = Run(...) - run['foo'] = 1 - run['nested'] = {'foo': 1, {'bar': {'baz': 2}}} - run['bar'].append(1, step=10) + run = Run(...) + run['foo'] = 1 + run['nested'] = {'foo': 1, {'bar': {'baz': 2}}} + run['bar'].append(1, step=10) """ def __init__(self, store: AttributeStore, path: str) -> None: @@ -166,37 +177,6 @@ def extend( # TODO: change Run API to typehint timestamp as Union[datetime, float] -def iter_nested(dict_: dict[str, ValueType], path: str) -> Iterator[tuple[tuple[str, ...], ValueType]]: - """Iterate a nested dictionary, yielding a tuple of path components and value. - - >>> list(iter_nested({"foo": 1, "bar": {"baz": 2}}, "base")) - [(('base', 'foo'), 1), (('base', 'bar', 'baz'), 2)] - >>> list(iter_nested({"foo":{"bar": 1}, "bar":{"baz": 2}}, "base")) - [(('base', 'foo', 'bar'), 1), (('base', 'bar', 'baz'), 2)] - >>> list(iter_nested({"foo": 1, "bar": 2}, "base")) - [(('base', 'foo'), 1), (('base', 'bar'), 2)] - >>> list(iter_nested({"foo": {}}, "")) - Traceback (most recent call last): - ... - ValueError: The dictionary cannot be empty or contain empty nested dictionaries. - """ - - parts = tuple(path.split("/")) - yield from _iter_nested(dict_, parts) - - -def _iter_nested(dict_: dict[str, ValueType], path_acc: tuple[str, ...]) -> Iterator[tuple[tuple[str, ...], ValueType]]: - if not dict_: - raise ValueError("The dictionary cannot be empty or contain empty nested dictionaries.") - - for key, value in dict_.items(): - current_path = path_acc + (key,) - if isinstance(value, dict): - yield from _iter_nested(value, current_path) - else: - yield current_path, value - - def cleanup_path(path: str) -> str: """ >>> cleanup_path('/a/b/c') @@ -209,10 +189,17 @@ def cleanup_path(path: str) -> str: Traceback (most recent call last): ... ValueError: Invalid path: `a//b/c`. Path components must not be empty. + >>> cleanup_path('a/ /b/c') + Traceback (most recent call last): + ... + ValueError: Invalid path: `a/ /b/c`. Path components must not be empty. + >>> cleanup_path('a/b/c ') + Traceback (most recent call last): + ... + ValueError: Invalid path: `a/b/c `. Path cannot contain leading or trailing whitespace. """ - path = path.strip() - if path in ("", "/"): + if path.strip() in ("", "/"): raise ValueError(f"Invalid path: `{path}`.") if path.startswith("/"): @@ -220,14 +207,23 @@ def cleanup_path(path: str) -> str: if path.endswith("/"): raise ValueError(f"Invalid path: `{path}`. Path must not end with a slash.") - if not all(path.split("/")): + + if not all(x.strip() for x in path.split("/")): raise ValueError(f"Invalid path: `{path}`. Path components must not be empty.") + if path[0].lstrip() != path[0] or path[-1].rstrip() != path[-1]: + raise ValueError(f"Invalid path: `{path}`. Path cannot contain leading or trailing whitespace.") + return path -def accumulate_dict_values(value: Union[ValueType, dict[str, ValueType]], path_or_base: str) -> dict: +def accumulate_dict_values(value: Union[ValueType, dict[str, ValueType]], path_or_base: str) -> dict[str, Any]: """ + If value is a dict, flatten nested dictionaries into a single dict with unwrapped paths, each + starting with `path_or_base`. + + If value is an atom, return a dict with a single entry `path_or_base` -> `value`. + >>> accumulate_dict_values(1, "foo") {'foo': 1} >>> accumulate_dict_values({"bar": 1, 'l0/l1': 2, 'l3':{"l4": 3}}, "foo") @@ -240,3 +236,34 @@ def accumulate_dict_values(value: Union[ValueType, dict[str, ValueType]], path_o data = {path_or_base: value} return data + + +def iter_nested(dict_: dict[str, ValueType], path: str) -> Iterator[tuple[tuple[str, ...], ValueType]]: + """Iterate a nested dictionary, yielding a tuple of path components and value. + + >>> list(iter_nested({"foo": 1, "bar": {"baz": 2}}, "base")) + [(('base', 'foo'), 1), (('base', 'bar', 'baz'), 2)] + >>> list(iter_nested({"foo":{"bar": 1}, "bar":{"baz": 2}}, "base")) + [(('base', 'foo', 'bar'), 1), (('base', 'bar', 'baz'), 2)] + >>> list(iter_nested({"foo": 1, "bar": 2}, "base")) + [(('base', 'foo'), 1), (('base', 'bar'), 2)] + >>> list(iter_nested({"foo": {}}, "")) + Traceback (most recent call last): + ... + ValueError: The dictionary cannot be empty or contain empty nested dictionaries. + """ + + parts = tuple(path.split("/")) + yield from _iter_nested(dict_, parts) + + +def _iter_nested(dict_: dict[str, ValueType], path_acc: tuple[str, ...]) -> Iterator[tuple[tuple[str, ...], ValueType]]: + if not dict_: + raise ValueError("The dictionary cannot be empty or contain empty nested dictionaries.") + + for key, value in dict_.items(): + current_path = path_acc + (key,) + if isinstance(value, dict): + yield from _iter_nested(value, current_path) + else: + yield current_path, value diff --git a/src/neptune_scale/api/run.py b/src/neptune_scale/api/run.py index 74302bb0..f6d05f3b 100644 --- a/src/neptune_scale/api/run.py +++ b/src/neptune_scale/api/run.py @@ -29,7 +29,7 @@ AttributeStore, ) from neptune_scale.api.validation import ( - verify_collection_type, + verify_dict_type, verify_max_length, verify_non_empty, verify_project_qualified_name, @@ -48,7 +48,6 @@ ErrorsQueue, ) from neptune_scale.sync.lag_tracking import LagTracker -from neptune_scale.sync.metadata_splitter import MetadataSplitter from neptune_scale.sync.operations_queue import OperationsQueue from neptune_scale.sync.parameters import ( MAX_EXPERIMENT_NAME_LENGTH, @@ -199,13 +198,15 @@ def __init__( self._project: str = input_project self._run_id: str = run_id - self._attr_store: AttributeStore = AttributeStore(self) self._lock = threading.RLock() self._operations_queue: OperationsQueue = OperationsQueue( lock=self._lock, max_size=max_queue_size, ) + + self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue) + self._errors_queue: ErrorsQueue = ErrorsQueue() self._errors_monitor = ErrorsMonitor( errors_queue=self._errors_queue, @@ -536,25 +537,10 @@ def log( verify_type("tags_add", tags_add, (dict, type(None))) verify_type("tags_remove", tags_remove, (dict, type(None))) - timestamp = datetime.now() if timestamp is None else timestamp - # TODO: move this into AttributeStore - configs = {} if configs is None else configs - metrics = {} if metrics is None else metrics - tags_add = {} if tags_add is None else tags_add - tags_remove = {} if tags_remove is None else tags_remove - - # TODO: refactor this into something like `verify_dict_types(name, allowed_key_types, allowed_value_types)` - verify_collection_type("`configs` keys", list(configs.keys()), str) - verify_collection_type("`metrics` keys", list(metrics.keys()), str) - verify_collection_type("`tags_add` keys", list(tags_add.keys()), str) - verify_collection_type("`tags_remove` keys", list(tags_remove.keys()), str) - - verify_collection_type( - "`configs` values", list(configs.values()), (float, bool, int, str, datetime, list, set, tuple) - ) - verify_collection_type("`metrics` values", list(metrics.values()), (float, int)) - verify_collection_type("`tags_add` values", list(tags_add.values()), (list, set, tuple)) - verify_collection_type("`tags_remove` values", list(tags_remove.values()), (list, set, tuple)) + verify_dict_type("configs", configs, (float, bool, int, str, datetime, list, set, tuple)) + verify_dict_type("metrics", metrics, (float, int)) + verify_dict_type("tags_add", tags_add, (list, set, tuple)) + verify_dict_type("tags_remove", tags_remove, (list, set, tuple)) # Don't log anything after we've been stopped. This allows continuing the training script # after a non-recoverable error happened. Note we don't to use self._lock in this check, @@ -562,21 +548,10 @@ def log( if self._is_closing: return - # TODO: move this to a separate process or thread, to make the .log call as lightweight as possible - splitter: MetadataSplitter = MetadataSplitter( - project=self._project, - run_id=self._run_id, - step=step, - timestamp=timestamp, - configs=configs, - metrics=metrics, - add_tags=tags_add, - remove_tags=tags_remove, + self._attr_store.log( + step=step, timestamp=timestamp, configs=configs, metrics=metrics, tags_add=tags_add, tags_remove=tags_remove ) - for operation, metadata_size in splitter: - self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step) - def _wait( self, phrase: str, diff --git a/src/neptune_scale/api/validation.py b/src/neptune_scale/api/validation.py index 7142b1b0..20a75719 100644 --- a/src/neptune_scale/api/validation.py +++ b/src/neptune_scale/api/validation.py @@ -10,6 +10,7 @@ from typing import ( Any, + Optional, Union, ) @@ -55,10 +56,29 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None: def verify_collection_type( var_name: str, var: Union[list, set, tuple], expected_type: Union[type, tuple], allow_none: bool = True ) -> None: - if var is None and not allow_none: - raise ValueError(f"{var_name} must not be None") + if var is None: + if not allow_none: + raise ValueError(f"{var_name} must not be None") + return verify_type(var_name, var, (list, set, tuple)) for value in var: verify_type(f"elements of collection '{var_name}'", value, expected_type) + + +def verify_dict_type( + var_name: str, var: Optional[dict[Any, Any]], expected_type: Union[type, tuple], allow_none: bool = True +) -> None: + if var is None: + if not allow_none: + raise ValueError(f"{var_name} must not be None") + return + + verify_type(var_name, var, dict) + + for key, value in var.items(): + if not isinstance(key, str): + raise TypeError(f"Keys of dictionary '{var_name}' must be strings (got `{key}`)") + + verify_type(f"Values of dictionary '{var_name}'", value, expected_type) diff --git a/src/neptune_scale/sync/metadata_splitter.py b/src/neptune_scale/sync/metadata_splitter.py index 01743495..e92524af 100644 --- a/src/neptune_scale/sync/metadata_splitter.py +++ b/src/neptune_scale/sync/metadata_splitter.py @@ -52,20 +52,20 @@ def __init__( run_id: str, step: Optional[Union[int, float]], timestamp: datetime, - configs: dict[str, Union[float, bool, int, str, datetime, list, set, tuple]], - metrics: dict[str, float], - add_tags: dict[str, Union[list[str], set[str], tuple[str]]], - remove_tags: dict[str, Union[list[str], set[str], tuple[str]]], + configs: Optional[dict[str, Union[float, bool, int, str, datetime, list, set, tuple]]], + metrics: Optional[dict[str, float]], + add_tags: Optional[dict[str, Union[list[str], set[str], tuple[str]]]], + remove_tags: Optional[dict[str, Union[list[str], set[str], tuple[str]]]], max_message_bytes_size: int = 1024 * 1024, ): self._step = None if step is None else make_step(number=step) self._timestamp = datetime_to_proto(timestamp) self._project = project self._run_id = run_id - self._configs = peekable(configs.items()) - self._metrics = peekable(self._skip_non_finite(step, metrics)) - self._add_tags = peekable(add_tags.items()) - self._remove_tags = peekable(remove_tags.items()) + self._configs = peekable(configs.items()) if configs else None + self._metrics = peekable(self._skip_non_finite(step, metrics)) if metrics else None + self._add_tags = peekable(add_tags.items()) if add_tags else None + self._remove_tags = peekable(remove_tags.items()) if remove_tags else None self._max_update_bytes_size = ( max_message_bytes_size @@ -124,10 +124,13 @@ def __next__(self) -> tuple[RunOperation, int]: def populate( self, - assets: peekable[Any], + assets: Optional[peekable[Any]], update_producer: Callable[[str, Value], None], size: int, ) -> int: + if not assets: + return size + while size < self._max_update_bytes_size: try: key, value = assets.peek() @@ -146,8 +149,11 @@ def populate( return size def populate_tags( - self, update: UpdateRunSnapshot, assets: peekable[Any], operation: SET_OPERATION.ValueType, size: int + self, update: UpdateRunSnapshot, assets: Optional[peekable[Any]], operation: SET_OPERATION.ValueType, size: int ) -> int: + if not assets: + return size + while size < self._max_update_bytes_size: try: key, values = assets.peek() From f933b088fa438d5fb1cf448fae00a5cc882d53b3 Mon Sep 17 00:00:00 2001 From: Krzysztof Godlewski Date: Wed, 18 Dec 2024 16:44:51 +0100 Subject: [PATCH 3/3] Update `cleanup_path()` and add some more tests --- src/neptune_scale/api/attribute.py | 25 +++++++++++++++---------- tests/unit/test_attribute.py | 23 +++++++++++++++++++++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index 3d140025..ea433137 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -192,29 +192,34 @@ def cleanup_path(path: str) -> str: >>> cleanup_path('a/ /b/c') Traceback (most recent call last): ... - ValueError: Invalid path: `a/ /b/c`. Path components must not be empty. + ValueError: Invalid path: `a/ /b/c`. Path components cannot contain leading or trailing whitespace. >>> cleanup_path('a/b/c ') Traceback (most recent call last): ... - ValueError: Invalid path: `a/b/c `. Path cannot contain leading or trailing whitespace. + ValueError: Invalid path: `a/b/c `. Path components cannot contain leading or trailing whitespace. """ if path.strip() in ("", "/"): raise ValueError(f"Invalid path: `{path}`.") - if path.startswith("/"): - path = path[1:] + orig_parts = path.split("/") + parts = [x.strip() for x in orig_parts] - if path.endswith("/"): + for i, part in enumerate(parts): + if part != orig_parts[i]: + raise ValueError(f"Invalid path: `{path}`. Path components cannot contain leading or trailing whitespace.") + + # Skip the first slash, if present + if parts[0] == "": + parts = parts[1:] + + if parts[-1] == "": raise ValueError(f"Invalid path: `{path}`. Path must not end with a slash.") - if not all(x.strip() for x in path.split("/")): + if not all(parts): raise ValueError(f"Invalid path: `{path}`. Path components must not be empty.") - if path[0].lstrip() != path[0] or path[-1].rstrip() != path[-1]: - raise ValueError(f"Invalid path: `{path}`. Path cannot contain leading or trailing whitespace.") - - return path + return "/".join(parts) def accumulate_dict_values(value: Union[ValueType, dict[str, ValueType]], path_or_base: str) -> dict[str, Any]: diff --git a/tests/unit/test_attribute.py b/tests/unit/test_attribute.py index 146b9d93..e919dbd0 100644 --- a/tests/unit/test_attribute.py +++ b/tests/unit/test_attribute.py @@ -8,6 +8,7 @@ ) from neptune_scale import Run +from neptune_scale.api.attribute import cleanup_path @fixture @@ -71,3 +72,25 @@ def test_series(run, store): run["sys/series"].append({"foo": 1, "bar": 2}, step=2) store.log.assert_called_with(metrics={"sys/series/foo": 1, "sys/series/bar": 2}, step=2, timestamp=None) + + +@pytest.mark.parametrize( + "path", ["", " ", "/", " /", "/ ", "///", "/a ", "/a/b /", "a/b /c", "a /b/c", "a/b/", "a/b ", " /a/b"] +) +def test_cleanup_path_invalid_path(path): + with pytest.raises(ValueError) as exc: + cleanup_path(path) + + exc.match("Invalid path:") + + +@pytest.mark.parametrize( + "path, expected", + ( + ("/a/b/c", "a/b/c"), + ("a a/b/c", "a a/b/c"), + ("/a a/b/c", "a a/b/c"), + ), +) +def test_cleanup_path_valid_path(path, expected): + assert cleanup_path(path) == expected