Skip to content

Commit 66af541

Browse files
authored
add chat_completion (#31)
* add chat_completion in BaseLLMAPIClient + AnthropicClient * add get_chat_tokens_count in BaseLLMAPIClient + AnthropicClient + OpenAIClient
1 parent 980f72e commit 66af541

File tree

8 files changed

+238
-47
lines changed

8 files changed

+238
-47
lines changed

README.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ any flexibility (API params, endpoints etc.). *We also provide sync version, see
1212
more details below in Usage section.
1313

1414
## Base Interface
15-
The package exposes two simple interfaces for communicating with LLMs (In the future, we
15+
The package exposes two simple interfaces for seamless integration with LLMs (In the future, we
1616
will expand the interface to support more tasks like list models, edits, etc.):
1717
```python
1818
from abc import ABC, abstractmethod
1919
from dataclasses import dataclass, field
2020
from typing import Any, Optional
21+
from enum import Enum
22+
from dataclasses_json import dataclass_json, config
2123
from aiohttp import ClientSession
2224

2325

@@ -30,6 +32,20 @@ class BaseLLMClient(ABC):
3032
raise NotImplementedError()
3133

3234

35+
class Role(Enum):
36+
SYSTEM = "system"
37+
USER = "user"
38+
ASSISTANT = "assistant"
39+
40+
41+
@dataclass_json
42+
@dataclass
43+
class ChatMessage:
44+
role: Role = field(metadata=config(encoder=lambda role: role.value, decoder=Role))
45+
content: str
46+
name: Optional[str] = field(default=None, metadata=config(exclude=lambda name: name is None))
47+
example: bool = field(default=False, metadata=config(exclude=lambda _: True))
48+
3349

3450
@dataclass
3551
class LLMAPIClientConfig:
@@ -49,8 +65,15 @@ class BaseLLMAPIClient(BaseLLMClient, ABC):
4965
temperature: Optional[float] = None, top_p: Optional[float] = None, **kwargs) -> list[str]:
5066
raise NotImplementedError()
5167

68+
async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
69+
max_tokens: int = 16, model: Optional[str] = None, **kwargs) -> list[str]:
70+
raise NotImplementedError()
71+
5272
async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]:
5373
raise NotImplementedError()
74+
75+
async def get_chat_tokens_count(self, messages: list[ChatMessage], **kwargs) -> int:
76+
raise NotImplementedError()
5477
```
5578

5679
## Requirements
@@ -109,10 +132,12 @@ async def main():
109132
llm_client = OpenAIClient(LLMAPIClientConfig(OPENAI_API_KEY, session, default_model="text-davinci-003",
110133
headers={"OpenAI-Organization": OPENAI_ORG_ID})) # The headers are optional
111134
text = "This is indeed a test"
135+
messages = [ChatMessage(role=Role.USER, content="Hello!"),
136+
ChatMessage(role=Role.SYSTEM, content="Hi there! How can I assist you today?")]
112137

113138
print("number of tokens:", await llm_client.get_tokens_count(text)) # 5
114-
print("generated chat:", await llm_client.chat_completion(
115-
messages=[ChatMessage(role=Role.USER, content="Hello!")], model="gpt-3.5-turbo")) # ['Hi there! How can I assist you today?']
139+
print("number of tokens for chat completion:", await llm_client.get_chat_tokens_count(messages, model="gpt-3.5-turbo")) # 23
140+
print("generated chat:", await llm_client.chat_completion(messages, model="gpt-3.5-turbo")) # ['Hi there! How can I assist you today?']
116141
print("generated text:", await llm_client.text_completion(text)) # [' string\n\nYes, this is a test string. Test strings are used to']
117142
print("generated embedding:", await llm_client.embedding(text)) # [0.0023064255, -0.009327292, ...]
118143
```
@@ -190,7 +215,7 @@ Contributions are welcome! Please check out the todos below, and feel free to op
190215
- [ ] Cohere
191216
- [x] Add support for more functions via LLMs
192217
- [x] embeddings
193-
- [ ] chat
218+
- [x] chat
194219
- [ ] list models
195220
- [ ] edits
196221
- [ ] more

llm_client/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
__version__ = "0.7.0"
1+
__version__ = "0.8.0"
22

33
from llm_client.base_llm_client import BaseLLMClient
44

55
# load api clients
66
try:
7-
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
7+
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig, ChatMessage, Role
88
from llm_client.llm_api_client.llm_api_client_factory import LLMAPIClientFactory, LLMAPIClientType
99
# load base-api clients
1010
try:
@@ -15,7 +15,7 @@
1515
pass
1616
# load apis with different dependencies
1717
try:
18-
from llm_client.llm_api_client.openai_client import OpenAIClient, ChatMessage, Role
18+
from llm_client.llm_api_client.openai_client import OpenAIClient
1919
except ImportError:
2020
pass
2121
try:

llm_client/llm_api_client/anthropic_client.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from anthropic import AsyncAnthropic
44

5-
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
5+
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig, ChatMessage, Role
66
from llm_client.consts import PROMPT_KEY
77

88
COMPLETE_PATH = "complete"
@@ -13,6 +13,11 @@
1313
VERSION_HEADER = "anthropic-version"
1414
ACCEPT_VALUE = "application/json"
1515
MAX_TOKENS_KEY = "max_tokens_to_sample"
16+
USER_PREFIX = "Human:"
17+
ASSISTANT_PREFIX = "Assistant:"
18+
START_PREFIX = "\n\n"
19+
SYSTEM_START_PREFIX = "<admin>"
20+
SYSTEM_END_PREFIX = "</admin>"
1621

1722

1823
class AnthropicClient(BaseLLMAPIClient):
@@ -26,6 +31,10 @@ def __init__(self, config: LLMAPIClientConfig):
2631
self._headers[ACCEPT_HEADER] = ACCEPT_VALUE
2732
self._headers[AUTH_HEADER] = self._api_key
2833

34+
async def chat_completion(self, messages: list[ChatMessage], model: Optional[str] = None,
35+
max_tokens: Optional[int] = None, temperature: float = 1, **kwargs) -> list[str]:
36+
return await self.text_completion(self.messages_to_text(messages), model, max_tokens, temperature, **kwargs)
37+
2938
async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
3039
temperature: float = 1, top_p: Optional[float] = None,
3140
**kwargs) -> \
@@ -45,5 +54,26 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, max_to
4554
response_json = await response.json()
4655
return [response_json[COMPLETIONS_KEY]]
4756

57+
async def get_chat_tokens_count(self, messages: list[ChatMessage], **kwargs) -> int:
58+
return await self.get_tokens_count(self.messages_to_text(messages), **kwargs)
59+
4860
async def get_tokens_count(self, text: str, **kwargs) -> int:
4961
return await self._anthropic.count_tokens(text)
62+
63+
def messages_to_text(self, messages: list[ChatMessage]) -> str:
64+
prompt = START_PREFIX
65+
prompt += START_PREFIX.join(map(self._message_to_prompt, messages))
66+
if messages[-1].role != Role.ASSISTANT:
67+
prompt += START_PREFIX
68+
prompt += self._message_to_prompt(ChatMessage(role=Role.ASSISTANT, content=""))
69+
return prompt.rstrip()
70+
71+
@staticmethod
72+
def _message_to_prompt(message: ChatMessage) -> str:
73+
if message.role == Role.USER:
74+
return f"{USER_PREFIX} {message.content}"
75+
if message.role == Role.ASSISTANT:
76+
return f"{ASSISTANT_PREFIX} {message.content}"
77+
if message.role == Role.SYSTEM:
78+
return f"{USER_PREFIX} {SYSTEM_START_PREFIX}{message.content}{SYSTEM_END_PREFIX}"
79+
raise ValueError(f"Unknown role: {message.role}")

llm_client/llm_api_client/base_llm_api_client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass, field
3+
from enum import Enum
34
from typing import Any, Optional
45

6+
from dataclasses_json import dataclass_json, config
7+
58
try:
69
from aiohttp import ClientSession
710
except ImportError:
@@ -11,6 +14,21 @@
1114
from llm_client.consts import MODEL_KEY
1215

1316

17+
class Role(Enum):
18+
SYSTEM = "system"
19+
USER = "user"
20+
ASSISTANT = "assistant"
21+
22+
23+
@dataclass_json
24+
@dataclass
25+
class ChatMessage:
26+
role: Role = field(metadata=config(encoder=lambda role: role.value, decoder=Role))
27+
content: str
28+
name: Optional[str] = field(default=None, metadata=config(exclude=lambda name: name is None))
29+
example: bool = field(default=False, metadata=config(exclude=lambda _: True))
30+
31+
1432
@dataclass
1533
class LLMAPIClientConfig:
1634
api_key: str
@@ -33,9 +51,16 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, max_to
3351
temperature: Optional[float] = None, top_p: Optional[float] = None, **kwargs) -> list[str]:
3452
raise NotImplementedError()
3553

54+
async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
55+
max_tokens: int = 16, model: Optional[str] = None, **kwargs) -> list[str]:
56+
raise NotImplementedError()
57+
3658
async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]:
3759
raise NotImplementedError()
3860

61+
async def get_chat_tokens_count(self, messages: list[ChatMessage], **kwargs) -> int:
62+
raise NotImplementedError()
63+
3964
def _set_model_in_kwargs(self, kwargs, model: Optional[str]) -> None:
4065
if model is not None:
4166
kwargs[MODEL_KEY] = model

llm_client/llm_api_client/openai_client.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
1-
from dataclasses import dataclass, field
2-
from enum import Enum
31
from functools import lru_cache
42
from typing import Optional
53

64
import openai
75
import tiktoken
8-
from dataclasses_json import dataclass_json, config
96
from tiktoken import Encoding
107

11-
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
8+
from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig, ChatMessage
129
from llm_client.consts import PROMPT_KEY
1310

1411
INPUT_KEY = "input"
15-
16-
17-
class Role(Enum):
18-
SYSTEM = "system"
19-
USER = "user"
20-
ASSISTANT = "assistant"
21-
22-
23-
@dataclass_json
24-
@dataclass
25-
class ChatMessage:
26-
role: Role = field(metadata=config(encoder=lambda role: role.value, decoder=Role))
27-
content: str
28-
name: Optional[str] = field(default=None, metadata=config(exclude=lambda name: name is None))
12+
MODEL_NAME_TO_TOKENS_PER_MESSAGE_AND_TOKENS_PER_NAME = {
13+
"gpt-3.5-turbo-0613": (3, 1),
14+
"gpt-3.5-turbo-16k-0613": (3, 1),
15+
"gpt-4-0314": (3, 1),
16+
"gpt-4-32k-0314": (3, 1),
17+
"gpt-4-0613": (3, 1),
18+
"gpt-4-32k-0613": (3, 1),
19+
# every message follows <|start|>{role/name}\n{content}<|end|>\n, if there's a name, the role is omitted
20+
"gpt-3.5-turbo-0301": (4, -1),
21+
}
2922

3023

3124
class OpenAIClient(BaseLLMAPIClient):
@@ -46,7 +39,8 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, temper
4639
return [choice.text for choice in completions.choices]
4740

4841
async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
49-
max_tokens: int = 16, top_p: float = 1, model: Optional[str] = None, **kwargs) -> list[str]:
42+
max_tokens: int = 16, top_p: float = 1, model: Optional[str] = None, **kwargs) \
43+
-> list[str]:
5044
self._set_model_in_kwargs(kwargs, model)
5145
kwargs["messages"] = [message.to_dict() for message in messages]
5246
kwargs["temperature"] = temperature
@@ -66,6 +60,51 @@ async def get_tokens_count(self, text: str, model: Optional[str] = None, **kwarg
6660
model = self._default_model
6761
return len(self._get_relevant_tokeniser(model).encode(text))
6862

63+
async def get_chat_tokens_count(self, messages: list[ChatMessage], model: Optional[str] = None) -> int:
64+
"""
65+
This is based on:
66+
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
67+
"""
68+
model = self._get_model_name_for_tokeniser(model)
69+
encoding = self._get_relevant_tokeniser(model)
70+
tokens_per_message, tokens_per_name = MODEL_NAME_TO_TOKENS_PER_MESSAGE_AND_TOKENS_PER_NAME[model]
71+
num_tokens = 0
72+
for message in messages:
73+
num_tokens += tokens_per_message
74+
num_tokens += len(encoding.encode(message.content))
75+
num_tokens += len(encoding.encode(message.role.value))
76+
if message.name:
77+
num_tokens += len(encoding.encode(message.name))
78+
num_tokens += tokens_per_name
79+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
80+
return num_tokens
81+
82+
def _get_model_name_for_tokeniser(self, model: Optional[str] = None) -> str:
83+
if model is None:
84+
model = self._default_model
85+
if model in {
86+
"gpt-3.5-turbo-0613",
87+
"gpt-3.5-turbo-16k-0613",
88+
"gpt-4-0314",
89+
"gpt-4-32k-0314",
90+
"gpt-4-0613",
91+
"gpt-4-32k-0613",
92+
}:
93+
return model
94+
elif model == "gpt-3.5-turbo-0301":
95+
return model
96+
elif "gpt-3.5-turbo" in model:
97+
print("Warning: gpt-3.5-turbo may update over time. Returning tokeniser assuming gpt-3.5-turbo-0613.")
98+
return "gpt-3.5-turbo-0613"
99+
elif "gpt-4" in model:
100+
print("Warning: gpt-4 may update over time. Returning tokeniser assuming gpt-4-0613.")
101+
return "gpt-4-0613"
102+
else:
103+
raise NotImplementedError(
104+
f"""not implemented for model {model}.
105+
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
106+
)
107+
69108
@staticmethod
70109
@lru_cache(maxsize=40)
71110
def _get_relevant_tokeniser(model: str) -> Encoding:

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ classifiers = [
2020
"Programming Language :: Python :: 3.10",
2121
"Programming Language :: Python :: 3.11",
2222
]
23-
dependencies = ["aiohttp >=3.0.0,<4.0.0"]
23+
dependencies = [
24+
"aiohttp >=3.0.0,<4.0.0",
25+
"dataclasses_json >= 0.5.0"
26+
]
2427
dynamic = ["version"]
2528

2629
[project.urls]
@@ -37,7 +40,6 @@ test = [
3740
openai = [
3841
"openai >=0.27.4",
3942
"tiktoken >=0.3.3",
40-
"dataclasses_json >= 0.5.0"
4143
]
4244
huggingface = [
4345
"transformers >= 4.0.0"

tests/llm_api_client/anthropic_client/test_anthropic_client.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from unittest.mock import AsyncMock
2+
13
import pytest
24

3-
from llm_client import LLMAPIClientFactory, LLMAPIClientType
5+
from llm_client import LLMAPIClientFactory, LLMAPIClientType, ChatMessage
46
from llm_client.consts import PROMPT_KEY, MODEL_KEY
57
from llm_client.llm_api_client.anthropic_client import AUTH_HEADER, COMPLETIONS_KEY, MAX_TOKENS_KEY, ACCEPT_HEADER, \
6-
ACCEPT_VALUE, VERSION_HEADER, AnthropicClient
8+
ACCEPT_VALUE, VERSION_HEADER, AnthropicClient, USER_PREFIX, ASSISTANT_PREFIX, START_PREFIX, SYSTEM_START_PREFIX, \
9+
SYSTEM_END_PREFIX
10+
from llm_client.llm_api_client.base_llm_api_client import Role
711

812

913
@pytest.mark.asyncio
@@ -14,6 +18,48 @@ async def test_get_llm_api_client__with_anthropic(config):
1418

1519
assert isinstance(actual, AnthropicClient)
1620

21+
@pytest.mark.asyncio
22+
async def test_chat_completion_sanity(llm_client):
23+
text_completion_mock = AsyncMock(return_value=["completion text"])
24+
llm_client.text_completion = text_completion_mock
25+
26+
actual = await llm_client.chat_completion(messages=[ChatMessage(Role.USER, "Why is the sky blue?")], max_tokens=10)
27+
28+
assert actual == ["completion text"]
29+
text_completion_mock.assert_awaited_once_with(f"{START_PREFIX}{USER_PREFIX} Why is the sky blue?"
30+
f"{START_PREFIX}{ASSISTANT_PREFIX}", None, 10, 1)
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_chat_completion_with_assistant_in_the_end(llm_client):
35+
text_completion_mock = AsyncMock(return_value=["completion text"])
36+
llm_client.text_completion = text_completion_mock
37+
38+
actual = await llm_client.chat_completion(messages=[ChatMessage(Role.USER, "Why is the sky blue?"),
39+
ChatMessage(Role.ASSISTANT, "Answer - ")], temperature=10)
40+
41+
assert actual == ["completion text"]
42+
text_completion_mock.assert_awaited_once_with(f"{START_PREFIX}{USER_PREFIX} Why is the sky blue?"
43+
f"{START_PREFIX}{ASSISTANT_PREFIX} Answer -", None, None,
44+
10)
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_chat_completion_with_system(llm_client):
49+
text_completion_mock = AsyncMock(return_value=["completion text"])
50+
llm_client.text_completion = text_completion_mock
51+
52+
actual = await llm_client.chat_completion(messages=[ChatMessage(Role.SYSTEM, "Be nice!"),
53+
ChatMessage(Role.USER, "Why is the sky blue?")], max_tokens=10,
54+
temperature=2)
55+
56+
assert actual == ["completion text"]
57+
text_completion_mock.assert_awaited_once_with(f"{START_PREFIX}{USER_PREFIX} "
58+
f"{SYSTEM_START_PREFIX}Be nice!{SYSTEM_END_PREFIX}{START_PREFIX}"
59+
f"{USER_PREFIX} Why is the sky blue?"
60+
f"{START_PREFIX}{ASSISTANT_PREFIX}", None, 10, 2)
61+
62+
1763
@pytest.mark.asyncio
1864
async def test_text_completion__sanity(mock_aioresponse, llm_client, complete_url, anthropic_version):
1965
mock_aioresponse.post(

0 commit comments

Comments
 (0)