Skip to content

Move the core logic from Run.log() to AttributeStore #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-e .

# dev
black
pre-commit
pytest
pytest-timeout
Expand Down
160 changes: 96 additions & 64 deletions src/neptune_scale/api/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
)
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
cast,
)

if TYPE_CHECKING:
from neptune_scale.api.run import Run
from neptune_scale.sync.metadata_splitter import MetadataSplitter
from neptune_scale.sync.operations_queue import OperationsQueue

__all__ = ("Attribute", "AttributeStore")

Expand Down Expand Up @@ -44,12 +43,21 @@ def wrapper(*args, **kwargs): # type: ignore


# TODO: proper typehinting
# AtomType = Union[float, bool, int, str, datetime, list, set, tuple]
ValueType = Any # Union[float, int, str, bool, datetime, Tuple, List, Dict, Set]


class AttributeStore:
def __init__(self, run: "Run") -> None:
self._run = run
"""
Responsible for managing local attribute store, and pushing log() operations
to the provided OperationsQueue -- assuming that there is something on the other
end consuming the queue (which would be SyncProcess).
"""

def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) -> None:
self._project = project
self._run_id = run_id
self._operations_queue = operations_queue
self._attributes: dict[str, Attribute] = {}

def __getitem__(self, path: str) -> "Attribute":
Expand All @@ -62,7 +70,7 @@ def __getitem__(self, path: str) -> "Attribute":
return attr

def __setitem__(self, key: str, value: ValueType) -> None:
# TODO: validate type if attr is already known
# TODO: validate type if attr is already known?
attr = self[key]
attr.assign(value)

Expand All @@ -75,31 +83,34 @@ def log(
tags_add: Optional[dict[str, Union[list[str], set[str], tuple[str]]]] = None,
tags_remove: Optional[dict[str, Union[list[str], set[str], tuple[str]]]] = None,
) -> None:
# TODO: This should not call Run.log, but do the actual work. Reverse the current dependency so that this
# class handles all the logging
timestamp = datetime.now() if timestamp is None else timestamp

# TODO: Remove this and teach MetadataSplitter to handle Nones
configs = {} if configs is None else configs
metrics = {} if metrics is None else metrics
tags_add = {} if tags_add is None else tags_add
tags_remove = {} if tags_remove is None else tags_remove

# TODO: remove once Run.log accepts Union[datetime, float]
timestamp = cast(datetime, timestamp)
self._run.log(
step=step, timestamp=timestamp, configs=configs, metrics=metrics, tags_add=tags_add, tags_remove=tags_remove
if timestamp is None:
timestamp = datetime.now()
elif isinstance(timestamp, float):
timestamp = datetime.fromtimestamp(timestamp)

splitter: MetadataSplitter = MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
)

for operation, metadata_size in splitter:
self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step)


class Attribute:
"""Objects of this class are returned on dict-like access to Run. Attributes have a path and
allow logging values under it.
allow logging values under it. Example:

run = Run(...)
run['foo'] = 1
run['nested'] = {'foo': 1, {'bar': {'baz': 2}}}
run['bar'].append(1, step=10)
run = Run(...)
run['foo'] = 1
run['nested'] = {'foo': 1, {'bar': {'baz': 2}}}
run['bar'].append(1, step=10)
"""

def __init__(self, store: AttributeStore, path: str) -> None:
Expand Down Expand Up @@ -166,37 +177,6 @@ def extend(
# TODO: change Run API to typehint timestamp as Union[datetime, float]


def iter_nested(dict_: dict[str, ValueType], path: str) -> Iterator[tuple[tuple[str, ...], ValueType]]:
"""Iterate a nested dictionary, yielding a tuple of path components and value.

>>> list(iter_nested({"foo": 1, "bar": {"baz": 2}}, "base"))
[(('base', 'foo'), 1), (('base', 'bar', 'baz'), 2)]
>>> list(iter_nested({"foo":{"bar": 1}, "bar":{"baz": 2}}, "base"))
[(('base', 'foo', 'bar'), 1), (('base', 'bar', 'baz'), 2)]
>>> list(iter_nested({"foo": 1, "bar": 2}, "base"))
[(('base', 'foo'), 1), (('base', 'bar'), 2)]
>>> list(iter_nested({"foo": {}}, ""))
Traceback (most recent call last):
...
ValueError: The dictionary cannot be empty or contain empty nested dictionaries.
"""

parts = tuple(path.split("/"))
yield from _iter_nested(dict_, parts)


def _iter_nested(dict_: dict[str, ValueType], path_acc: tuple[str, ...]) -> Iterator[tuple[tuple[str, ...], ValueType]]:
if not dict_:
raise ValueError("The dictionary cannot be empty or contain empty nested dictionaries.")

for key, value in dict_.items():
current_path = path_acc + (key,)
if isinstance(value, dict):
yield from _iter_nested(value, current_path)
else:
yield current_path, value


def cleanup_path(path: str) -> str:
"""
>>> cleanup_path('/a/b/c')
Expand All @@ -209,25 +189,46 @@ def cleanup_path(path: str) -> str:
Traceback (most recent call last):
...
ValueError: Invalid path: `a//b/c`. Path components must not be empty.
>>> cleanup_path('a/ /b/c')
Traceback (most recent call last):
...
ValueError: Invalid path: `a/ /b/c`. Path components cannot contain leading or trailing whitespace.
>>> cleanup_path('a/b/c ')
Traceback (most recent call last):
...
ValueError: Invalid path: `a/b/c `. Path components cannot contain leading or trailing whitespace.
"""

path = path.strip()
if path in ("", "/"):
if path.strip() in ("", "/"):
Copy link
Contributor

@michalsosn michalsosn Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, since you're not assigning the stripped path back to path, if the path has some whitespace, it will stay and break the 2 checks following immediately below.

E.g.
previously " /path/to/sth/ " -> strip() -> "/path/to/sth/" -> path[1:] -> "path/to/sth/" -> error because it ends with /
now " /path/to/sth/ " -> strip() is not saved -> " /path/to/sth/ " -> path[1:] ignored bc path starts with " " -> " /path/to/sth/ " -> no error because it ends with " "...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add tests

raise ValueError(f"Invalid path: `{path}`.")

if path.startswith("/"):
path = path[1:]
orig_parts = path.split("/")
parts = [x.strip() for x in orig_parts]

for i, part in enumerate(parts):
if part != orig_parts[i]:
raise ValueError(f"Invalid path: `{path}`. Path components cannot contain leading or trailing whitespace.")

# Skip the first slash, if present
if parts[0] == "":
parts = parts[1:]

if path.endswith("/"):
if parts[-1] == "":
raise ValueError(f"Invalid path: `{path}`. Path must not end with a slash.")
if not all(path.split("/")):

if not all(parts):
raise ValueError(f"Invalid path: `{path}`. Path components must not be empty.")

return path
return "/".join(parts)


def accumulate_dict_values(value: Union[ValueType, dict[str, ValueType]], path_or_base: str) -> dict:
def accumulate_dict_values(value: Union[ValueType, dict[str, ValueType]], path_or_base: str) -> dict[str, Any]:
"""
If value is a dict, flatten nested dictionaries into a single dict with unwrapped paths, each
starting with `path_or_base`.

If value is an atom, return a dict with a single entry `path_or_base` -> `value`.

>>> accumulate_dict_values(1, "foo")
{'foo': 1}
>>> accumulate_dict_values({"bar": 1, 'l0/l1': 2, 'l3':{"l4": 3}}, "foo")
Expand All @@ -240,3 +241,34 @@ def accumulate_dict_values(value: Union[ValueType, dict[str, ValueType]], path_o
data = {path_or_base: value}

return data


def iter_nested(dict_: dict[str, ValueType], path: str) -> Iterator[tuple[tuple[str, ...], ValueType]]:
"""Iterate a nested dictionary, yielding a tuple of path components and value.

>>> list(iter_nested({"foo": 1, "bar": {"baz": 2}}, "base"))
[(('base', 'foo'), 1), (('base', 'bar', 'baz'), 2)]
>>> list(iter_nested({"foo":{"bar": 1}, "bar":{"baz": 2}}, "base"))
[(('base', 'foo', 'bar'), 1), (('base', 'bar', 'baz'), 2)]
>>> list(iter_nested({"foo": 1, "bar": 2}, "base"))
[(('base', 'foo'), 1), (('base', 'bar'), 2)]
>>> list(iter_nested({"foo": {}}, ""))
Traceback (most recent call last):
...
ValueError: The dictionary cannot be empty or contain empty nested dictionaries.
"""

parts = tuple(path.split("/"))
yield from _iter_nested(dict_, parts)


def _iter_nested(dict_: dict[str, ValueType], path_acc: tuple[str, ...]) -> Iterator[tuple[tuple[str, ...], ValueType]]:
if not dict_:
raise ValueError("The dictionary cannot be empty or contain empty nested dictionaries.")

for key, value in dict_.items():
current_path = path_acc + (key,)
if isinstance(value, dict):
yield from _iter_nested(value, current_path)
else:
yield current_path, value
45 changes: 10 additions & 35 deletions src/neptune_scale/api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
AttributeStore,
)
from neptune_scale.api.validation import (
verify_collection_type,
verify_dict_type,
verify_max_length,
verify_non_empty,
verify_project_qualified_name,
Expand All @@ -48,7 +48,6 @@
ErrorsQueue,
)
from neptune_scale.sync.lag_tracking import LagTracker
from neptune_scale.sync.metadata_splitter import MetadataSplitter
from neptune_scale.sync.operations_queue import OperationsQueue
from neptune_scale.sync.parameters import (
MAX_EXPERIMENT_NAME_LENGTH,
Expand Down Expand Up @@ -199,13 +198,15 @@ def __init__(

self._project: str = input_project
self._run_id: str = run_id
self._attr_store: AttributeStore = AttributeStore(self)

self._lock = threading.RLock()
self._operations_queue: OperationsQueue = OperationsQueue(
lock=self._lock,
max_size=max_queue_size,
)

self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue)

self._errors_queue: ErrorsQueue = ErrorsQueue()
self._errors_monitor = ErrorsMonitor(
errors_queue=self._errors_queue,
Expand Down Expand Up @@ -536,47 +537,21 @@ def log(
verify_type("tags_add", tags_add, (dict, type(None)))
verify_type("tags_remove", tags_remove, (dict, type(None)))

timestamp = datetime.now() if timestamp is None else timestamp
# TODO: move this into AttributeStore
configs = {} if configs is None else configs
metrics = {} if metrics is None else metrics
tags_add = {} if tags_add is None else tags_add
tags_remove = {} if tags_remove is None else tags_remove

# TODO: refactor this into something like `verify_dict_types(name, allowed_key_types, allowed_value_types)`
verify_collection_type("`configs` keys", list(configs.keys()), str)
verify_collection_type("`metrics` keys", list(metrics.keys()), str)
verify_collection_type("`tags_add` keys", list(tags_add.keys()), str)
verify_collection_type("`tags_remove` keys", list(tags_remove.keys()), str)

verify_collection_type(
"`configs` values", list(configs.values()), (float, bool, int, str, datetime, list, set, tuple)
)
verify_collection_type("`metrics` values", list(metrics.values()), (float, int))
verify_collection_type("`tags_add` values", list(tags_add.values()), (list, set, tuple))
verify_collection_type("`tags_remove` values", list(tags_remove.values()), (list, set, tuple))
verify_dict_type("configs", configs, (float, bool, int, str, datetime, list, set, tuple))
verify_dict_type("metrics", metrics, (float, int))
verify_dict_type("tags_add", tags_add, (list, set, tuple))
verify_dict_type("tags_remove", tags_remove, (list, set, tuple))

# Don't log anything after we've been stopped. This allows continuing the training script
# after a non-recoverable error happened. Note we don't to use self._lock in this check,
# to keep the common path faster, because the benefit of locking here is minimal.
if self._is_closing:
return

# TODO: move this to a separate process or thread, to make the .log call as lightweight as possible
splitter: MetadataSplitter = MetadataSplitter(
project=self._project,
run_id=self._run_id,
step=step,
timestamp=timestamp,
configs=configs,
metrics=metrics,
add_tags=tags_add,
remove_tags=tags_remove,
self._attr_store.log(
step=step, timestamp=timestamp, configs=configs, metrics=metrics, tags_add=tags_add, tags_remove=tags_remove
)

for operation, metadata_size in splitter:
self._operations_queue.enqueue(operation=operation, size=metadata_size, key=step)

def _wait(
self,
phrase: str,
Expand Down
24 changes: 22 additions & 2 deletions src/neptune_scale/api/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing import (
Any,
Optional,
Union,
)

Expand Down Expand Up @@ -55,10 +56,29 @@ def verify_project_qualified_name(var_name: str, var: Any) -> None:
def verify_collection_type(
var_name: str, var: Union[list, set, tuple], expected_type: Union[type, tuple], allow_none: bool = True
) -> None:
if var is None and not allow_none:
raise ValueError(f"{var_name} must not be None")
if var is None:
if not allow_none:
raise ValueError(f"{var_name} must not be None")
return

verify_type(var_name, var, (list, set, tuple))

for value in var:
verify_type(f"elements of collection '{var_name}'", value, expected_type)


def verify_dict_type(
var_name: str, var: Optional[dict[Any, Any]], expected_type: Union[type, tuple], allow_none: bool = True
) -> None:
if var is None:
if not allow_none:
raise ValueError(f"{var_name} must not be None")
return

verify_type(var_name, var, dict)

for key, value in var.items():
if not isinstance(key, str):
raise TypeError(f"Keys of dictionary '{var_name}' must be strings (got `{key}`)")

verify_type(f"Values of dictionary '{var_name}'", value, expected_type)
Loading
Loading