Skip to content

Commit b0fe522

Browse files
committed
expose experimental streaming methods for anthropic, openai, and groq
1 parent c1a8bef commit b0fe522

File tree

5 files changed

+172
-6
lines changed

5 files changed

+172
-6
lines changed

src/abcs/anthropic.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import asyncio
12
import logging
23
import os
34
from typing import Any, Dict, List, Optional
45

56
import anthropic
67
from abcs.llm import LLM
7-
from abcs.models import PromptResponse, UsageStats
8+
from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats
89
from tools.tool_manager import ToolManager
910

1011
logging.basicConfig(level=logging.INFO)
@@ -142,3 +143,61 @@ def _translate_response(self, response) -> PromptResponse:
142143
except Exception as e:
143144
logger.exception(f"error: {e}\nresponse: {response}")
144145
raise e
146+
147+
async def generate_text_stream(
148+
self,
149+
prompt: str,
150+
past_messages: List[Dict[str, str]],
151+
tools: Optional[List[Dict[str, Any]]] = None,
152+
**kwargs,
153+
) -> StreamingPromptResponse:
154+
combined_history = past_messages + [{"role": "user", "content": prompt}]
155+
156+
try:
157+
stream = self.client.messages.create(
158+
model=self.model,
159+
max_tokens=4096,
160+
messages=combined_history,
161+
system=self.system_prompt,
162+
stream=True,
163+
)
164+
165+
async def content_generator():
166+
for event in stream:
167+
if isinstance(event, anthropic.types.MessageStartEvent):
168+
# Message start event, we can ignore this
169+
pass
170+
elif isinstance(event, anthropic.types.ContentBlockStartEvent):
171+
# Content block start event, we can ignore this
172+
pass
173+
elif isinstance(event, anthropic.types.ContentBlockDeltaEvent):
174+
# This is the event that contains the actual text
175+
if event.delta.text:
176+
yield event.delta.text
177+
elif isinstance(event, anthropic.types.ContentBlockStopEvent):
178+
# Content block stop event, we can ignore this
179+
pass
180+
elif isinstance(event, anthropic.types.MessageStopEvent):
181+
# Message stop event, we can ignore this
182+
pass
183+
# Small delay to allow for cooperative multitasking
184+
await asyncio.sleep(0)
185+
186+
return StreamingPromptResponse(
187+
content=content_generator(),
188+
raw_response=stream,
189+
error={},
190+
usage=UsageStats(
191+
input_tokens=0, # These will need to be updated after streaming
192+
output_tokens=0,
193+
extra={},
194+
),
195+
)
196+
except Exception as e:
197+
logger.exception(f"An error occurred while streaming from Claude: {e}")
198+
raise e
199+
200+
async def handle_tool_call(self, tool_calls, combined_history, tools):
201+
# This is a placeholder for handling tool calls in streaming context
202+
# You'll need to implement the logic to execute the tool call and generate a response
203+
pass

src/abcs/llm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from importlib import resources
55

66
import yaml
7-
from abcs.models import PromptResponse
7+
from abcs.models import PromptResponse, StreamingPromptResponse
88
from abcs.tools import gen_anthropic, gen_cohere, gen_google, gen_openai
99

1010
# Add the project root to the Python path
@@ -33,6 +33,15 @@ def generate_text(self,
3333
"""Generates text based on the given prompt and additional arguments."""
3434
pass
3535

36+
@abstractmethod
37+
async def generate_text_stream(self,
38+
prompt: str,
39+
past_messages,
40+
tools,
41+
**kwargs) -> StreamingPromptResponse:
42+
"""Generates streaming text based on the given prompt and additional arguments."""
43+
pass
44+
3645
@abstractmethod
3746
def call_tool(self, past_messages, tool_msg) -> str:
3847
"""Calls a specific tool with the given arguments and returns the response."""

src/abcs/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, AsyncIterator
22

33
from pydantic import BaseModel
44

@@ -29,3 +29,12 @@ class OllamaResponse(BaseModel):
2929
prompt_eval_duration: int
3030
eval_count: int
3131
eval_duration: int
32+
33+
class StreamingPromptResponse(BaseModel):
34+
content: AsyncIterator[str]
35+
raw_response: Any
36+
error: object
37+
usage: UsageStats
38+
39+
class Config:
40+
arbitrary_types_allowed = True

src/abcs/openai.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import asyncio
12
import json
23
import logging
34
import os
45
from typing import Any, Dict, List, Optional
56

6-
import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401
7+
# todo: need to support this for multi tool use, maybe upstream package has it fixed now.
8+
# commented out because it's not working with streams
9+
# import openai_multi_tool_use_parallel_patch # type: ignore # noqa: F401
710
from abcs.llm import LLM
8-
from abcs.models import PromptResponse, UsageStats
11+
from abcs.models import PromptResponse, StreamingPromptResponse, UsageStats
912
from openai import OpenAI
1013
from tools.tool_manager import ToolManager
1114

@@ -188,3 +191,59 @@ def _translate_response(self, response) -> PromptResponse:
188191
# logger.error("An error occurred while translating OpenAI response: %s", e, exc_info=True)
189192
logger.exception(f"error: {e}\nresponse: {response}")
190193
raise e
194+
195+
# https://cookbook.openai.com/examples/how_to_stream_completions
196+
async def generate_text_stream(
197+
self,
198+
prompt: str,
199+
past_messages: List[Dict[str, str]],
200+
tools: Optional[List[Dict[str, Any]]] = None,
201+
**kwargs,
202+
) -> StreamingPromptResponse:
203+
system_message = [{"role": "system", "content": self.system_prompt}] if self.system_prompt else []
204+
combined_history = system_message + past_messages + [{"role": "user", "content": prompt}]
205+
206+
try:
207+
stream = self.client.chat.completions.create(
208+
model=self.model,
209+
messages=combined_history,
210+
tools=tools,
211+
stream=True,
212+
)
213+
214+
async def content_generator():
215+
for event in stream:
216+
# print("HERE\n"*30)
217+
# print(event)
218+
if event.choices[0].delta.content is not None:
219+
yield event.choices[0].delta.content
220+
# Small delay to allow for cooperative multitasking
221+
await asyncio.sleep(0)
222+
223+
# # After the stream is complete, you might want to handle tool calls here
224+
# # This is a simplification and may need to be adjusted based on your needs
225+
# if tools and collected_content.strip().startswith('{"function":'):
226+
# # Handle tool calls (simplified example)
227+
# tool_response = await self.handle_tool_call(collected_content, combined_history, tools)
228+
# yield tool_response
229+
230+
return StreamingPromptResponse(
231+
content=content_generator(),
232+
raw_response=stream,
233+
error={},
234+
usage=UsageStats(
235+
input_tokens=0, # These will need to be updated after streaming
236+
output_tokens=0,
237+
extra={},
238+
),
239+
)
240+
except Exception as e:
241+
logger.error("Error generating text stream: %s", e, exc_info=True)
242+
raise e
243+
244+
async def handle_tool_call(self, collected_content, combined_history, tools):
245+
# This is a placeholder for handling tool calls in streaming context
246+
# You'll need to implement the logic to parse the tool call, execute it,
247+
# and generate a response based on the tool's output
248+
# This might involve breaking the streaming and making a new API call
249+
pass

src/agents/agent.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from abcs.llm import LLM
4-
from abcs.models import PromptResponse
4+
from abcs.models import PromptResponse, StreamingPromptResponse
55

66
# from metrics.main import call_tool_counter, generate_text_counter
77
from storage.storage_manager import StorageManager
@@ -13,6 +13,8 @@
1313

1414
class Agent(LLM):
1515
def __init__(self, client, tool_manager: ToolManager, system_prompt: str = "", tools=[], storage_manager: StorageManager = None):
16+
if len(tools) == 0 and (client.provider == "openai" or client.provider == "groq"):
17+
tools = None
1618
self.tools = tools
1719
logger.debug("Initializing Agent with tools: %s and system prompt: '%s'", tools, system_prompt)
1820
super().__init__(
@@ -90,3 +92,31 @@ def _translate_response(self, response) -> PromptResponse:
9092
# except Exception as e:
9193
# logger.error("Error translating response: %s", e, exc_info=True)
9294
# raise e
95+
96+
async def generate_text_stream(self,
97+
prompt: str,
98+
**kwargs) -> StreamingPromptResponse:
99+
"""Generates streaming text based on the given prompt and additional arguments."""
100+
past_messages = []
101+
if self.storage_manager is not None:
102+
past_messages = self.storage_manager.get_past_messages()
103+
logger.debug("Fetched %d past messages", len(past_messages))
104+
if self.storage_manager is not None:
105+
self.storage_manager.store_message("user", prompt)
106+
try:
107+
response = await self.client.generate_text_stream(prompt, past_messages, self.tools)
108+
except Exception as err:
109+
if self.storage_manager is not None:
110+
self.storage_manager.remove_last()
111+
raise err
112+
113+
# TODO: can't do this with streaming. have to handle this in the API
114+
# if self.storage_manager is not None:
115+
# try:
116+
# # translated = self._translate_response(response)
117+
# self.storage_manager.store_message("assistant", response.content)
118+
# except Exception as e:
119+
# logger.error("Error storing messages: %s", e, exc_info=True)
120+
# raise e
121+
122+
return response

0 commit comments

Comments
 (0)