Skip to content

Commit 528f76e

Browse files
committed
fix(client): using httpx for running calls within async context
This is so that client.query works within a async context Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
1 parent b3d924e commit 528f76e

File tree

5 files changed

+36
-7
lines changed

5 files changed

+36
-7
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ dependencies = [
3939
# tabulate for CLI with CJK support
4040
# >=0.9.0 for some bug fixes
4141
"tabulate[widechars]>=0.9.0",
42+
# httpx used within openllm.client
43+
"httpx",
44+
# for typing support
4245
"typing_extensions",
4346
]
4447
description = 'OpenLLM: REST/gRPC API server for running any open Large-Language Model - StableLM, Llama, Alpaca, Dolly, Flan-T5, Custom'

src/openllm/_configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,11 +1026,11 @@ def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]:
10261026
return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove}
10271027

10281028
@t.overload
1029-
def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny:
1029+
def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig:
10301030
...
10311031

10321032
@t.overload
1033-
def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig:
1033+
def to_generation_config(self, return_as_dict: t.Literal[True] = ...) -> DictStrAny:
10341034
...
10351035

10361036
def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny:

src/openllm/utils/dantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def __init__(self, enum: Enum, case_sensitive: bool = False):
280280
self.internal_type = enum
281281
super().__init__([e.name for e in self.mapping], case_sensitive)
282282

283-
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
283+
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
284284
if isinstance(value, self.internal_type):
285285
return value
286286
result = super().convert(value, param, ctx)
@@ -292,7 +292,7 @@ def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Contex
292292
class LiteralChoice(EnumChoice):
293293
name = "literal"
294294

295-
def __init__(self, enum: t.Literal, case_sensitive: bool = False):
295+
def __init__(self, enum: t.LiteralString, case_sensitive: bool = False):
296296
# expect every literal value to belong to the same primitive type
297297
values = list(enum.__args__)
298298
item_type = type(values[0])

src/openllm_client/runtimes/base.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import typing as t
1819
from abc import abstractmethod
20+
from urllib.parse import urljoin
1921

2022
import bentoml
23+
import httpx
2124

2225
import openllm
2326

@@ -43,6 +46,14 @@ def metadata_v1(self) -> dict[str, t.Any]:
4346
...
4447

4548

49+
def in_async_context() -> bool:
50+
try:
51+
_ = asyncio.get_running_loop()
52+
return True
53+
except RuntimeError:
54+
return False
55+
56+
4657
class ClientMixin:
4758
_api_version: str
4859
_client_class: type[bentoml.client.Client]
@@ -57,12 +68,17 @@ def __init__(self, address: str, timeout: int = 30):
5768
self._address = address
5869
self._timeout = timeout
5970
assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation."
60-
self._metadata = self.call("metadata")
6171

6272
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
6373
cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
6474
cls._api_version = api_version
6575

76+
@property
77+
def _metadata(self) -> dict[str, t.Any]:
78+
if in_async_context():
79+
return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
80+
return self.call("metadata")
81+
6682
@property
6783
@abstractmethod
6884
def model_name(self) -> str:
@@ -140,7 +156,14 @@ def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **at
140156
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str:
141157
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs)
142158
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
143-
result = self.call("generate", inputs)
159+
if in_async_context():
160+
result = httpx.post(
161+
urljoin(self._address, f"/{self._api_version}/generate"),
162+
json=openllm.utils.bentoml_cattr.unstructure(inputs),
163+
timeout=self.timeout,
164+
).json()
165+
else:
166+
result = self.call("generate", inputs)
144167
r = self.postprocess(result)
145168

146169
if return_raw_response:

src/openllm_client/runtimes/grpc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def configuration(self) -> dict[str, t.Any]:
7373
except KeyError:
7474
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
7575

76-
def postprocess(self, result: Response) -> openllm.GenerationOutput:
76+
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
77+
if isinstance(result, dict):
78+
return openllm.GenerationOutput(**result)
79+
7780
from google.protobuf.json_format import MessageToDict
7881

7982
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))

0 commit comments

Comments
 (0)