Skip to content

Commit 9a937f8

Browse files
[Client] Add guided_grammar and other missing fields (#532)
Add guided_grammar to the client, + add some missing fields to some codepaths
1 parent 7d4cc3e commit 9a937f8

File tree

5 files changed

+25
-3
lines changed

5 files changed

+25
-3
lines changed

clients/python/llmengine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "0.0.0b33"
15+
__version__ = "0.0.0b34"
1616

1717
import os
1818
from typing import Sequence

clients/python/llmengine/completion.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ async def acreate(
4747
guided_json: Optional[Dict[str, Any]] = None,
4848
guided_regex: Optional[str] = None,
4949
guided_choice: Optional[List[str]] = None,
50+
guided_grammar: Optional[str] = None,
5051
timeout: int = COMPLETION_TIMEOUT,
5152
stream: bool = False,
5253
) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]:
@@ -118,6 +119,9 @@ async def acreate(
118119
guided_choice (Optional[List[str]]):
119120
If specified, the output will be exactly one of the choices.
120121
122+
guided_grammar (Optional[str]):
123+
If specified, the output will follow the context-free grammar provided.
124+
121125
timeout (int):
122126
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
123127
@@ -218,6 +222,7 @@ async def _acreate_stream(
218222
guided_json=guided_json,
219223
guided_regex=guided_regex,
220224
guided_choice=guided_choice,
225+
guided_grammar=guided_grammar,
221226
timeout=timeout,
222227
)
223228

@@ -242,6 +247,11 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse:
242247
frequency_penalty=frequency_penalty,
243248
top_k=top_k,
244249
top_p=top_p,
250+
include_stop_str_in_output=include_stop_str_in_output,
251+
guided_json=guided_json,
252+
guided_regex=guided_regex,
253+
guided_choice=guided_choice,
254+
guided_grammar=guided_grammar,
245255
)
246256

247257
@classmethod
@@ -261,6 +271,7 @@ def create(
261271
guided_json: Optional[Dict[str, Any]] = None,
262272
guided_regex: Optional[str] = None,
263273
guided_choice: Optional[List[str]] = None,
274+
guided_grammar: Optional[str] = None,
264275
timeout: int = COMPLETION_TIMEOUT,
265276
stream: bool = False,
266277
) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]:
@@ -333,6 +344,9 @@ def create(
333344
guided_choice (Optional[List[str]]):
334345
If specified, the output will be exactly one of the choices.
335346
347+
guided_grammar (Optional[str]):
348+
If specified, the output will follow the context-free grammar provided.
349+
336350
timeout (int):
337351
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
338352
@@ -419,6 +433,11 @@ def _create_stream(**kwargs):
419433
frequency_penalty=frequency_penalty,
420434
top_k=top_k,
421435
top_p=top_p,
436+
include_stop_str_in_output=include_stop_str_in_output,
437+
guided_json=guided_json,
438+
guided_regex=guided_regex,
439+
guided_choice=guided_choice,
440+
guided_grammar=guided_grammar,
422441
)
423442

424443
else:
@@ -436,6 +455,7 @@ def _create_stream(**kwargs):
436455
guided_json=guided_json,
437456
guided_regex=guided_regex,
438457
guided_choice=guided_choice,
458+
guided_grammar=guided_grammar,
439459
).dict()
440460
response = cls.post_sync(
441461
resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}",

clients/python/llmengine/data_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ class CompletionSyncV1Request(BaseModel):
331331
guided_json: Optional[Dict[str, Any]] = Field(default=None)
332332
guided_regex: Optional[str] = Field(default=None)
333333
guided_choice: Optional[List[str]] = Field(default=None)
334+
guided_grammar: Optional[str] = Field(default=None)
334335

335336

336337
class TokenOutput(BaseModel):
@@ -405,6 +406,7 @@ class CompletionStreamV1Request(BaseModel):
405406
guided_json: Optional[Dict[str, Any]] = Field(default=None)
406407
guided_regex: Optional[str] = Field(default=None)
407408
guided_choice: Optional[List[str]] = Field(default=None)
409+
guided_grammar: Optional[str] = Field(default=None)
408410

409411

410412
class CompletionStreamOutput(BaseModel):

clients/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "scale-llm-engine"
3-
version = "0.0.0.beta33"
3+
version = "0.0.0.beta34"
44
description = "Scale LLM Engine Python client"
55
license = "Apache-2.0"
66
authors = ["Phil Chen <phil.chen@scale.com>"]

clients/python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="scale-llm-engine",
55
python_requires=">=3.7",
6-
version="0.0.0.beta33",
6+
version="0.0.0.beta34",
77
packages=find_packages(),
88
package_data={"llmengine": ["py.typed"]},
99
)

0 commit comments

Comments
 (0)