Skip to content

Commit 7c0a464

Browse files
authored
Typed search attributes (#366)
1 parent fd938c4 commit 7c0a464

File tree

12 files changed

+1417
-261
lines changed

12 files changed

+1417
-261
lines changed

temporalio/client.py

Lines changed: 197 additions & 68 deletions
Large diffs are not rendered by default.

temporalio/common.py

Lines changed: 361 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,37 @@
44

55
import inspect
66
import types
7+
import warnings
78
from abc import ABC, abstractmethod
89
from dataclasses import dataclass
910
from datetime import datetime, timedelta
1011
from enum import IntEnum
1112
from typing import (
1213
Any,
1314
Callable,
15+
ClassVar,
16+
Collection,
17+
Generic,
18+
Iterator,
1419
List,
1520
Mapping,
1621
Optional,
1722
Sequence,
1823
Text,
1924
Tuple,
2025
Type,
26+
TypeVar,
2127
Union,
2228
get_type_hints,
29+
overload,
2330
)
2431

2532
import google.protobuf.internal.containers
26-
from typing_extensions import ClassVar, TypeAlias
33+
from typing_extensions import ClassVar, NamedTuple, TypeAlias, get_origin
2734

2835
import temporalio.api.common.v1
2936
import temporalio.api.enums.v1
37+
import temporalio.types
3038

3139

3240
@dataclass
@@ -176,6 +184,358 @@ def __setstate__(self, state: object) -> None:
176184

177185
SearchAttributes: TypeAlias = Mapping[str, SearchAttributeValues]
178186

187+
SearchAttributeValue: TypeAlias = Union[str, int, float, bool, datetime, Sequence[str]]
188+
189+
SearchAttributeValueType = TypeVar(
190+
"SearchAttributeValueType", str, int, float, bool, datetime, Sequence[str]
191+
)
192+
193+
194+
class SearchAttributeIndexedValueType(IntEnum):
195+
"""Server index type of a search attribute."""
196+
197+
TEXT = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_TEXT)
198+
KEYWORD = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_KEYWORD)
199+
INT = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_INT)
200+
DOUBLE = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_DOUBLE)
201+
BOOL = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_BOOL)
202+
DATETIME = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_DATETIME)
203+
KEYWORD_LIST = int(
204+
temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_KEYWORD_LIST
205+
)
206+
207+
208+
class SearchAttributeKey(ABC, Generic[SearchAttributeValueType]):
209+
"""Typed search attribute key representation.
210+
211+
Use one of the ``for`` static methods here to create a key.
212+
"""
213+
214+
@property
215+
@abstractmethod
216+
def name(self) -> str:
217+
"""Get the name of the key."""
218+
...
219+
220+
@property
221+
@abstractmethod
222+
def indexed_value_type(self) -> SearchAttributeIndexedValueType:
223+
"""Get the server index typed of the key"""
224+
...
225+
226+
@property
227+
@abstractmethod
228+
def value_type(self) -> Type[SearchAttributeValueType]:
229+
"""Get the Python type of value for the key.
230+
231+
This may contain generics which cannot be used in ``isinstance``.
232+
:py:attr:`origin_value_type` can be used instead.
233+
"""
234+
...
235+
236+
@property
237+
def origin_value_type(self) -> Type:
238+
"""Get the Python type of value for the key without generics."""
239+
return get_origin(self.value_type) or self.value_type
240+
241+
@property
242+
def _metadata_type(self) -> str:
243+
index_type = self.indexed_value_type
244+
if index_type == SearchAttributeIndexedValueType.TEXT:
245+
return "Text"
246+
elif index_type == SearchAttributeIndexedValueType.KEYWORD:
247+
return "Keyword"
248+
elif index_type == SearchAttributeIndexedValueType.INT:
249+
return "Int"
250+
elif index_type == SearchAttributeIndexedValueType.DOUBLE:
251+
return "Double"
252+
elif index_type == SearchAttributeIndexedValueType.BOOL:
253+
return "Bool"
254+
elif index_type == SearchAttributeIndexedValueType.DATETIME:
255+
return "Datetime"
256+
elif index_type == SearchAttributeIndexedValueType.KEYWORD_LIST:
257+
return "KeywordList"
258+
raise ValueError(f"Unrecognized type: {self}")
259+
260+
def value_set(
261+
self, value: SearchAttributeValueType
262+
) -> SearchAttributeUpdate[SearchAttributeValueType]:
263+
"""Create a search attribute update to set the given value on this
264+
key.
265+
"""
266+
return _SearchAttributeUpdate[SearchAttributeValueType](self, value)
267+
268+
def value_unset(self) -> SearchAttributeUpdate[SearchAttributeValueType]:
269+
"""Create a search attribute update to unset the value on this key."""
270+
return _SearchAttributeUpdate[SearchAttributeValueType](self, None)
271+
272+
@staticmethod
273+
def for_text(name: str) -> SearchAttributeKey[str]:
274+
"""Create a 'Text' search attribute type."""
275+
return _SearchAttributeKey[str](name, SearchAttributeIndexedValueType.TEXT, str)
276+
277+
@staticmethod
278+
def for_keyword(name: str) -> SearchAttributeKey[str]:
279+
"""Create a 'Keyword' search attribute type."""
280+
return _SearchAttributeKey[str](
281+
name, SearchAttributeIndexedValueType.KEYWORD, str
282+
)
283+
284+
@staticmethod
285+
def for_int(name: str) -> SearchAttributeKey[int]:
286+
"""Create an 'Int' search attribute type."""
287+
return _SearchAttributeKey[int](name, SearchAttributeIndexedValueType.INT, int)
288+
289+
@staticmethod
290+
def for_float(name: str) -> SearchAttributeKey[float]:
291+
"""Create a 'Double' search attribute type."""
292+
return _SearchAttributeKey[float](
293+
name, SearchAttributeIndexedValueType.DOUBLE, float
294+
)
295+
296+
@staticmethod
297+
def for_bool(name: str) -> SearchAttributeKey[bool]:
298+
"""Create a 'Bool' search attribute type."""
299+
return _SearchAttributeKey[bool](
300+
name, SearchAttributeIndexedValueType.BOOL, bool
301+
)
302+
303+
@staticmethod
304+
def for_datetime(name: str) -> SearchAttributeKey[datetime]:
305+
"""Create a 'Datetime' search attribute type."""
306+
return _SearchAttributeKey[datetime](
307+
name, SearchAttributeIndexedValueType.DATETIME, datetime
308+
)
309+
310+
@staticmethod
311+
def for_keyword_list(name: str) -> SearchAttributeKey[Sequence[str]]:
312+
"""Create a 'KeywordList' search attribute type."""
313+
return _SearchAttributeKey[Sequence[str]](
314+
name,
315+
SearchAttributeIndexedValueType.KEYWORD_LIST,
316+
# Generic types not supported yet like this: https://github.com/python/mypy/issues/4717
317+
Sequence[str], # type: ignore
318+
)
319+
320+
@staticmethod
321+
def _from_metadata_type(
322+
name: str, metadata_type: str
323+
) -> Optional[SearchAttributeKey]:
324+
if metadata_type == "Text":
325+
return SearchAttributeKey.for_text(name)
326+
elif metadata_type == "Keyword":
327+
return SearchAttributeKey.for_keyword(name)
328+
elif metadata_type == "Int":
329+
return SearchAttributeKey.for_int(name)
330+
elif metadata_type == "Double":
331+
return SearchAttributeKey.for_float(name)
332+
elif metadata_type == "Bool":
333+
return SearchAttributeKey.for_bool(name)
334+
elif metadata_type == "Datetime":
335+
return SearchAttributeKey.for_datetime(name)
336+
elif metadata_type == "KeywordList":
337+
return SearchAttributeKey.for_keyword_list(name)
338+
return None
339+
340+
@staticmethod
341+
def _guess_from_untyped_values(
342+
name: str, vals: SearchAttributeValues
343+
) -> Optional[SearchAttributeKey]:
344+
if not vals:
345+
return None
346+
elif len(vals) > 1:
347+
if isinstance(vals[0], str):
348+
return SearchAttributeKey.for_keyword_list(name)
349+
elif isinstance(vals[0], str):
350+
return SearchAttributeKey.for_keyword(name)
351+
elif isinstance(vals[0], int):
352+
return SearchAttributeKey.for_int(name)
353+
elif isinstance(vals[0], float):
354+
return SearchAttributeKey.for_float(name)
355+
elif isinstance(vals[0], bool):
356+
return SearchAttributeKey.for_bool(name)
357+
elif isinstance(vals[0], datetime):
358+
return SearchAttributeKey.for_datetime(name)
359+
return None
360+
361+
362+
@dataclass(frozen=True)
363+
class _SearchAttributeKey(SearchAttributeKey[SearchAttributeValueType]):
364+
_name: str
365+
_indexed_value_type: SearchAttributeIndexedValueType
366+
# No supported way in Python to derive this, so we're setting manually
367+
_value_type: Type[SearchAttributeValueType]
368+
369+
@property
370+
def name(self) -> str:
371+
return self._name
372+
373+
@property
374+
def indexed_value_type(self) -> SearchAttributeIndexedValueType:
375+
return self._indexed_value_type
376+
377+
@property
378+
def value_type(self) -> Type[SearchAttributeValueType]:
379+
return self._value_type
380+
381+
382+
class SearchAttributePair(NamedTuple, Generic[SearchAttributeValueType]):
383+
"""A named tuple representing a key/value search attribute pair."""
384+
385+
key: SearchAttributeKey[SearchAttributeValueType]
386+
value: SearchAttributeValueType
387+
388+
389+
class SearchAttributeUpdate(ABC, Generic[SearchAttributeValueType]):
390+
"""Representation of a search attribute update."""
391+
392+
@property
393+
@abstractmethod
394+
def key(self) -> SearchAttributeKey[SearchAttributeValueType]:
395+
"""Key that is being set."""
396+
...
397+
398+
@property
399+
@abstractmethod
400+
def value(self) -> Optional[SearchAttributeValueType]:
401+
"""Value that is being set or ``None`` if being unset."""
402+
...
403+
404+
405+
@dataclass(frozen=True)
406+
class _SearchAttributeUpdate(SearchAttributeUpdate[SearchAttributeValueType]):
407+
_key: SearchAttributeKey[SearchAttributeValueType]
408+
_value: Optional[SearchAttributeValueType]
409+
410+
@property
411+
def key(self) -> SearchAttributeKey[SearchAttributeValueType]:
412+
return self._key
413+
414+
@property
415+
def value(self) -> Optional[SearchAttributeValueType]:
416+
return self._value
417+
418+
419+
@dataclass(frozen=True)
420+
class TypedSearchAttributes(Collection[SearchAttributePair]):
421+
"""Collection of typed search attributes.
422+
423+
This is represented as an immutable collection of
424+
:py:class:`SearchAttributePair`. This can be created passing a sequence of
425+
pairs to the constructor.
426+
"""
427+
428+
search_attributes: Sequence[SearchAttributePair]
429+
"""Underlying sequence of search attribute pairs. Do not mutate this, only
430+
create new ``TypedSearchAttribute`` instances.
431+
432+
These are sorted by key name during construction. Duplicates cannot exist.
433+
"""
434+
435+
empty: ClassVar[TypedSearchAttributes]
436+
"""Class variable representing an empty set of attributes."""
437+
438+
def __post_init__(self):
439+
"""Post-init initialization."""
440+
# Sort
441+
object.__setattr__(
442+
self,
443+
"search_attributes",
444+
sorted(self.search_attributes, key=lambda pair: pair.key.name),
445+
)
446+
# Ensure no duplicates
447+
for i, pair in enumerate(self.search_attributes):
448+
if i > 0 and self.search_attributes[i - 1].key.name == pair.key.name:
449+
raise ValueError(
450+
f"Duplicate search attribute entries found for key {pair.key.name}"
451+
)
452+
453+
def __len__(self) -> int:
454+
"""Get the number of search attributes."""
455+
return len(self.search_attributes)
456+
457+
def __getitem__(
458+
self, key: SearchAttributeKey[SearchAttributeValueType]
459+
) -> SearchAttributeValueType:
460+
"""Get a single search attribute value by key or fail with
461+
``KeyError``.
462+
"""
463+
ret = next((v for k, v in self if k == key), None)
464+
if ret is None:
465+
raise KeyError()
466+
return ret
467+
468+
def __iter__(self) -> Iterator[SearchAttributePair]:
469+
"""Get an iterator over search attribute key/value pairs."""
470+
return iter(self.search_attributes)
471+
472+
def __contains__(self, key: object) -> bool:
473+
"""Check whether this search attribute contains the given key.
474+
475+
This uses key equality so the key must be the same name and type.
476+
"""
477+
return any(v for k, v in self if k == key)
478+
479+
@overload
480+
def get(
481+
self, key: SearchAttributeKey[SearchAttributeValueType]
482+
) -> Optional[SearchAttributeValueType]:
483+
...
484+
485+
@overload
486+
def get(
487+
self,
488+
key: SearchAttributeKey[SearchAttributeValueType],
489+
default: temporalio.types.AnyType,
490+
) -> Union[SearchAttributeValueType, temporalio.types.AnyType]:
491+
...
492+
493+
def get(
494+
self,
495+
key: SearchAttributeKey[SearchAttributeValueType],
496+
default: Optional[Any] = None,
497+
) -> Any:
498+
"""Get an attribute value for a key (or default). This is similar to
499+
dict.get.
500+
"""
501+
try:
502+
return self.__getitem__(key)
503+
except KeyError:
504+
return default
505+
506+
def updated(self, *search_attributes: SearchAttributePair) -> TypedSearchAttributes:
507+
"""Copy this collection, replacing attributes with matching key names or
508+
adding if key name not present.
509+
"""
510+
attrs = list(self.search_attributes)
511+
# Go over each update, replacing matching keys by index or adding
512+
for attr in search_attributes:
513+
existing_index = next(
514+
(i for i, attr in enumerate(attrs) if attr.key.name == attr.key.name),
515+
None,
516+
)
517+
if existing_index is None:
518+
attrs.append(attr)
519+
else:
520+
attrs[existing_index] = attr
521+
return TypedSearchAttributes(attrs)
522+
523+
524+
TypedSearchAttributes.empty = TypedSearchAttributes(search_attributes=[])
525+
526+
527+
def _warn_on_deprecated_search_attributes(
528+
attributes: Optional[Union[SearchAttributes, Any]],
529+
stack_level: int = 2,
530+
) -> None:
531+
if attributes and isinstance(attributes, Mapping):
532+
warnings.warn(
533+
"Dictionary-based search attributes are deprecated",
534+
DeprecationWarning,
535+
stacklevel=1 + stack_level,
536+
)
537+
538+
179539
MetricAttributes: TypeAlias = Mapping[str, Union[str, int, float, bool]]
180540

181541

0 commit comments

Comments
 (0)