Skip to content

Commit da341e3

Browse files
authored
refactor: extract search_iterable out from get and optimize it (#920)
* refactor: extract `search_iterable` out from `get` and optimize it * refactor: clean up `get`'s parameter and docstring * docs(typestub): remove `_search_iterable` from `get.pyi`
1 parent f2905b8 commit da341e3

File tree

4 files changed

+88
-133
lines changed

4 files changed

+88
-133
lines changed

docs/migration.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ portal and add the intent to your current intents when connecting:
6666
4.1.0 → 4.3.0
6767
~~~~~~~~~~~~~~~
6868

69-
A new big change in this release is the implementation of the ``get` utility method.
69+
A new big change in this release is the implementation of the ``get`` utility method.
7070
It allows you to no longer use ``**await bot._http...``.
7171

7272
You can get more information by reading the `get-documentation`_.

interactions/client/get.py

Lines changed: 50 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from asyncio import sleep
22
from enum import Enum
3-
from inspect import isawaitable, isfunction
3+
from inspect import isawaitable
44
from logging import getLogger
5-
from typing import Coroutine, Iterable, List, Optional, Type, TypeVar, Union, get_args
5+
from typing import Coroutine, List, Optional, Type, TypeVar, Union, get_args
66

77
try:
88
from typing import _GenericAlias
@@ -43,7 +43,7 @@ class Force(str, Enum):
4343
HTTP = "http"
4444

4545

46-
def get(*args, **kwargs):
46+
def get(client: Client, obj: Type[_T], **kwargs) -> Optional[_T]:
4747
"""
4848
Helper method to get an object.
4949
@@ -56,17 +56,11 @@ def get(*args, **kwargs):
5656
* purely from cache
5757
* purely from HTTP
5858
* from HTTP if not found in cache else from cache
59-
* Get an object from an iterable
60-
* based of its name
61-
* based of its ID
62-
* based of a custom check
63-
* based of any other attribute the object inside the iterable has
6459
6560
The method has to be awaited when:
6661
* You don't force anything
6762
* You force HTTP
6863
The method doesn't have to be awaited when:
69-
* You get from an iterable
7064
* You force cache
7165
7266
.. note ::
@@ -93,22 +87,6 @@ def get(*args, **kwargs):
9387
enforce HTTP. To prevent this bug from happening it is suggested using ``force="http"`` instead of
9488
the enum.
9589
96-
Getting from an iterable:
97-
98-
.. code-block:: python
99-
100-
# Getting an object from an iterable
101-
check = lambda role: role.name == "ADMIN" and role.color == 0xff0000
102-
roles = [
103-
interactions.Role(name="NOT ADMIN", color=0xff0000),
104-
interactions.Role(name="ADMIN", color=0xff0000),
105-
]
106-
role = get(roles, check=check)
107-
# role will be `interactions.Role(name="ADMIN", color=0xff0000)`
108-
109-
You can specify *any* attribute to check that the object could have (although only ``check``, ``id`` and
110-
``name`` are type-hinted) and the method will check for a match.
111-
11290
Getting an object:
11391
11492
Here you will see two examples on how to get a single objects and the variations of how the object can be
@@ -198,81 +176,68 @@ def _check():
198176
def _check():
199177
return False
200178

201-
if len(args) == 2:
179+
if not isinstance(obj, type) and not isinstance(obj, _GenericAlias):
180+
client: Client
181+
obj: Union[Type[_T], Type[List[_T]]]
182+
raise LibraryException(message="The object must not be an instance of a class!", code=12)
202183

203-
if any(isinstance(_, Iterable) for _ in args):
204-
raise LibraryException(
205-
message="You can only use Iterables as single-argument!", code=12
206-
)
184+
kwargs = _resolve_kwargs(obj, **kwargs)
185+
http_name = f"get_{obj.__name__.lower()}"
186+
kwarg_name = f"{obj.__name__.lower()}_id"
187+
if isinstance(obj, _GenericAlias) or _check():
188+
_obj: Type[_T] = get_args(obj)[0]
189+
_objects: List[Union[_obj, Coroutine]] = []
190+
kwarg_name += "s"
207191

208-
client, obj = args
209-
if not isinstance(obj, type) and not isinstance(obj, _GenericAlias):
210-
client: Client
211-
obj: Union[Type[_T], Type[List[_T]]]
212-
raise LibraryException(
213-
message="The object must not be an instance of a class!", code=12
214-
)
215-
216-
kwargs = _resolve_kwargs(obj, **kwargs)
217-
http_name = f"get_{obj.__name__.lower()}"
218-
kwarg_name = f"{obj.__name__.lower()}_id"
219-
if isinstance(obj, _GenericAlias) or _check():
220-
_obj: Type[_T] = get_args(obj)[0]
221-
_objects: List[Union[_obj, Coroutine]] = []
222-
kwarg_name += "s"
192+
force_cache = kwargs.pop("force", None) == "cache"
193+
force_http = kwargs.pop("force", None) == "http"
223194

224-
force_cache = kwargs.pop("force", None) == "cache"
225-
force_http = kwargs.pop("force", None) == "http"
195+
if not force_http:
196+
_objects = _get_cache(_obj, client, kwarg_name, _list=True, **kwargs)
226197

227-
if not force_http:
228-
_objects = _get_cache(_obj, client, kwarg_name, _list=True, **kwargs)
198+
if force_cache:
199+
return _objects
229200

230-
if force_cache:
231-
return _objects
201+
elif not force_http and None not in _objects:
202+
return _return_cache(_objects)
232203

233-
elif not force_http and None not in _objects:
234-
return _return_cache(_objects)
204+
elif force_http:
205+
_objects.clear()
206+
_func = getattr(client._http, http_name)
207+
for _id in kwargs.get(kwarg_name):
208+
_kwargs = kwargs
209+
_kwargs.pop(kwarg_name)
210+
_kwargs[kwarg_name[:-1]] = _id
211+
_objects.append(_func(**_kwargs))
212+
return _http_request(_obj, http=client._http, request=_objects)
235213

236-
elif force_http:
237-
_objects.clear()
238-
_func = getattr(client._http, http_name)
239-
for _id in kwargs.get(kwarg_name):
214+
else:
215+
_func = getattr(client._http, http_name)
216+
for _index, __obj in enumerate(_objects):
217+
if __obj is None:
218+
_id = kwargs.get(kwarg_name)[_index]
240219
_kwargs = kwargs
241220
_kwargs.pop(kwarg_name)
242221
_kwargs[kwarg_name[:-1]] = _id
243-
_objects.append(_func(**_kwargs))
244-
return _http_request(_obj, http=client._http, request=_objects)
245-
246-
else:
247-
_func = getattr(client._http, http_name)
248-
for _index, __obj in enumerate(_objects):
249-
if __obj is None:
250-
_id = kwargs.get(kwarg_name)[_index]
251-
_kwargs = kwargs
252-
_kwargs.pop(kwarg_name)
253-
_kwargs[kwarg_name[:-1]] = _id
254-
_request = _func(**_kwargs)
255-
_objects[_index] = _request
256-
return _http_request(_obj, http=client._http, request=_objects)
257-
258-
_obj: Optional[_T] = None
222+
_request = _func(**_kwargs)
223+
_objects[_index] = _request
224+
return _http_request(_obj, http=client._http, request=_objects)
259225

260-
force_cache = kwargs.pop("force", None) == "cache"
261-
force_http = kwargs.pop("force", None) == "http"
262-
if not force_http:
263-
_obj = _get_cache(obj, client, kwarg_name, **kwargs)
226+
_obj: Optional[_T] = None
264227

265-
if force_cache:
266-
return _obj
228+
force_cache = kwargs.pop("force", None) == "cache"
229+
force_http = kwargs.pop("force", None) == "http"
230+
if not force_http:
231+
_obj = _get_cache(obj, client, kwarg_name, **kwargs)
267232

268-
elif not force_http and _obj:
269-
return _return_cache(_obj)
233+
if force_cache:
234+
return _obj
270235

271-
else:
272-
return _http_request(obj=obj, http=client._http, _name=http_name, **kwargs)
236+
elif not force_http and _obj:
237+
return _return_cache(_obj)
273238

274-
elif len(args) == 1:
275-
return _search_iterable(*args, **kwargs)
239+
else:
240+
return _http_request(obj=obj, http=client._http, _name=http_name, **kwargs)
276241

277242

278243
async def _http_request(
@@ -339,40 +304,6 @@ def _get_cache(
339304
return _obj
340305

341306

342-
def _search_iterable(items: Iterable[_T], **kwargs) -> Optional[_T]:
343-
if not isinstance(items, Iterable):
344-
raise LibraryException(message="The specified items must be an iterable!", code=12)
345-
346-
if not kwargs:
347-
raise LibraryException(
348-
message="You have to specify either a custom check or a keyword argument to check against!",
349-
code=12,
350-
)
351-
352-
if len(list(kwargs)) > 1:
353-
raise LibraryException(
354-
message="Only one keyword argument to check against is allowed!", code=12
355-
)
356-
357-
_arg = str(list(kwargs)[0])
358-
kwarg = kwargs.get(_arg)
359-
kwarg_is_function: bool = isfunction(kwarg)
360-
361-
__obj = next(
362-
(
363-
item
364-
for item in items
365-
if (
366-
str(getattr(item, _arg, None)) == str(kwarg)
367-
if not kwarg_is_function
368-
else kwarg(item)
369-
)
370-
),
371-
None,
372-
)
373-
return __obj
374-
375-
376307
def _resolve_kwargs(obj, **kwargs):
377308
# This function is needed to get correct kwarg names
378309
if __id := kwargs.pop("parent_id", None):

interactions/client/get.pyi

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from typing import overload, Type, TypeVar, List, Iterable, Optional, Callable, Awaitable, Literal, Coroutine, Union
1+
from enum import Enum
2+
from typing import Awaitable, Coroutine, List, Literal, Optional, Type, TypeVar, Union, overload
23

34
from interactions.client.bot import Client
4-
from enum import Enum
5+
from ..api.http.client import HTTPClient
56
from ..api.models.channel import Channel
67
from ..api.models.guild import Guild
78
from ..api.models.member import Member
8-
from ..api.models.message import Message, Emoji, Sticker
9+
from ..api.models.message import Emoji, Message, Sticker
10+
from ..api.models.role import Role
911
from ..api.models.user import User
1012
from ..api.models.webhook import Webhook
11-
from ..api.models.role import Role
12-
from ..api.http.client import HTTPClient
1313

1414
_SA = TypeVar("_SA", Channel, Guild, Webhook, User, Sticker)
1515
_MA = TypeVar("_MA", Member, Emoji, Role, Message)
@@ -24,11 +24,6 @@ class Force(str, Enum):
2424
CACHE: str
2525
HTTP: str
2626

27-
# not API-object related
28-
@overload
29-
def get(
30-
items: Iterable[_T], /, *, id: Optional[int] = None, name: Optional[str] = None, check: Callable[..., bool], **kwargs
31-
) -> Optional[_T]: ...
3227

3328
# API-object related
3429

@@ -101,7 +96,6 @@ def get(
10196
# Having a not-overloaded definition stops showing a warning/complaint from the IDE if wrong arguments are put in,
10297
# so we'll leave that out
10398

104-
def _search_iterable(item: Iterable[_T], **kwargs) -> Optional[_T]:...
10599
def _get_cache(
106100
_object: Type[_T], client: Client, kwarg_name: str, _list: bool = False, **kwargs
107101
) -> Union[Optional[_T], List[Optional[_T]]]:...

interactions/client/models/utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from asyncio import Task, get_running_loop, sleep
22
from functools import wraps
3-
from typing import TYPE_CHECKING, Awaitable, Callable, List, Union
3+
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable, List, Optional, TypeVar, Union
44

55
from ...api.error import LibraryException
66
from .component import ActionRow, Button, SelectMenu
77

88
if TYPE_CHECKING:
99
from ..context import CommandContext
1010

11-
1211
__all__ = ("autodefer", "spread_to_rows")
1312

13+
_T = TypeVar("_T")
14+
1415

1516
def autodefer(
1617
delay: Union[float, int] = 2,
@@ -138,3 +139,32 @@ async def command(ctx):
138139
raise LibraryException(code=12, message="Number of rows exceeds 5.")
139140

140141
return rows
142+
143+
144+
def search_iterable(
145+
iterable: Iterable[_T], check: Optional[Callable[[_T], bool]] = None, /, **kwargs
146+
) -> List[_T]:
147+
"""
148+
Searches through an iterable for items that:
149+
- Are True for the check, if one is given
150+
- Have attributes that match the keyword arguments (e.x. passing `id=your_id` will only return objects with that id)
151+
152+
:param iterable: The iterable to search through
153+
:type iterable: Iterable
154+
:param check: The check that items will be checked against
155+
:type check: Callable[[Any], bool]
156+
:param kwargs: Any attributes the items should have
157+
:type kwargs: Any
158+
:return: All items that match the check and keywords
159+
:rtype: list
160+
"""
161+
if check:
162+
iterable = filter(check, iterable)
163+
164+
if kwargs:
165+
iterable = filter(
166+
lambda item: all(getattr(item, attr) == value for attr, value in kwargs.items()),
167+
iterable,
168+
)
169+
170+
return list(iterable)

0 commit comments

Comments
 (0)