Skip to content

Commit cc2cc69

Browse files
committed
add streaming to cohere and ollama
1 parent 9670aff commit cc2cc69

File tree

5 files changed

+139
-9
lines changed

5 files changed

+139
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "ada-python"
3-
version = "0.4.0"
3+
version = "0.5.0"
44
description = "Ada, making LLMs easier to work with."
55
authors = ["Will Beebe"]
66
packages = [

src/abcs/anthropic.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,4 @@ async def content_generator():
198198
raise e
199199

200200
async def handle_tool_call(self, tool_calls, combined_history, tools):
201-
# This is a placeholder for handling tool calls in streaming context
202-
# You'll need to implement the logic to execute the tool call and generate a response
203201
pass

src/abcs/cohere.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import asyncio
12
import logging
3+
from typing import Any, Dict, List, Optional
24

35
import cohere
46
from abcs.llm import LLM
5-
from abcs.models import PromptResponse, UsageStats
7+
from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats
68
from tools.tool_manager import ToolManager
79

810
logging.basicConfig(level=logging.INFO)
@@ -121,3 +123,75 @@ def _translate_response(self, response) -> PromptResponse:
121123
)
122124
raise e
123125

126+
# https://github.com/cohere-ai/cohere-python/blob/main/src/cohere/types/streamed_chat_response.py
127+
# https://docs.cohere.com/docs/streaming#stream-events
128+
# https://docs.cohere.com/docs/streaming#example-responses
129+
async def generate_text_stream(
130+
self,
131+
prompt: str,
132+
past_messages: List[Dict[str, str]],
133+
tools: Optional[List[Dict[str, Any]]] = None,
134+
**kwargs,
135+
) -> StreamingPromptResponse:
136+
combined_history = past_messages + [{"role": "user", "content": prompt}]
137+
138+
try:
139+
combined_history = []
140+
for msg in past_messages:
141+
combined_history.append({
142+
"role": 'CHATBOT' if msg['role'] == 'assistant' else 'USER',
143+
"message": msg['content'],
144+
})
145+
stream = self.client.chat_stream(
146+
chat_history=combined_history,
147+
message=prompt,
148+
tools=tools,
149+
model=self.model,
150+
# perform web search before answering the question. You can also use your own custom connector.
151+
# connectors=[{"id": "web-search"}],
152+
)
153+
154+
async def content_generator():
155+
for event in stream:
156+
if isinstance(event, cohere.types.StreamedChatResponse_StreamStart):
157+
# Message start event, we can ignore this
158+
pass
159+
elif isinstance(event, cohere.types.StreamedChatResponse_TextGeneration):
160+
# This is the event that contains the actual text
161+
if event.text:
162+
yield event.text
163+
elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsGeneration):
164+
# todo: call tool
165+
pass
166+
elif isinstance(event, cohere.types.StreamedChatResponse_CitationGeneration):
167+
# todo: not sure, but seems useful
168+
pass
169+
elif isinstance(event, cohere.types.StreamedChatResponse_ToolCallsChunk):
170+
# todo: tool response
171+
pass
172+
elif isinstance(event, cohere.types.StreamedChatResponse_SearchQueriesGeneration):
173+
pass
174+
elif isinstance(event, cohere.types.StreamedChatResponse_SearchResults):
175+
pass
176+
elif isinstance(event, cohere.types.StreamedChatResponse_StreamEnd):
177+
# Message stop event, we can ignore this
178+
pass
179+
# Small delay to allow for cooperative multitasking
180+
await asyncio.sleep(0)
181+
182+
return StreamingPromptResponse(
183+
content=content_generator(),
184+
raw_response=stream,
185+
error={},
186+
usage=UsageStats(
187+
input_tokens=0, # These will need to be updated after streaming
188+
output_tokens=0,
189+
extra={},
190+
),
191+
)
192+
except Exception as e:
193+
logger.exception(f"An error occurred while streaming from Claude: {e}")
194+
raise e
195+
196+
async def handle_tool_call(self, tool_calls, combined_history, tools):
197+
pass

src/abcs/ollama.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import asyncio
12
import logging
23
import os
34
from typing import Any, Dict, List, Optional
45

56
from abcs.llm import LLM
6-
from abcs.models import OllamaResponse, PromptResponse, UsageStats
7+
from abcs.models import (
8+
OllamaResponse,
9+
PromptResponse,
10+
StreamingPromptResponse,
11+
UsageStats,
12+
)
713
from ollama import Client
814
from tools.tool_manager import ToolManager
915

@@ -110,3 +116,59 @@ def _translate_response(self, response) -> PromptResponse:
110116
except Exception as e:
111117
logger.exception(f"An error occurred while translating Ollama response: {e}")
112118
raise e
119+
120+
async def generate_text_stream(
121+
self,
122+
prompt: str,
123+
past_messages: List[Dict[str, str]],
124+
tools: Optional[List[Dict[str, Any]]] = None,
125+
**kwargs,
126+
) -> StreamingPromptResponse:
127+
combined_history = past_messages + [{"role": "user", "content": prompt}]
128+
129+
try:
130+
combined_history = past_messages
131+
combined_history.append(
132+
{
133+
"role": "user",
134+
"content": prompt,
135+
}
136+
)
137+
# https://github.com/ollama/ollama-python
138+
# client = Client(host="https://120d-2606-40-15c-13ba-00-460-7bae.ngrok-free.app",)
139+
140+
# todo: generate vs chat
141+
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
142+
stream = self.client.chat(
143+
model=self.model,
144+
messages=combined_history,
145+
stream=True,
146+
# num_predict=4000
147+
# todo
148+
# system=self.system_prompt
149+
)
150+
151+
152+
async def content_generator():
153+
for chunk in stream:
154+
if chunk['message']['content']:
155+
yield chunk['message']['content']
156+
# Small delay to allow for cooperative multitasking
157+
await asyncio.sleep(0)
158+
159+
return StreamingPromptResponse(
160+
content=content_generator(),
161+
raw_response=stream,
162+
error={},
163+
usage=UsageStats(
164+
input_tokens=0, # These will need to be updated after streaming
165+
output_tokens=0,
166+
extra={},
167+
),
168+
)
169+
except Exception as e:
170+
logger.exception(f"An error occurred while streaming from Claude: {e}")
171+
raise e
172+
173+
async def handle_tool_call(self, tool_calls, combined_history, tools):
174+
pass

src/abcs/openai.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,4 @@ async def content_generator():
242242
raise e
243243

244244
async def handle_tool_call(self, collected_content, combined_history, tools):
245-
# This is a placeholder for handling tool calls in streaming context
246-
# You'll need to implement the logic to parse the tool call, execute it,
247-
# and generate a response based on the tool's output
248-
# This might involve breaking the streaming and making a new API call
249245
pass

0 commit comments

Comments
 (0)