Skip to content

Commit 04a5472

Browse files
Use Gemini (#1)
* Updated install files * Updated readme * Updated main and rmq * Updated model file name * Updated folder name * Updated model * Updated requirements * Use gemini * Added model field * Added system_prompt wrapper * Update to support Submind participation (#2) * Resovle import error Update dependencies for chatbotsforum compat. Update license tests * Use stable embeddings --------- Co-authored-by: Daniel McKnight <daniel@neon.ai> Co-authored-by: NeonBohdan <bohdan@neon.ai> * Fix instruction loss of empty history --------- Co-authored-by: Daniel McKnight <34697904+NeonDaniel@users.noreply.github.com> Co-authored-by: Daniel McKnight <daniel@neon.ai>
1 parent 38bb05e commit 04a5472

File tree

10 files changed

+51
-39
lines changed

10 files changed

+51
-39
lines changed

.github/workflows/license_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ jobs:
99
license_tests:
1010
uses: neongeckocom/.github/.github/workflows/license_tests.yml@master
1111
with:
12-
packages-exclude: '^(neon-llm-palm2|tqdm).*'
12+
packages-exclude: '^(neon-llm|tqdm|klat-connector|neon-chatbot|dnspython).*'

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
FROM python:3.9-slim
22

33
LABEL vendor=neon.ai \
4-
ai.neon.name="neon-llm-palm2"
4+
ai.neon.name="neon-llm-gemini"
55

66
ENV OVOS_CONFIG_BASE_FOLDER neon
77
ENV OVOS_CONFIG_FILENAME diana.yaml
@@ -12,4 +12,4 @@ WORKDIR /app
1212
COPY . /app
1313
RUN pip install /app
1414

15-
CMD [ "neon-llm-palm2" ]
15+
CMD [ "neon-llm-gemini" ]

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# NeonAI LLM Palm2
2-
Proxies API calls to Google Palm2.
1+
# NeonAI LLM Gemini
2+
Proxies API calls to Google Gemini.
33

44
## Request Format
55
API requests should include `history`, a list of tuples of strings, and the current
@@ -25,10 +25,11 @@ MQ:
2525
port: <MQ Port>
2626
server: <MQ Hostname or IP>
2727
users:
28-
neon_llm_palm2:
29-
password: <neon_palm2 user's password>
30-
user: neon_palm2
31-
LLM_PALM2:
28+
neon_llm_gemini:
29+
password: <neon_gemini user's password>
30+
user: neon_gemini
31+
LLM_GEMINI:
32+
model: "gemini-pro"
3233
key_path: ""
3334
role: "You are trying to give a short answer in less than 40 words."
3435
context_depth: 3
@@ -39,6 +40,6 @@ LLM_PALM2:
3940
For example, if your configuration resides in `~/.config`:
4041
```shell
4142
export CONFIG_PATH="/home/${USER}/.config"
42-
docker run -v ${CONFIG_PATH}:/config neon_llm_palm2
43+
docker run -v ${CONFIG_PATH}:/config neon_llm_gemini
4344
```
4445
> Note: If connecting to a local MQ server, you may need to specify `--network host`

docker_overlay/etc/neon/diana.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ MQ:
1414
mq_handler:
1515
user: neon_api_utils
1616
password: Klatchat2021
17-
LLM_PALM2:
17+
LLM_GEMINI:
18+
model: "gemini-pro"
1819
role: "You are trying to give a short answer in less than 40 words."
1920
context_depth: 3
2021
max_tokens: 100
File renamed without changes.

neon_llm_palm2/__main__.py renamed to neon_llm_gemini/__main__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@
2424
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
2525
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from neon_llm_palm2.rmq import Palm2MQ
27+
from neon_llm_gemini.rmq import GeminiMQ
2828

2929

3030
def main():
3131
# Run RabbitMQ
32-
palm2MQ = Palm2MQ()
33-
palm2MQ.run(run_sync=False, run_consumers=True,
32+
geminiMQ = GeminiMQ()
33+
geminiMQ.run(run_sync=False, run_consumers=True,
3434
daemonize_consumers=True)
35-
palm2MQ.observer_thread.join()
35+
geminiMQ.observer_thread.join()
3636

3737

3838
if __name__ == "__main__":

neon_llm_palm2/palm2.py renamed to neon_llm_gemini/gemini.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,27 @@
2525
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import os
28-
from vertexai.language_models import ChatModel, ChatMessage, TextEmbeddingModel
28+
from vertexai.preview.generative_models import GenerativeModel, Content, Part
29+
from vertexai.language_models import TextEmbeddingModel
2930
from openai.embeddings_utils import distances_from_embeddings
3031

3132
from typing import List, Dict
3233
from neon_llm_core.llm import NeonLLM
3334

3435

35-
class Palm2(NeonLLM):
36+
class Gemini(NeonLLM):
3637

3738
mq_to_llm_role = {
3839
"user": "user",
39-
"llm": "bot"
40+
"llm": "model"
4041
}
4142

4243
def __init__(self, config):
4344
super().__init__(config)
4445
self._embedding = None
4546
self._context_depth = 0
4647

48+
self.model_name = config["model"]
4749
self.role = config["role"]
4850
self.context_depth = config["context_depth"]
4951
self.max_tokens = config["max_tokens"]
@@ -67,9 +69,9 @@ def tokenizer_model_name(self) -> str:
6769
return ""
6870

6971
@property
70-
def model(self) -> ChatModel:
72+
def model(self) -> GenerativeModel:
7173
if self._model is None:
72-
self._model = ChatModel.from_pretrained("chat-bison")
74+
self._model = GenerativeModel(self.model_name)
7375
return self._model
7476

7577
@property
@@ -108,19 +110,20 @@ def get_sorted_answer_indexes(self, question: str, answers: List[str], persona:
108110

109111
def _call_model(self, prompt: Dict) -> str:
110112
"""
111-
Wrapper for Palm2 Model generation logic
113+
Wrapper for Gemini Model generation logic
112114
:param prompt: Input messages sequence
113115
:returns: Output text sequence generated by model
114116
"""
115117

116118
chat = self._model.start_chat(
117-
context=prompt["system_prompt"],
118-
message_history=prompt["chat_history"],
119-
max_output_tokens=self.max_tokens,
120-
temperature=0,
119+
history=prompt["chat_history"],
121120
)
122121
response = chat.send_message(
123122
prompt["message"],
123+
generation_config = {
124+
"temperature": 0,
125+
"max_output_tokens": self.max_tokens,
126+
}
124127
)
125128
text = response.text
126129

@@ -140,16 +143,23 @@ def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona:
140143
# Context N messages
141144
messages = []
142145
for role, content in chat_history[-self.context_depth:]:
143-
role_palm2 = self.convert_role(role)
144-
messages.append(ChatMessage(content, role_palm2))
146+
if ((len(messages) == 0) and (role == "user")):
147+
content = self._convert2instruction(content, system_prompt)
148+
role_gemini = self.convert_role(role)
149+
messages.append(Content(parts=[Part.from_text(content)], role = role_gemini))
150+
if (len(messages) == 0):
151+
message = self._convert2instruction(message, system_prompt)
145152
prompt = {
146-
"system_prompt": system_prompt,
147153
"chat_history": messages,
148154
"message": message
149155
}
150156

151157
return prompt
152158

159+
def _convert2instruction(self, content: str, system_prompt: str):
160+
instruction = f"{system_prompt.strip()}\n\n{content.strip()}"
161+
return instruction
162+
153163
def _score(self, prompt: str, targets: List[str], persona: dict) -> List[float]:
154164
"""
155165
Calculates logarithmic probabilities for the list of provided text sequences

neon_llm_palm2/rmq.py renamed to neon_llm_gemini/rmq.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
from neon_llm_core.rmq import NeonLLMMQConnector
2727

28-
from neon_llm_palm2.palm2 import Palm2
28+
from neon_llm_gemini.gemini import Gemini
2929

3030

31-
class Palm2MQ(NeonLLMMQConnector):
31+
class GeminiMQ(NeonLLMMQConnector):
3232
"""
33-
Module for processing MQ requests to Palm2
33+
Module for processing MQ requests to Gemini
3434
"""
3535

3636
def __init__(self):
@@ -39,12 +39,12 @@ def __init__(self):
3939

4040
@property
4141
def name(self):
42-
return "palm2"
42+
return "gemini"
4343

4444
@property
4545
def model(self):
4646
if self._model is None:
47-
self._model = Palm2(self.model_config)
47+
self._model = Gemini(self.model_config)
4848
return self._model
4949

5050
def warmup(self):

requirements/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# model
2-
google-cloud-aiplatform
2+
google-cloud-aiplatform>=1.38
33
openai[embeddings]~=0.27
44
# networking
5-
neon_llm_core~=0.1.0
5+
neon_llm_core[chatbots]~=0.1.0,>=0.1.1a1

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def get_requirements(requirements_filename: str):
6767
version = line.split("'")[1]
6868

6969
setup(
70-
name='neon-llm-palm2',
70+
name='neon-llm-gemini',
7171
version=version,
72-
description='LLM service for Palm2',
72+
description='LLM service for Gemini',
7373
long_description=long_description,
7474
long_description_content_type="text/markdown",
75-
url='https://github.com/NeonGeckoCom/neon-llm-palm2',
75+
url='https://github.com/NeonGeckoCom/neon-llm-gemini',
7676
author='Neongecko',
7777
author_email='developers@neon.ai',
7878
license='BSD-3.0',
@@ -85,7 +85,7 @@ def get_requirements(requirements_filename: str):
8585
],
8686
entry_points={
8787
'console_scripts': [
88-
'neon-llm-palm2=neon_llm_palm2.__main__:main'
88+
'neon-llm-gemini=neon_llm_gemini.__main__:main'
8989
]
9090
}
9191
)

0 commit comments

Comments
 (0)