Skip to content

Commit ce534dd

Browse files
jkwatsonewilliams-clouderamliu-clouderabaasitsharief
authored
Remove crew, replace with custom tool-calling (#243)
* wip on moving away from crew * wip lastFile:llm-service/app/services/query/agents/crewai_querier.py * drop databases lastFile:llm-service/app/services/query/agents/crewai_querier.py * WIP basic tool calling and minor refactoring of retriever tool lastFile:llm-service/app/services/query/tools/retriever.py * not working usage of openai agent * basic streaming example * wip on openai tool using * provide summaries to the retrieval tool * added source nodes to OpenAI agent stream * minor change to prompt * wip on supporting bedrock * WIP streaming for non open ai models * add conditional for openai * support for non openai agents and refactor * minor fix to return types * remove crew * minor change to check for empty response * fix mypy, ruff issues * Revert "bump mui packages that seemed to address issues with x-charts" This reverts commit 864b770. * Revert "update hide legend" This reverts commit b4db0dd. * remove the last crew bits * added verbosity for non openai agents * added missing print dashes * better error handling for bedrock model availability * fix mypy * fix writing source nodes --------- Co-authored-by: Elijah Williams <ewilliams@cloudera.com> Co-authored-by: Michael Liu <mliu@cloudera.com> Co-authored-by: Baasit Sharief <baasitsharief@gmail.com>
1 parent f96c664 commit ce534dd

File tree

30 files changed

+806
-2957
lines changed

30 files changed

+806
-2957
lines changed

llm-service/app/routers/index/sessions/__init__.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@
6363
from ....services.chat_history.paginator import paginate
6464
from ....services.metadata_apis import session_metadata_api
6565
from ....services.mlflow import rating_mlflow_log_metric, feedback_mlflow_log_table
66-
from ....services.query.agents.crewai_querier import poison_pill
67-
from ....services.query.crew_events import CrewEvent
66+
from ....services.query.agents.tool_calling_querier import poison_pill
67+
from ....services.query.chat_events import ToolEvent
6868
from ....services.session import rename_session
6969

7070
logger = logging.getLogger(__name__)
@@ -258,15 +258,15 @@ def stream_chat_completion(
258258
session = session_metadata_api.get_session(session_id, user_name=origin_remote_user)
259259
configuration = request.configuration or RagPredictConfiguration()
260260

261-
crew_events_queue: queue.Queue[CrewEvent] = queue.Queue()
261+
tool_events_queue: queue.Queue[ToolEvent] = queue.Queue()
262262
# Create a cancellation event to signal when the client disconnects
263263
cancel_event = threading.Event()
264264

265-
def crew_callback(chat_future: Future[Any]) -> Generator[str, None, None]:
265+
def tools_callback(chat_future: Future[Any]) -> Generator[str, None, None]:
266266
while True:
267267
# Check if client has disconnected
268268
if cancel_event.is_set():
269-
logger.info("Client disconnected, stopping crew callback")
269+
logger.info("Client disconnected, stopping tool callback")
270270
# Try to cancel the future if it's still running
271271
if not chat_future.done():
272272
chat_future.cancel()
@@ -276,14 +276,14 @@ def crew_callback(chat_future: Future[Any]) -> Generator[str, None, None]:
276276
raise e
277277

278278
try:
279-
event_data = crew_events_queue.get(block=True, timeout=1.0)
279+
event_data = tool_events_queue.get(block=True, timeout=1.0)
280280
if event_data.type == poison_pill:
281281
break
282282
event_json = json.dumps({"event": event_data.model_dump()})
283283
yield f"data: {event_json}\n\n"
284284
except queue.Empty:
285285
# Send a heartbeat event every second to keep the connection alive
286-
heartbeat = CrewEvent(
286+
heartbeat = ToolEvent(
287287
type="event", name="Processing", timestamp=time.time()
288288
)
289289
event_json = json.dumps({"event": heartbeat.model_dump()})
@@ -303,27 +303,28 @@ def generate_stream() -> Generator[str, None, None]:
303303
query=request.query,
304304
configuration=configuration,
305305
user_name=origin_remote_user,
306-
crew_events_queue=crew_events_queue,
306+
tool_events_queue=tool_events_queue,
307307
)
308308

309-
# Yield from crew_callback, which will check for cancellation
310-
yield from crew_callback(future)
309+
# Yield from tools_callback, which will check for cancellation
310+
yield from tools_callback(future)
311311

312312
# If we get here and the cancel_event is set, the client has disconnected
313313
if cancel_event.is_set():
314314
logger.info("Client disconnected, not processing results")
315315
return
316316

317317
first_message = True
318-
for response in future.result():
318+
stream = future.result()
319+
for response in stream:
319320
# Check for cancellation between each response
320321
if cancel_event.is_set():
321322
logger.info("Client disconnected during result processing")
322323
break
323324

324325
# send an initial message to let the client know the response stream is starting
325326
if first_message:
326-
done = CrewEvent(type="done", name="done", timestamp=time.time())
327+
done = ToolEvent(type="done", name="done", timestamp=time.time())
327328
event_json = json.dumps({"event": done.model_dump()})
328329
yield f"data: {event_json}\n\n"
329330
first_message = False

llm-service/app/services/caii/caii.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def get_llm(
131131
base_url=api_base,
132132
model=model,
133133
http_client=http_client
134-
# api_base=api_base, # todo: figure out how to integrate with Crew models
135134
)
136135

137136

llm-service/app/services/chat/streaming_chat.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@
5959
from app.services.metadata_apis.session_metadata_api import Session
6060
from app.services.mlflow import record_direct_llm_mlflow_run
6161
from app.services.query import querier
62-
from app.services.query.agents.crewai_querier import poison_pill
62+
from app.services.query.agents.tool_calling_querier import poison_pill
6363
from app.services.query.chat_engine import (
6464
FlexibleContextChatEngine,
6565
build_flexible_chat_engine,
6666
)
67-
from app.services.query.crew_events import CrewEvent
67+
from app.services.query.chat_events import ToolEvent
6868
from app.services.query.querier import (
6969
build_retriever,
7070
)
@@ -76,7 +76,7 @@ def stream_chat(
7676
query: str,
7777
configuration: RagPredictConfiguration,
7878
user_name: Optional[str],
79-
crew_events_queue: Queue[CrewEvent],
79+
tool_events_queue: Queue[ToolEvent],
8080
) -> Generator[ChatResponse, None, None]:
8181
query_configuration = QueryConfiguration(
8282
top_k=session.response_chunks,
@@ -99,14 +99,12 @@ def stream_chat(
9999
if not query_configuration.use_tool_calling and (
100100
len(session.data_source_ids) == 0 or total_data_sources_size == 0
101101
):
102-
# put a poison pill in the queue to stop the crew events stream
103-
crew_events_queue.put(CrewEvent(type=poison_pill, name="no-op"))
104-
return _stream_direct_llm_chat(
105-
session, response_id, query, user_name, crew_events_queue
106-
)
102+
# put a poison pill in the queue to stop the tool events stream
103+
tool_events_queue.put(ToolEvent(type=poison_pill, name="no-op"))
104+
return _stream_direct_llm_chat(session, response_id, query, user_name)
107105

108106
condensed_question, streaming_chat_response = build_streamer(
109-
crew_events_queue, query, query_configuration, session
107+
tool_events_queue, query, query_configuration, session
110108
)
111109
return _run_streaming_chat(
112110
session,
@@ -125,10 +123,11 @@ def _run_streaming_chat(
125123
query: str,
126124
query_configuration: QueryConfiguration,
127125
user_name: Optional[str],
126+
streaming_chat_response: StreamingAgentChatResponse,
128127
condensed_question: Optional[str] = None,
129-
streaming_chat_response: StreamingAgentChatResponse = None,
130128
) -> Generator[ChatResponse, None, None]:
131129
response: ChatResponse = ChatResponse(message=ChatMessage(content=query))
130+
132131
if streaming_chat_response.chat_stream:
133132
for response in streaming_chat_response.chat_stream:
134133
response.additional_kwargs["response_id"] = response_id
@@ -152,7 +151,7 @@ def _run_streaming_chat(
152151

153152

154153
def build_streamer(
155-
crew_events_queue: Queue[CrewEvent],
154+
chat_events_queue: Queue[ToolEvent],
156155
query: str,
157156
query_configuration: QueryConfiguration,
158157
session: Session,
@@ -181,9 +180,8 @@ def build_streamer(
181180
query,
182181
query_configuration,
183182
chat_messages,
184-
crew_events_queue=crew_events_queue,
183+
tool_events_queue=chat_events_queue,
185184
session=session,
186-
retriever=retriever,
187185
)
188186
return condensed_question, streaming_chat_response
189187

@@ -193,7 +191,6 @@ def _stream_direct_llm_chat(
193191
response_id: str,
194192
query: str,
195193
user_name: Optional[str],
196-
queue: Queue[CrewEvent],
197194
) -> Generator[ChatResponse, None, None]:
198195
record_direct_llm_mlflow_run(response_id, session, user_name)
199196

llm-service/app/services/models/providers/bedrock.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import requests
4545
from botocore.auth import SigV4Auth
4646
from botocore.awsrequest import AWSRequest
47+
from fastapi import HTTPException
4748
from llama_index.embeddings.bedrock import BedrockEmbedding
4849
from llama_index.llms.bedrock_converse import BedrockConverse
4950
from llama_index.llms.bedrock_converse.utils import BEDROCK_MODELS
@@ -153,22 +154,33 @@ def get_aws_responses(
153154
raise_for_http_error(response)
154155
return cast(dict[str, Any], response.json())
155156

156-
responses: list[dict[str, Any] | None] = [None for _ in aws_requests]
157+
responses: list[dict[str, Any] | None] = []
157158
with concurrent.futures.ThreadPoolExecutor() as executor:
158-
future_to_index = {
159-
executor.submit(get_aws_responses, url, headers): idx
160-
for idx, (url, headers) in enumerate(aws_requests)
161-
}
162-
for future in concurrent.futures.as_completed(future_to_index):
163-
idx = future_to_index[future]
159+
results = executor.map(
160+
lambda url_and_headers: get_aws_responses(*url_and_headers),
161+
aws_requests,
162+
)
163+
while True:
164164
try:
165-
responses[idx] = future.result()
166-
except Exception:
165+
result = next(results)
166+
responses.append(result)
167+
except StopIteration:
168+
break
169+
except HTTPException as e:
170+
model_id = str(e).split("/")[-1]
167171
logger.exception(
168-
"Error fetching data for model %s", models[idx]["modelId"]
172+
"Error fetching data for model %s",
173+
model_id,
169174
)
170-
responses[idx] = None
171-
175+
responses.append(None)
176+
continue
177+
except Exception as e:
178+
logger.exception(
179+
"Unexpected error fetching data: %s",
180+
e,
181+
)
182+
responses.append(None)
183+
continue
172184
for model, model_data in zip(models, responses):
173185
if model_data:
174186
if model_data["entitlementAvailability"] == "AVAILABLE":

0 commit comments

Comments
 (0)