Skip to content

Commit 57f8207

Browse files
authored
Added Top_p #28 (#29)
* bug fix * added top_p * added top_p for google and BaseLLMClientApi * changes after cr * json fix * fix anthropic tests * fix ai21 tests * small changes * small changes
1 parent 50bc2c8 commit 57f8207

File tree

14 files changed

+41
-35
lines changed

14 files changed

+41
-35
lines changed

llm_client/llm_api_client/ai21_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def __init__(self, config: LLMAPIClientConfig):
2222
self._headers[AUTH_HEADER] = BEARER_TOKEN + self._api_key
2323

2424
async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: int = 16,
25-
temperature: float = 0.7, **kwargs) -> list[str]:
25+
temperature: float = 0.7, top_p: float = 1,**kwargs) -> list[str]:
2626
model = model or self._default_model
2727
kwargs[PROMPT_KEY] = prompt
28+
kwargs["topP"] = kwargs.pop("topP", top_p)
2829
kwargs["maxTokens"] = kwargs.pop("maxTokens", max_tokens)
2930
kwargs["temperature"] = temperature
3031
response = await self._session.post(self._base_url + model + "/" + COMPLETE_PATH,

llm_client/llm_api_client/aleph_alpha_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ def __init__(self, config: LLMAPIClientConfig):
2727
self._headers[AUTH_HEADER] = BEARER_TOKEN + self._api_key
2828

2929
async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
30-
temperature: float = 0, **kwargs) -> \
30+
temperature: float = 0,top_p: float = 0, **kwargs) -> \
3131
list[str]:
3232
self._set_model_in_kwargs(kwargs, model)
3333
if max_tokens is None:
3434
raise ValueError("max_tokens must be specified")
3535
kwargs[PROMPT_KEY] = prompt
36+
kwargs["top_p"] = top_p
3637
kwargs["maximum_tokens"] = kwargs.pop("maximum_tokens", max_tokens)
3738
kwargs["temperature"] = temperature
3839
response = await self._session.post(self._base_url + COMPLETE_PATH,

llm_client/llm_api_client/anthropic_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ def __init__(self, config: LLMAPIClientConfig):
2727
self._headers[AUTH_HEADER] = self._api_key
2828

2929
async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
30-
temperature: float = 1,
30+
temperature: float = 1, top_p: Optional[float] = None,
3131
**kwargs) -> \
3232
list[str]:
3333
if max_tokens is None and kwargs.get(MAX_TOKENS_KEY) is None:
3434
raise ValueError(f"max_tokens or {MAX_TOKENS_KEY} must be specified")
35+
if top_p:
36+
kwargs["top_p"] = top_p
3537
self._set_model_in_kwargs(kwargs, model)
3638
kwargs[PROMPT_KEY] = prompt
3739
kwargs[MAX_TOKENS_KEY] = kwargs.pop(MAX_TOKENS_KEY, max_tokens)

llm_client/llm_api_client/base_llm_api_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, config: LLMAPIClientConfig):
3030

3131
@abstractmethod
3232
async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
33-
temperature: Optional[float] = None, **kwargs) -> list[str]:
33+
temperature: Optional[float] = None,top_p : Optional[float] = None, **kwargs) -> list[str]:
3434
raise NotImplementedError()
3535

3636
async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]:

llm_client/llm_api_client/google_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ def __init__(self, config: LLMAPIClientConfig):
3333
self._params = {AUTH_PARAM: self._api_key}
3434

3535
async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = 64,
36-
temperature: Optional[float] = None, **kwargs) -> list[str]:
36+
temperature: Optional[float] = None,top_p: Optional[float] = None, **kwargs) -> list[str]:
3737
model = model or self._default_model
3838
kwargs[PROMPT_KEY] = {TEXT_KEY: prompt}
3939
kwargs[MAX_TOKENS_KEY] = kwargs.pop(MAX_TOKENS_KEY, max_tokens)
40+
if top_p:
41+
kwargs["topP"] = top_p
4042
kwargs["temperature"] = kwargs.pop("temperature", temperature)
4143
response = await self._session.post(self._base_url + model + ":" + COMPLETE_PATH,
4244
params=self._params,

llm_client/llm_api_client/huggingface_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ def __init__(self, config: LLMAPIClientConfig):
2929
self._headers[AUTH_HEADER] = BEARER_TOKEN + self._api_key
3030

3131
async def text_completion(self, prompt: str, max_tokens: Optional[int] = None, temperature: float = 1.0,
32-
model: Optional[str] = None, **kwargs) -> list[str]:
32+
model: Optional[str] = None, top_p: Optional[float] = None, **kwargs) -> list[str]:
3333
model = model or self._default_model
34+
kwargs["top_p"] = top_p
3435
kwargs[INPUT_KEY] = prompt
3536
kwargs[TEMPERATURE_KEY] = temperature
3637
kwargs[TOKENS_KEY] = kwargs.pop(TOKENS_KEY, max_tokens)

llm_client/llm_api_client/openai_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,21 @@ def __init__(self, config: LLMAPIClientConfig):
3636
self._client = openai
3737

3838
async def text_completion(self, prompt: str, model: Optional[str] = None, temperature: float = 0,
39-
max_tokens: int = 16, **kwargs) -> list[str]:
39+
max_tokens: int = 16, top_p: float = 1, **kwargs) -> list[str]:
4040
self._set_model_in_kwargs(kwargs, model)
4141
kwargs[PROMPT_KEY] = prompt
42+
kwargs["top_p"] = top_p
4243
kwargs["temperature"] = temperature
4344
kwargs["max_tokens"] = max_tokens
4445
completions = await self._client.Completion.acreate(headers=self._headers, **kwargs)
4546
return [choice.text for choice in completions.choices]
4647

4748
async def chat_completion(self, messages: list[ChatMessage], temperature: float = 0,
48-
max_tokens: int = 16, model: Optional[str] = None, **kwargs) -> list[str]:
49+
max_tokens: int = 16, top_p: float = 1, model: Optional[str] = None, **kwargs) -> list[str]:
4950
self._set_model_in_kwargs(kwargs, model)
5051
kwargs["messages"] = [message.to_dict() for message in messages]
5152
kwargs["temperature"] = temperature
53+
kwargs["top_p"] = top_p
5254
kwargs["max_tokens"] = max_tokens
5355
completions = await self._client.ChatCompletion.acreate(headers=self._headers, **kwargs)
5456
return [choice.message.content for choice in completions.choices]

tests/llm_api_client/ai21_client/test_ai21.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def test_text_completion__sanity(mock_aioresponse, llm_client, url):
3030
'friends, entertaining family...you get the point! One of my favorite things to do is plan parties']
3131
mock_aioresponse.assert_called_once_with(url, method='POST',
3232
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key },
33-
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7 },
33+
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 1 },
3434
raise_for_status=True)
3535

3636

@@ -49,7 +49,7 @@ async def test_text_completion__return_multiple_completions(mock_aioresponse, ll
4949
]
5050
mock_aioresponse.assert_called_once_with(url, method='POST',
5151
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
52-
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7 },
52+
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 1 },
5353
raise_for_status=True)
5454

5555

@@ -69,7 +69,7 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client):
6969
'friends, entertaining family...you get the point! One of my favorite things to do is plan parties']
7070
mock_aioresponse.assert_called_once_with(url, method='POST',
7171
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
72-
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7 },
72+
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 1 },
7373
raise_for_status=True)
7474

7575

@@ -87,7 +87,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, url):
8787
'friends, entertaining family...you get the point! One of my favorite things to do is plan parties']
8888
mock_aioresponse.assert_called_once_with(url, method='POST',
8989
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
90-
json={'prompt': 'These are a few of my favorite', "maxTokens" : 10, "temperature" : 0.7 },
90+
json={'prompt': 'These are a few of my favorite', "maxTokens" : 10, "temperature" : 0.7 ,"topP" : 1},
9191
raise_for_status=True)
9292

9393

tests/llm_api_client/anthropic_client/test_anthropic_client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@ async def test_get_llm_api_client__with_anthropic(config):
1414

1515
assert isinstance(actual, AnthropicClient)
1616

17-
1817
@pytest.mark.asyncio
1918
async def test_text_completion__sanity(mock_aioresponse, llm_client, complete_url, anthropic_version):
2019
mock_aioresponse.post(
2120
complete_url,
2221
payload={COMPLETIONS_KEY: "completion text"}
2322
)
2423

25-
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10)
24+
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10,)
2625

2726
assert actual == ["completion text"]
2827
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
@@ -92,7 +91,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple
9291
payload={COMPLETIONS_KEY: "completion text"}
9392
)
9493

95-
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, temperature=0.5)
94+
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, temperature=0.5,top_p=0.5)
9695

9796
assert actual == ["completion text"]
9897
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
@@ -102,7 +101,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple
102101
json={PROMPT_KEY: 'These are a few of my favorite',
103102
MAX_TOKENS_KEY: 10,
104103
MODEL_KEY: llm_client._default_model,
105-
"temperature": 0.5},
104+
"temperature": 0.5, "top_p" : 0.5},
106105
raise_for_status=True)
107106

108107

tests/llm_api_client/google_client/test_google_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, params
6868
payload=load_json_resource("google/text_completion.json")
6969
)
7070

71-
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, blabla="aaa")
71+
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, blabla="aaa", top_p= 0.95)
7272

7373
assert actual == ['Once upon a time, there was a young girl named Lily...',
7474
'Once upon a time, there was a young boy named Billy...']
7575
mock_aioresponse.assert_called_once_with(url, method='POST', params={AUTH_PARAM: llm_client._api_key},
7676
json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'},
7777
MAX_TOKENS_KEY: 10,
7878
'temperature': None,
79-
'blabla': 'aaa'},
79+
'blabla': 'aaa',"topP" : 0.95},
8080
headers=llm_client._headers,
8181
raise_for_status=True,
8282
)

tests/llm_api_client/huggingface_client/test_huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def test_text_completion__sanity(mock_aioresponse, llm_client, url):
2828
assert actual == ['Kobe Bryant is a retired professional basketball player who played for the Los Angeles Lakers of']
2929
mock_aioresponse.assert_called_once_with(url, method='POST',
3030
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
31-
json={'inputs': 'who is kobe bryant',"max_length": None, "temperature": 1.0},
31+
json={'inputs': 'who is kobe bryant',"max_length": None, "temperature": 1.0, "top_p" : None},
3232
raise_for_status=True)
3333

3434

@@ -44,7 +44,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, url):
4444
assert actual == ['Kobe Bryant is a retired professional basketball player who played for the Los Angeles Lakers of']
4545
mock_aioresponse.assert_called_once_with(url, method='POST',
4646
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
47-
json={'inputs': 'who is kobe bryant',"max_length": 10, "temperature": 1.0},
47+
json={'inputs': 'who is kobe bryant',"max_length": 10, "temperature": 1.0, "top_p" : None},
4848
raise_for_status=True)
4949

5050

tests/llm_api_client/openai_client/test_openai.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def test_text_completion__sanity(openai_mock, open_ai_client, model_name):
3737
openai_mock.Completion.acreate.assert_awaited_once_with(
3838
model=model_name,
3939
prompt="These are a few of my favorite",
40-
headers={},temperature=0,max_tokens=16)
40+
headers={},temperature=0,max_tokens=16,top_p=1)
4141

4242

4343
@pytest.mark.asyncio
@@ -52,7 +52,7 @@ async def test_text_completion__return_multiple_completions(openai_mock, open_ai
5252
openai_mock.Completion.acreate.assert_awaited_once_with(
5353
model=model_name,
5454
prompt="These are a few of my favorite",
55-
headers={},temperature=0,max_tokens=16)
55+
headers={},temperature=0,max_tokens=16,top_p=1)
5656

5757

5858
@pytest.mark.asyncio
@@ -67,7 +67,7 @@ async def test_text_completion__override_model(openai_mock, open_ai_client, mode
6767
openai_mock.Completion.acreate.assert_awaited_once_with(
6868
model=new_model_name,
6969
prompt="These are a few of my favorite",
70-
headers={},temperature=0,max_tokens=16)
70+
headers={},temperature=0,max_tokens=16,top_p=1)
7171

7272

7373
@pytest.mark.asyncio
@@ -81,7 +81,7 @@ async def test_text_completion__with_kwargs(openai_mock, open_ai_client, model_n
8181
openai_mock.Completion.acreate.assert_awaited_once_with(
8282
model=model_name,
8383
prompt="These are a few of my favorite",
84-
temperature=0,max_tokens=10,
84+
temperature=0,max_tokens=10,top_p=1,
8585
headers={})
8686

8787

@@ -98,7 +98,7 @@ async def test_text_completion__with_headers(openai_mock, model_name):
9898
openai_mock.Completion.acreate.assert_awaited_once_with(
9999
model=model_name,
100100
prompt="These are a few of my favorite",
101-
headers={"header_name": "header_value"},temperature=0,max_tokens=16)
101+
headers={"header_name": "header_value"},temperature=0,max_tokens=16,top_p=1)
102102

103103

104104
@pytest.mark.asyncio
@@ -112,7 +112,7 @@ async def test_chat_completion__sanity(openai_mock, open_ai_client, model_name):
112112
openai_mock.ChatCompletion.acreate.assert_awaited_once_with(
113113
model=model_name,
114114
messages=[{'content': 'Hello!', 'role': 'user'}],
115-
headers={},temperature=0,max_tokens=16)
115+
headers={},temperature=0,max_tokens=16,top_p=1)
116116

117117

118118
@pytest.mark.asyncio
@@ -127,7 +127,7 @@ async def test_chat_completion__return_multiple_completions(openai_mock, open_ai
127127
openai_mock.ChatCompletion.acreate.assert_awaited_once_with(
128128
model=model_name,
129129
messages=[{'content': 'Hello!', 'role': 'user'}],
130-
headers={},temperature=0,max_tokens=16)
130+
headers={},temperature=0,max_tokens=16,top_p=1)
131131

132132

133133
@pytest.mark.asyncio
@@ -142,22 +142,22 @@ async def test_chat_completion__override_model(openai_mock, open_ai_client, mode
142142
openai_mock.ChatCompletion.acreate.assert_awaited_once_with(
143143
model=new_model_name,
144144
messages=[{'content': 'Hello!', 'role': 'user'}],
145-
headers={},temperature=0,max_tokens=16)
145+
headers={},temperature=0,max_tokens=16,top_p=1)
146146

147147

148148
@pytest.mark.asyncio
149149
async def test_chat_completion__with_kwargs(openai_mock, open_ai_client, model_name):
150150
openai_mock.ChatCompletion.acreate = AsyncMock(
151151
return_value=OpenAIObject.construct_from(load_json_resource("openai/chat_completion.json")))
152152

153-
actual = await open_ai_client.chat_completion([ChatMessage(Role.USER, "Hello!")], max_tokens=10)
153+
actual = await open_ai_client.chat_completion([ChatMessage(Role.USER, "Hello!")], max_tokens=10,top_p=1)
154154

155155
assert actual == ["\n\nHello there, how may I assist you today?"]
156156
openai_mock.ChatCompletion.acreate.assert_awaited_once_with(
157157
model=model_name,
158158
messages=[{'content': 'Hello!', 'role': 'user'}],
159159
max_tokens=10,
160-
headers={},temperature=0)
160+
headers={},temperature=0,top_p=1)
161161

162162

163163
@pytest.mark.asyncio
@@ -173,7 +173,7 @@ async def test_chat_completion__with_headers(openai_mock, model_name):
173173
openai_mock.ChatCompletion.acreate.assert_awaited_once_with(
174174
model=model_name,
175175
messages=[{'content': 'Hello!', 'role': 'user'}],
176-
headers={"header_name": "header_value"},temperature=0,max_tokens=16)
176+
headers={"header_name": "header_value"},temperature=0,max_tokens=16,top_p=1)
177177

178178

179179
@pytest.mark.asyncio

tests/resources/openai/chat_completion.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@
1313
"usage": {
1414
"prompt_tokens": 9,
1515
"completion_tokens": 12,
16-
"total_tokens": 21
17-
}
16+
"total_tokens": 21}
1817
}

tests/resources/openai/text_completion.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,5 @@
1414
"usage": {
1515
"prompt_tokens": 5,
1616
"completion_tokens": 7,
17-
"total_tokens": 12
18-
}
17+
"total_tokens": 12}
1918
}

0 commit comments

Comments
 (0)