-
Notifications
You must be signed in to change notification settings - Fork 1
Dev/minimal flow #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev/minimal flow #16
Changes from 1 commit
6cf6715
b85f1c5
2faf3a5
4c91b15
67f63cb
6e4ada2
d8098f9
e35876c
cceddab
04a7dce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("MetadataSplitter",) | ||
|
||
from datetime import datetime | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Iterator, | ||
TypeVar, | ||
) | ||
|
||
from more_itertools import peekable | ||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( | ||
SET_OPERATION, | ||
UpdateRunSnapshot, | ||
Value, | ||
) | ||
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation | ||
|
||
from neptune_scale.core.serialization import ( | ||
datetime_to_proto, | ||
make_step, | ||
make_value, | ||
pb_key_size, | ||
) | ||
|
||
T = TypeVar("T", bound=Any) | ||
|
||
|
||
class MetadataSplitter(Iterator[RunOperation]): | ||
def __init__( | ||
self, | ||
*, | ||
project: str, | ||
run_id: str, | ||
step: int | float | None, | ||
timestamp: datetime, | ||
fields: dict[str, float | bool | int | str | datetime | list | set], | ||
metrics: dict[str, float], | ||
add_tags: dict[str, list[str] | set[str]], | ||
remove_tags: dict[str, list[str] | set[str]], | ||
max_message_bytes_size: int = 1024 * 1024, | ||
): | ||
self._step = None if step is None else make_step(number=step) | ||
self._timestamp = datetime_to_proto(timestamp) | ||
self._project = project | ||
self._run_id = run_id | ||
self._fields = peekable(fields.items()) | ||
self._metrics = peekable(metrics.items()) | ||
self._add_tags = peekable(add_tags.items()) | ||
self._remove_tags = peekable(remove_tags.items()) | ||
|
||
self._max_update_bytes_size = ( | ||
max_message_bytes_size | ||
- RunOperation( | ||
project=self._project, | ||
run_id=self._run_id, | ||
update=UpdateRunSnapshot(step=self._step, timestamp=self._timestamp), | ||
).ByteSize() | ||
) | ||
|
||
self._has_returned = False | ||
|
||
def __iter__(self) -> MetadataSplitter: | ||
self._has_returned = False | ||
return self | ||
|
||
def __next__(self) -> RunOperation: | ||
size = 0 | ||
update = UpdateRunSnapshot( | ||
step=self._step, | ||
timestamp=self._timestamp, | ||
assign={}, | ||
append={}, | ||
modify_sets={}, | ||
) | ||
|
||
size = self.populate( | ||
assets=self._fields, | ||
update_producer=lambda key, value: update.assign[key].MergeFrom(value), | ||
size=size, | ||
) | ||
size = self.populate( | ||
assets=self._metrics, | ||
update_producer=lambda key, value: update.append[key].MergeFrom(value), | ||
size=size, | ||
) | ||
size = self.populate_tags( | ||
update=update, | ||
assets=self._add_tags, | ||
operation=SET_OPERATION.ADD, | ||
size=size, | ||
) | ||
_ = self.populate_tags( | ||
update=update, | ||
assets=self._remove_tags, | ||
operation=SET_OPERATION.REMOVE, | ||
size=size, | ||
) | ||
|
||
if not self._has_returned or update.assign or update.append or update.modify_sets: | ||
self._has_returned = True | ||
return RunOperation(project=self._project, run_id=self._run_id, update=update) | ||
else: | ||
raise StopIteration | ||
|
||
def populate( | ||
self, | ||
assets: peekable[Any], | ||
update_producer: Callable[[str, Value], None], | ||
size: int, | ||
) -> int: | ||
while size < self._max_update_bytes_size: | ||
try: | ||
key, value = assets.peek() | ||
except StopIteration: | ||
break | ||
|
||
proto_value = make_value(value) | ||
new_size = size + pb_key_size(key) + proto_value.ByteSize() + 6 | ||
|
||
if new_size > self._max_update_bytes_size: | ||
break | ||
|
||
update_producer(key, proto_value) | ||
size, _ = new_size, next(assets) | ||
|
||
return size | ||
|
||
def populate_tags( | ||
self, update: UpdateRunSnapshot, assets: peekable[Any], operation: SET_OPERATION.ValueType, size: int | ||
) -> int: | ||
while size < self._max_update_bytes_size: | ||
try: | ||
key, values = assets.peek() | ||
except StopIteration: | ||
break | ||
|
||
if not isinstance(values, peekable): | ||
values = peekable(values) | ||
|
||
is_full = False | ||
new_size = size + pb_key_size(key) + 6 | ||
for value in values: | ||
tag_size = pb_key_size(value) + 6 | ||
if new_size + tag_size > self._max_update_bytes_size: | ||
values.prepend(value) | ||
is_full = True | ||
break | ||
|
||
update.modify_sets[key].string.values[value] = operation | ||
new_size += tag_size | ||
|
||
size, _ = new_size, next(assets) | ||
|
||
if is_full: | ||
assets.prepend((key, list(values))) | ||
break | ||
|
||
return size |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,40 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("datetime_to_proto", "make_step") | ||
__all__ = ( | ||
"make_value", | ||
"make_step", | ||
"datetime_to_proto", | ||
"pb_key_size", | ||
) | ||
|
||
from datetime import datetime | ||
|
||
from google.protobuf.timestamp_pb2 import Timestamp | ||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Step | ||
from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( | ||
Step, | ||
StringSet, | ||
Value, | ||
) | ||
|
||
|
||
def make_value(value: Value | float | str | int | bool | datetime | list[str] | set[str]) -> Value: | ||
if isinstance(value, Value): | ||
return value | ||
if isinstance(value, float): | ||
return Value(float64=value) | ||
elif isinstance(value, bool): | ||
return Value(bool=value) | ||
elif isinstance(value, int): | ||
return Value(int64=value) | ||
elif isinstance(value, str): | ||
return Value(string=value) | ||
elif isinstance(value, datetime): | ||
return Value(timestamp=datetime_to_proto(value)) | ||
elif isinstance(value, (list, set)): | ||
fv = Value(string_set=StringSet(values=value)) | ||
return fv | ||
else: | ||
raise ValueError(f"Unsupported ingest field value type: {type(value)}") | ||
|
||
|
||
def datetime_to_proto(dt: datetime) -> Timestamp: | ||
|
@@ -33,3 +62,8 @@ def make_step(number: float | int, raise_on_step_precision_loss: bool = False) - | |
micro = micro % m | ||
|
||
return Step(whole=whole, micro=micro) | ||
|
||
|
||
def pb_key_size(key: str) -> int: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A short explanation of how this is calculated and why would be great There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was from our previous script but I think it comes from max length assumption (10k if I remember, so 2 bytes at most for varint representation) + type definition overhead. |
||
key_bin = bytes(key, "utf-8") | ||
return len(key_bin) + 2 + (1 if len(key_bin) > 127 else 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does the
+6
come from?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was based on our internal previous script, an overhead for type and length definitions I think.