diff --git a/llm-service/app/routers/index/sessions/__init__.py b/llm-service/app/routers/index/sessions/__init__.py index 42c9212f..55827886 100644 --- a/llm-service/app/routers/index/sessions/__init__.py +++ b/llm-service/app/routers/index/sessions/__init__.py @@ -38,14 +38,14 @@ import base64 import json import logging -import queue import threading import time -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from typing import Optional, Generator, Any from fastapi import APIRouter, Header, HTTPException from fastapi.responses import StreamingResponse +from llama_index.core.base.llms.types import ChatResponse from pydantic import BaseModel from starlette.responses import ContentStream from starlette.types import Receive @@ -63,8 +63,7 @@ from ....services.chat_history.paginator import paginate from ....services.metadata_apis import session_metadata_api from ....services.mlflow import rating_mlflow_log_metric, feedback_mlflow_log_table -from ....services.query.agents.tool_calling_querier import poison_pill -from ....services.query.chat_events import ToolEvent +from ....services.query.chat_events import ChatEvent from ....services.session import rename_session logger = logging.getLogger(__name__) @@ -258,38 +257,9 @@ def stream_chat_completion( session = session_metadata_api.get_session(session_id, user_name=origin_remote_user) configuration = request.configuration or RagPredictConfiguration() - tool_events_queue: queue.Queue[ToolEvent] = queue.Queue() # Create a cancellation event to signal when the client disconnects cancel_event = threading.Event() - def tools_callback(chat_future: Future[Any]) -> Generator[str, None, None]: - while True: - # Check if client has disconnected - if cancel_event.is_set(): - logger.info("Client disconnected, stopping tool callback") - # Try to cancel the future if it's still running - if not chat_future.done(): - chat_future.cancel() - break - - if chat_future.done() and (e := chat_future.exception()): - raise e - - try: - event_data = tool_events_queue.get(block=True, timeout=1.0) - if event_data.type == poison_pill: - break - event_json = json.dumps({"event": event_data.model_dump()}) - yield f"data: {event_json}\n\n" - except queue.Empty: - # Send a heartbeat event every second to keep the connection alive - heartbeat = ToolEvent( - type="event", name="Processing", timestamp=time.time() - ) - event_json = json.dumps({"event": heartbeat.model_dump()}) - yield f"data: {event_json}\n\n" - time.sleep(1) - def generate_stream() -> Generator[str, None, None]: response_id: str = "" executor = None @@ -303,12 +273,8 @@ def generate_stream() -> Generator[str, None, None]: query=request.query, configuration=configuration, user_name=origin_remote_user, - tool_events_queue=tool_events_queue, ) - # Yield from tools_callback, which will check for cancellation - yield from tools_callback(future) - # If we get here and the cancel_event is set, the client has disconnected if cancel_event.is_set(): logger.info("Client disconnected, not processing results") @@ -316,15 +282,22 @@ def generate_stream() -> Generator[str, None, None]: first_message = True stream = future.result() - for response in stream: + for item in stream: + response: ChatResponse = item # Check for cancellation between each response if cancel_event.is_set(): logger.info("Client disconnected during result processing") break - + if "chat_event" in response.additional_kwargs: + chat_event: ChatEvent = response.additional_kwargs.get("chat_event") + event_json = json.dumps({"event": chat_event.model_dump()}) + yield f"data: {event_json}\n\n" + continue # send an initial message to let the client know the response stream is starting if first_message: - done = ToolEvent(type="done", name="done", timestamp=time.time()) + done = ChatEvent( + type="done", name="agent_done", timestamp=time.time() + ) event_json = json.dumps({"event": done.model_dump()}) yield f"data: {event_json}\n\n" first_message = False @@ -333,6 +306,9 @@ def generate_stream() -> Generator[str, None, None]: yield f"data: {json_delta}\n\n" if not cancel_event.is_set() and response_id: + done = ChatEvent(type="done", name="chat_done", timestamp=time.time()) + event_json = json.dumps({"event": done.model_dump()}) + yield f"data: {event_json}\n\n" yield f'data: {{"response_id" : "{response_id}"}}\n\n' except TimeoutError: diff --git a/llm-service/app/services/chat/streaming_chat.py b/llm-service/app/services/chat/streaming_chat.py index 620e702e..0a3902f8 100644 --- a/llm-service/app/services/chat/streaming_chat.py +++ b/llm-service/app/services/chat/streaming_chat.py @@ -37,7 +37,6 @@ # import time import uuid -from queue import Queue from typing import Optional, Generator from llama_index.core.base.llms.types import ChatResponse, ChatMessage @@ -59,12 +58,10 @@ from app.services.metadata_apis.session_metadata_api import Session from app.services.mlflow import record_direct_llm_mlflow_run from app.services.query import querier -from app.services.query.agents.tool_calling_querier import poison_pill from app.services.query.chat_engine import ( FlexibleContextChatEngine, build_flexible_chat_engine, ) -from app.services.query.chat_events import ToolEvent from app.services.query.querier import ( build_retriever, ) @@ -76,7 +73,6 @@ def stream_chat( query: str, configuration: RagPredictConfiguration, user_name: Optional[str], - tool_events_queue: Queue[ToolEvent], ) -> Generator[ChatResponse, None, None]: query_configuration = QueryConfiguration( top_k=session.response_chunks, @@ -100,12 +96,12 @@ def stream_chat( len(session.data_source_ids) == 0 or total_data_sources_size == 0 ): # put a poison pill in the queue to stop the tool events stream - tool_events_queue.put(ToolEvent(type=poison_pill, name="no-op")) return _stream_direct_llm_chat(session, response_id, query, user_name) condensed_question, streaming_chat_response = build_streamer( - tool_events_queue, query, query_configuration, session + query, query_configuration, session ) + return _run_streaming_chat( session, response_id, @@ -127,7 +123,6 @@ def _run_streaming_chat( condensed_question: Optional[str] = None, ) -> Generator[ChatResponse, None, None]: response: ChatResponse = ChatResponse(message=ChatMessage(content=query)) - if streaming_chat_response.chat_stream: for response in streaming_chat_response.chat_stream: response.additional_kwargs["response_id"] = response_id @@ -151,7 +146,6 @@ def _run_streaming_chat( def build_streamer( - chat_events_queue: Queue[ToolEvent], query: str, query_configuration: QueryConfiguration, session: Session, @@ -180,7 +174,6 @@ def build_streamer( query, query_configuration, chat_messages, - tool_events_queue=chat_events_queue, session=session, ) return condensed_question, streaming_chat_response diff --git a/llm-service/app/services/query/tools/__init__.py b/llm-service/app/services/query/agents/agent_tools/__init__.py similarity index 100% rename from llm-service/app/services/query/tools/__init__.py rename to llm-service/app/services/query/agents/agent_tools/__init__.py diff --git a/llm-service/app/services/query/tools/date.py b/llm-service/app/services/query/agents/agent_tools/date.py similarity index 83% rename from llm-service/app/services/query/tools/date.py rename to llm-service/app/services/query/agents/agent_tools/date.py index 7ef69a65..ac308196 100644 --- a/llm-service/app/services/query/tools/date.py +++ b/llm-service/app/services/query/agents/agent_tools/date.py @@ -36,7 +36,6 @@ # DATA. # from datetime import datetime -from typing import Any from llama_index.core.tools import BaseTool, ToolOutput, ToolMetadata from pydantic import BaseModel @@ -46,17 +45,28 @@ class DateToolInput(BaseModel): """ Input schema for the DateTool """ - input_: None = None + + input: None = None + class DateTool(BaseTool): """ A tool that provides the current date and time. """ + @property def metadata(self) -> ToolMetadata: - return ToolMetadata(name="date_tool", description="A tool that provides the current date and time.", fn_schema=DateToolInput) + return ToolMetadata( + name="date_tool", + description="A tool that provides the current date and time.", + fn_schema=DateToolInput, + ) - def __call__(self, input_: Any) -> ToolOutput: + def __call__(self, input: None=None) -> ToolOutput: now = datetime.now() - return ToolOutput(content=f"The current date is {now.strftime('%Y-%m-%d %H:%M:%S')}", tool_name="date_tool", raw_input={}, raw_output=now) - + return ToolOutput( + content=f"The current date is {now.strftime('%Y-%m-%d %H:%M:%S')}", + tool_name="date_tool", + raw_input={}, + raw_output=now, + ) diff --git a/llm-service/app/services/query/agents/agent_tools/mcp.py b/llm-service/app/services/query/agents/agent_tools/mcp.py new file mode 100644 index 00000000..07e8e015 --- /dev/null +++ b/llm-service/app/services/query/agents/agent_tools/mcp.py @@ -0,0 +1,87 @@ +# +# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) +# (C) Cloudera, Inc. 2025 +# All rights reserved. +# +# Applicable Open Source License: Apache 2.0 +# +# NOTE: Cloudera open source products are modular software products +# made up of hundreds of individual components, each of which was +# individually copyrighted. Each Cloudera open source product is a +# collective work under U.S. Copyright Law. Your license to use the +# collective work is as provided in your written agreement with +# Cloudera. Used apart from the collective work, this file is +# licensed for your use pursuant to the open source license +# identified above. +# +# This code is provided to you pursuant a written agreement with +# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute +# this code. If you do not have a written agreement with Cloudera nor +# with an authorized and properly licensed third party, you do not +# have any rights to access nor to use this code. +# +# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the +# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY +# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED +# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO +# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, +# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS +# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE +# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR +# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES +# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF +# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF +# DATA. +# +import json +import os +from copy import copy + +from llama_index.core.tools import FunctionTool +from llama_index.tools.mcp import BasicMCPClient, McpToolSpec + +from app.config import settings + + +def get_llama_index_tools(server_name: str) -> list[FunctionTool]: + """ + Find an MCP server by name in the mcp.json file and return the appropriate adapter. + + Args: + server_name: The name of the MCP server to find + + Returns: + An MCPServerAdapter configured for the specified server + + Raises: + ValueError: If the server name is not found in the mcp.json file + """ + mcp_json_path = os.path.join(settings.tools_dir, "mcp.json") + + with open(mcp_json_path, "r") as f: + mcp_config = json.load(f) + + mcp_servers = mcp_config["mcp_servers"] + server_config = next(filter(lambda x: x["name"] == server_name, mcp_servers), None) + + if server_config: + environment: dict[str, str] | None = copy(dict(os.environ)) + if "env" in server_config and environment: + environment.update(server_config["env"]) + + if "command" in server_config: + client = BasicMCPClient( + command_or_url=server_config["command"], + args=server_config.get("args", []), + env=environment, + ) + elif "url" in server_config: + client = BasicMCPClient(command_or_url=server_config["url"]) + else: + raise ValueError("Not configured right...fixme") + tool_spec = McpToolSpec(client=client) + return tool_spec.to_tool_list() + + raise ValueError(f"Invalid configuration for MCP server '{server_name}'") diff --git a/llm-service/app/services/query/tools/retriever.py b/llm-service/app/services/query/agents/agent_tools/retriever.py similarity index 100% rename from llm-service/app/services/query/tools/retriever.py rename to llm-service/app/services/query/agents/agent_tools/retriever.py diff --git a/llm-service/app/services/query/agents/tool_calling_querier.py b/llm-service/app/services/query/agents/tool_calling_querier.py index ffb36695..053493c8 100644 --- a/llm-service/app/services/query/agents/tool_calling_querier.py +++ b/llm-service/app/services/query/agents/tool_calling_querier.py @@ -41,28 +41,32 @@ from typing import Optional, Generator, AsyncGenerator, Callable, cast, Any import opik -from llama_index.agent.openai import OpenAIAgent from llama_index.core.agent.workflow import ( FunctionAgent, AgentStream, ToolCall, ToolCallResult, + AgentOutput, + AgentInput, + AgentSetup, ) from llama_index.core.base.llms.types import ChatMessage, MessageRole, ChatResponse from llama_index.core.chat_engine.types import StreamingAgentChatResponse from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.schema import NodeWithScore -from llama_index.core.tools import BaseTool, ToolOutput -from llama_index.llms.openai import OpenAI +from llama_index.core.tools import BaseTool from app.ai.indexing.summary_indexer import SummaryIndexer +from app.services.metadata_apis.session_metadata_api import Session +from app.services.query.agents.agent_tools.date import DateTool +from app.services.query.agents.agent_tools.mcp import get_llama_index_tools +from app.services.query.agents.agent_tools.retriever import ( + build_retriever_tool, +) from app.services.query.chat_engine import ( FlexibleContextChatEngine, ) -from app.services.query.tools.date import DateTool -from app.services.query.tools.retriever import ( - build_retriever_tool, -) +from app.services.query.chat_events import ChatEvent if os.environ.get("ENABLE_OPIK") == "True": opik.configure( @@ -70,7 +74,7 @@ ) logger = logging.getLogger(__name__) -# litellm._turn_on_debug() + poison_pill = "poison_pill" @@ -106,7 +110,7 @@ def should_use_retrieval( * Cite from node_ids in the given format: the node_id \ should be in an html anchor tag () with an html 'class' of 'rag_citation'. \ Do not use filenames as citations. Only node ids should be used. \ -For example: 2. Do not make up node ids that are not present +For example: . Do not make up node ids that are not present in the context. * All citations should be either in-line citations or markdown links. @@ -142,32 +146,35 @@ def stream_chat( chat_engine: Optional[FlexibleContextChatEngine], enhanced_query: str, chat_messages: list[ChatMessage], - additional_tools: list[BaseTool], + session: Session, data_source_summaries: dict[int, str], ) -> StreamingAgentChatResponse: + mcp_tools: list[BaseTool] = [] + if session.query_configuration and session.query_configuration.selected_tools: + for tool_name in session.query_configuration.selected_tools: + try: + mcp_tools.extend(get_llama_index_tools(tool_name)) + except ValueError as e: + logger.warning(f"Could not create adapter for tool {tool_name}: {e}") + continue + # Use the existing chat engine with the enhanced query for streaming response tools: list[BaseTool] = [DateTool()] if use_retrieval and chat_engine: retrieval_tool = build_retriever_tool( - retriever=chat_engine._retriever, + retriever=chat_engine.retriever, summaries=data_source_summaries, - node_postprocessors=chat_engine._node_postprocessors, + node_postprocessors=chat_engine.node_postprocessors, ) tools.append(retrieval_tool) - tools.extend(additional_tools) - if isinstance(llm, OpenAI): - gen, source_nodes = _openai_agent_streamer( - chat_messages, enhanced_query, llm, tools - ) - else: - gen, source_nodes = _run_non_openai_streamer( - chat_messages, enhanced_query, llm, tools - ) + tools.extend(mcp_tools) + + gen, source_nodes = _run_streamer(chat_messages, enhanced_query, llm, tools) return StreamingAgentChatResponse(chat_stream=gen, source_nodes=source_nodes) -def _run_non_openai_streamer( +def _run_streamer( chat_messages: list[ChatMessage], enhanced_query: str, llm: FunctionCallingLLM, @@ -184,17 +191,73 @@ def _run_non_openai_streamer( async def agen() -> AsyncGenerator[ChatResponse, None]: handler = agent.run(user_msg=enhanced_query, chat_history=chat_messages) + async for event in handler.stream_events(): - if isinstance(event, ToolCall): - if verbose and not isinstance(event, ToolCallResult): - print("=== Calling Function ===") - print( - f"Calling function: {event.tool_name} with args: {event.tool_kwargs}" - ) - if isinstance(event, ToolCallResult): + if isinstance(event, AgentSetup): + data = f"Agent {event.current_agent_name} setup with input: {event.input[-1].content!s}" if verbose: - print(f"Got output: {event.tool_output!s}") - print("========================") + logger.info("=== Agent Setup ===") + logger.info(data) + logger.info("========================") + yield ChatResponse( + message=ChatMessage( + role=MessageRole.FUNCTION, + content="", + ), + delta="", + raw="", + additional_kwargs={ + "chat_event": ChatEvent( + type="agent_setup", + name=event.current_agent_name, + data=data, + ), + }, + ) + elif isinstance(event, AgentInput): + data = f"Agent {event.current_agent_name} started with input: {event.input[-1].content!s}" + if verbose: + logger.info("=== Agent Input ===") + logger.info(data) + logger.info("========================") + yield ChatResponse( + message=ChatMessage( + role=MessageRole.FUNCTION, + content="", + ), + delta="", + raw="", + additional_kwargs={ + "chat_event": ChatEvent( + type="agent_input", + name=event.current_agent_name, + data=data, + ), + }, + ) + elif isinstance(event, ToolCall) and not isinstance(event, ToolCallResult): + data = f"Calling function: {event.tool_name} with args: {event.tool_kwargs}" + if verbose: + logger.info("=== Calling Function ===") + logger.info(data) + yield ChatResponse( + message=ChatMessage( + role=MessageRole.TOOL, + content="", + ), + delta="", + raw="", + additional_kwargs={ + "chat_event": ChatEvent( + type="tool_call", name=event.tool_name, data=data + ), + }, + ) + elif isinstance(event, ToolCallResult): + data = f"Got output: {event.tool_output!s}" + if verbose: + logger.info(data) + logger.info("========================") if ( event.tool_output.raw_output and isinstance(event.tool_output.raw_output, list) @@ -204,7 +267,47 @@ async def agen() -> AsyncGenerator[ChatResponse, None]: ) ): source_nodes.extend(event.tool_output.raw_output) - if isinstance(event, AgentStream): + yield ChatResponse( + message=ChatMessage( + role=MessageRole.TOOL, + content="", + ), + delta="", + raw="", + additional_kwargs={ + "chat_event": ChatEvent( + type="tool_result", + name=event.tool_name, + data=data, + ), + }, + ) + elif isinstance(event, AgentOutput): + data = f"Agent {event.current_agent_name} response: {event.response!s}" + if verbose: + logger.info("=== LLM Response ===") + logger.info( + f"{str(event.response) if event.response else 'No content'}" + ) + logger.info("========================") + yield ChatResponse( + message=ChatMessage( + role=MessageRole.TOOL, + content=( + event.response.content if event.response.content else "" + ), + ), + delta="", + raw=event.raw, + additional_kwargs={ + "chat_event": ChatEvent( + type="agent_response", + name=event.current_agent_name, + data=data, + ), + }, + ) + elif isinstance(event, AgentStream): if event.response: # Yield the delta response as a ChatResponse yield ChatResponse( @@ -214,77 +317,31 @@ async def agen() -> AsyncGenerator[ChatResponse, None]: ), delta=event.delta, raw=event.raw, - additional_kwargs={ - "tool_calls": event.tool_calls, - }, ) + else: + logger.info(f"Unhandled event of type: {type(event)}: {event}") + await handler + if e := handler.exception(): + raise e + if handler.ctx: + await handler.ctx.shutdown() def gen() -> Generator[ChatResponse, None, None]: - async def collect() -> list[ChatResponse]: - results: list[ChatResponse] = [] - async for chunk in agen(): - results.append(chunk) - return results + loop = asyncio.new_event_loop() + astream = agen() + try: + while True: + item = loop.run_until_complete(anext(astream)) + yield item + except (StopAsyncIteration, GeneratorExit): + pass + finally: + try: + loop.run_until_complete(astream.aclose()) + except Exception as e: + logger.warning(f"Exception during async generator close: {e}") + if not loop.is_closed(): + loop.stop() + loop.close() - item = ChatResponse( - message=ChatMessage(role=MessageRole.ASSISTANT, content=""), - delta="", - raw=None, - additional_kwargs={ - "tool_calls": [], - }, - ) - for item in asyncio.run(collect()): - yield item - if verbose: - print("=== LLM Response ===") - print( - f"{item.message.content.strip() if item.message.content else 'No content'}" - ) - print("========================") - - return gen(), source_nodes - - -def _openai_agent_streamer( - chat_messages: list[ChatMessage], - enhanced_query: str, - llm: OpenAI, - tools: list[BaseTool], - verbose: bool = True, -) -> tuple[Generator[ChatResponse, None, None], list[NodeWithScore]]: - agent = OpenAIAgent.from_tools( - tools=tools, - llm=llm, - verbose=verbose, - system_prompt=DEFAULT_AGENT_PROMPT, - ) - stream_chat_response: StreamingAgentChatResponse = agent.stream_chat( - message=enhanced_query, chat_history=chat_messages - ) - - def gen() -> Generator[ChatResponse, None, None]: - response = "" - res = stream_chat_response.response_gen - for chunk in res: - response += chunk - finalize_response = ChatResponse( - message=ChatMessage(role="assistant", content=response), - delta=chunk, - ) - yield finalize_response - - source_nodes = [] - if stream_chat_response.sources: - for tool_output in stream_chat_response.sources: - if isinstance(tool_output, ToolOutput): - if ( - tool_output.raw_output - and isinstance(tool_output.raw_output, list) - and all( - isinstance(elem, NodeWithScore) - for elem in tool_output.raw_output - ) - ): - source_nodes.extend(tool_output.raw_output) return gen(), source_nodes diff --git a/llm-service/app/services/query/chat_engine.py b/llm-service/app/services/query/chat_engine.py index c1dc1b81..9f4caeef 100644 --- a/llm-service/app/services/query/chat_engine.py +++ b/llm-service/app/services/query/chat_engine.py @@ -225,6 +225,14 @@ def _run_c3( return response_synthesizer, context_source, context_nodes + @property + def retriever(self) -> BaseRetriever: + return self._retriever + + @property + def node_postprocessors(self) -> List[BaseNodePostprocessor]: + return self._node_postprocessors + def build_flexible_chat_engine( configuration: QueryConfiguration, diff --git a/llm-service/app/services/query/chat_events.py b/llm-service/app/services/query/chat_events.py index 92f4b714..e8fe3936 100644 --- a/llm-service/app/services/query/chat_events.py +++ b/llm-service/app/services/query/chat_events.py @@ -40,18 +40,16 @@ from queue import Queue from typing import Optional, Any -from pydantic import BaseModel +from pydantic import BaseModel, Field -class ToolEvent(BaseModel): +class ChatEvent(BaseModel): type: str name: str data: Optional[str] = None - timestamp: float = time.time() + timestamp: float = Field(default_factory=lambda : time.time()) -def step_callback( - output: Any, agent: str, tool_events_queue: Queue[ToolEvent] -) -> None: +def step_callback(output: Any, agent: str, tool_events_queue: Queue[ChatEvent]) -> None: # todo: hook this up return None diff --git a/llm-service/app/services/query/querier.py b/llm-service/app/services/query/querier.py index b1083724..898120b7 100644 --- a/llm-service/app/services/query/querier.py +++ b/llm-service/app/services/query/querier.py @@ -29,11 +29,7 @@ # ############################################################################## from __future__ import annotations -import json -import os import re -from copy import copy -from queue import Queue from typing import Optional, TYPE_CHECKING, cast from llama_index.core.base.base_retriever import BaseRetriever @@ -41,19 +37,14 @@ from llama_index.core.llms import LLM from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.schema import NodeWithScore -from llama_index.core.tools import BaseTool as LLamaTool -from llama_index.core.tools import FunctionTool from .agents.tool_calling_querier import ( should_use_retrieval, stream_chat, - poison_pill, ) -from .chat_events import ToolEvent from .flexible_retriever import FlexibleRetriever from .multi_retriever import MultiSourceRetriever from ..metadata_apis.session_metadata_api import Session -from ...config import settings if TYPE_CHECKING: from ..chat.utils import RagContext @@ -73,75 +64,17 @@ from app.services.query.query_configuration import QueryConfiguration from .chat_engine import build_flexible_chat_engine, FlexibleContextChatEngine from ...ai.vector_stores.vector_store_factory import VectorStoreFactory -from llama_index.tools.mcp import BasicMCPClient, McpToolSpec logger = logging.getLogger(__name__) -def get_llama_index_tools(server_name: str) -> list[FunctionTool]: - """ - Find an MCP server by name in the mcp.json file and return the appropriate adapter. - - Args: - server_name: The name of the MCP server to find - - Returns: - An MCPServerAdapter configured for the specified server - - Raises: - ValueError: If the server name is not found in the mcp.json file - """ - mcp_json_path = os.path.join(settings.tools_dir, "mcp.json") - - with open(mcp_json_path, "r") as f: - mcp_config = json.load(f) - - mcp_servers = mcp_config["mcp_servers"] - server_config = next(filter(lambda x: x["name"] == server_name, mcp_servers), None) - - if server_config: - environment: dict[str, str] | None = copy(dict(os.environ)) - if "env" in server_config and environment: - environment.update(server_config["env"]) - - if "command" in server_config: - client = BasicMCPClient( - command_or_url=server_config["command"], - args=server_config.get("args", []), - env=environment, - ) - elif "url" in server_config: - client = BasicMCPClient(command_or_url=server_config["url"]) - else: - raise ValueError("Not configured right...fixme") - tool_spec = McpToolSpec(client=client) - return tool_spec.to_tool_list() - - raise ValueError(f"Invalid configuration for MCP server '{server_name}'") - - def streaming_query( chat_engine: Optional[FlexibleContextChatEngine], query_str: str, configuration: QueryConfiguration, chat_messages: list[ChatMessage], - tool_events_queue: Queue[ToolEvent], session: Session, ) -> StreamingAgentChatResponse: - all_tools: list[LLamaTool] = [] - - if session.query_configuration and session.query_configuration.selected_tools: - for tool_name in session.query_configuration.selected_tools: - try: - llama_tools = get_llama_index_tools(tool_name) - # print( - # f"Adding adapter for tools: {[tool.name for tool in adapter.tools]}" - # ) - all_tools.extend(llama_tools) - except ValueError as e: - logger.warning(f"Could not create adapter for tool {tool_name}: {e}") - continue - llm = models.LLM.get(model_name=configuration.model_name) chat_response: StreamingAgentChatResponse @@ -156,10 +89,9 @@ def streaming_query( chat_engine, query_str, chat_messages, - all_tools, + session, data_source_summaries, ) - tool_events_queue.put(ToolEvent(type=poison_pill, name="no-op")) return chat_response if not chat_engine: raise HTTPException( @@ -169,7 +101,6 @@ def streaming_query( try: chat_response = chat_engine.stream_chat(query_str, chat_messages) - tool_events_queue.put(ToolEvent(type=poison_pill, name="no-op")) logger.debug("query response received from chat engine") except botocore.exceptions.ClientError as error: logger.warning(error.response) diff --git a/llm-service/pyproject.toml b/llm-service/pyproject.toml index 789c9a14..adcad9d2 100644 --- a/llm-service/pyproject.toml +++ b/llm-service/pyproject.toml @@ -45,7 +45,6 @@ dependencies = [ "llama-index-callbacks-opik>=1.1.0", "mcp[cli]>=1.9.1", "pysqlite3-binary==0.5.4; platform_system == 'Linux' and platform_machine != 'aarch64'", - "llama-index-agent-openai>=0.4.8", "llama-index-tools-mcp>=0.2.5", ] requires-python = ">=3.10,<=3.12" diff --git a/llm-service/uv.lock b/llm-service/uv.lock index bbe65713..f7c0d7db 100644 --- a/llm-service/uv.lock +++ b/llm-service/uv.lock @@ -2044,20 +2044,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/af/1e344bc8aee41445272e677d802b774b1f8b34bdc3bb5697ba30f0fb5d52/litellm-1.68.0-py3-none-any.whl", hash = "sha256:3bca38848b1a5236b11aa6b70afa4393b60880198c939e582273f51a542d4759", size = 7684460 }, ] -[[package]] -name = "llama-index-agent-openai" -version = "0.4.8" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "llama-index-core" }, - { name = "llama-index-llms-openai" }, - { name = "openai" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/2c/10/34454bd6563ff7fb63dec264a34e2749486194f9b4fb1ea8c2e4b9f8e2e9/llama_index_agent_openai-0.4.8.tar.gz", hash = "sha256:ba76f21e1b7f0f66e326dc419c2cc403cbb614ae28f7904540b1103695965f68", size = 12230 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/b8/1d7f50b6471fd73ff7309e6abd808c935a8d9d8547b192ce56ed3a05c142/llama_index_agent_openai-0.4.8-py3-none-any.whl", hash = "sha256:a03e8609ada0355b408d4173cd7663708f826f23328f9719fba00ea20b6851b6", size = 14212 }, -] - [[package]] name = "llama-index-callbacks-opik" version = "1.1.0" @@ -2350,7 +2336,6 @@ dependencies = [ { name = "docx2txt" }, { name = "fastapi", extra = ["standard"] }, { name = "fastapi-utils" }, - { name = "llama-index-agent-openai" }, { name = "llama-index-callbacks-opik" }, { name = "llama-index-core" }, { name = "llama-index-embeddings-azure-openai" }, @@ -2378,7 +2363,7 @@ dependencies = [ { name = "presidio-anonymizer" }, { name = "pydantic" }, { name = "pydantic-settings" }, - { name = "pysqlite3-binary", marker = "platform_system == 'Linux'" }, + { name = "pysqlite3-binary", marker = "platform_machine != 'aarch64' and platform_system == 'Linux'" }, { name = "python-pptx" }, { name = "qdrant-client" }, { name = "torch" }, @@ -2407,7 +2392,6 @@ requires-dist = [ { name = "docx2txt", specifier = ">=0.8" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.111.0" }, { name = "fastapi-utils", specifier = ">=0.8.0" }, - { name = "llama-index-agent-openai", specifier = ">=0.4.8" }, { name = "llama-index-callbacks-opik", specifier = ">=1.1.0" }, { name = "llama-index-core", specifier = ">=0.10.68" }, { name = "llama-index-embeddings-azure-openai", specifier = ">=0.3.0" }, @@ -2435,7 +2419,7 @@ requires-dist = [ { name = "presidio-anonymizer", specifier = ">=2.2.355" }, { name = "pydantic", specifier = ">=2.8.2" }, { name = "pydantic-settings", specifier = ">=2.3.4" }, - { name = "pysqlite3-binary", marker = "platform_system == 'Linux'", specifier = "==0.5.4" }, + { name = "pysqlite3-binary", marker = "platform_machine != 'aarch64' and platform_system == 'Linux'", specifier = "==0.5.4" }, { name = "python-pptx", specifier = ">=1.0.2" }, { name = "qdrant-client", specifier = "<1.13.0" }, { name = "torch", specifier = ">=2.5.1" }, diff --git a/ui/src/api/chatApi.ts b/ui/src/api/chatApi.ts index d45a2de9..9baf72fc 100644 --- a/ui/src/api/chatApi.ts +++ b/ui/src/api/chatApi.ts @@ -331,10 +331,10 @@ export interface ChatMutationResponse { text?: string; response_id?: string; error?: string; - event?: ToolEventResponse; + event?: ChatEvent; } -export interface ToolEventResponse { +export interface ChatEvent { type: string; name: string; data?: string; @@ -378,7 +378,7 @@ const canceledChatMessage = (variables: ChatMutationRequest) => { interface StreamingChatCallbacks { onChunk: (msg: string) => void; - onEvent: (event: ToolEventResponse) => void; + onEvent: (event: ChatEvent) => void; getController?: (ctrl: AbortController) => void; } @@ -522,7 +522,7 @@ export const useStreamingChatMutation = ({ const streamChatMutation = async ( request: ChatMutationRequest, onChunk: (chunk: string) => void, - onEvent: (event: ToolEventResponse) => void, + onEvent: (event: ChatEvent) => void, onError: (error: string) => void, getController?: (ctrl: AbortController) => void, ): Promise => { @@ -598,9 +598,9 @@ const streamChatMutation = async ( }; export const getOnEvent = ( - setStreamedEvent: Dispatch>, + setStreamedEvent: Dispatch>, ) => { - return (event: ToolEventResponse) => { + return (event: ChatEvent) => { if (event.type === "done") { setStreamedEvent([]); } else { diff --git a/ui/src/pages/RagChatTab/ChatLayout.tsx b/ui/src/pages/RagChatTab/ChatLayout.tsx index 94e39de6..bf90ed3c 100644 --- a/ui/src/pages/RagChatTab/ChatLayout.tsx +++ b/ui/src/pages/RagChatTab/ChatLayout.tsx @@ -43,7 +43,7 @@ import { Outlet, useParams } from "@tanstack/react-router"; import { useMemo, useState } from "react"; import { ChatMessageType, - ToolEventResponse, + ChatEvent, useChatHistoryQuery, } from "src/api/chatApi.ts"; import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; @@ -78,7 +78,7 @@ function ChatLayout() { useGetDataSourcesForProject(+projectId); const [excludeKnowledgeBase, setExcludeKnowledgeBase] = useState(false); const [streamedChat, setStreamedChat] = useState(""); - const [streamedEvent, setStreamedEvent] = useState([]); + const [streamedEvent, setStreamedEvent] = useState([]); const [streamedAbortController, setStreamedAbortController] = useState(); const { diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx index e0bcfc7d..0a4b8a32 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx @@ -36,7 +36,7 @@ * DATA. */ -import { ChatMessageType, ToolEventResponse } from "src/api/chatApi.ts"; +import { ChatMessageType, ChatEvent } from "src/api/chatApi.ts"; import UserQuestion from "pages/RagChatTab/ChatOutput/ChatMessages/UserQuestion.tsx"; import { Divider, Flex, Typography } from "antd"; import Images from "src/components/images/Images.ts"; @@ -53,7 +53,7 @@ export const ChatMessageBody = ({ streamedEvents, }: { data: ChatMessageType; - streamedEvents?: ToolEventResponse[]; + streamedEvents?: ChatEvent[]; }) => { return (
diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/MarkdownResponse.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/MarkdownResponse.tsx index d8488085..7da784db 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/MarkdownResponse.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/MarkdownResponse.tsx @@ -58,7 +58,7 @@ export const MarkdownResponse = ({ data }: { data: ChatMessageType }) => { const { href, className, children, ...other } = props; if (className === "rag_citation") { if (data.source_nodes.length === 0) { - return undefined; + return {props.children}; } const { source_nodes } = data; const sourceNodeIndex = source_nodes.findIndex( @@ -66,10 +66,13 @@ export const MarkdownResponse = ({ data }: { data: ChatMessageType }) => { ); if (sourceNodeIndex >= 0) { return ( - + + {props.children} + + ); } if (!href?.startsWith("http")) { diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/StreamedEvents.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/StreamedEvents.tsx index 1c099d37..801620eb 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/StreamedEvents.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/StreamedEvents.tsx @@ -36,13 +36,13 @@ * DATA. */ -import { ToolEventResponse } from "src/api/chatApi.ts"; +import { ChatEvent } from "src/api/chatApi.ts"; import { Button, Card, Flex, Spin, Typography } from "antd"; import { format } from "date-fns"; import { useState } from "react"; import { MinusOutlined, PlusOutlined } from "@ant-design/icons"; -const StreamedEvent = ({ event }: { event: ToolEventResponse }) => { +const StreamedEvent = ({ event }: { event: ChatEvent }) => { return ( @@ -76,7 +76,7 @@ const StreamedEvent = ({ event }: { event: ToolEventResponse }) => { const StreamedEvents = ({ streamedEvents, }: { - streamedEvents?: ToolEventResponse[]; + streamedEvents?: ChatEvent[]; }) => { const [collapsed, setCollapsed] = useState(true); diff --git a/ui/src/pages/RagChatTab/State/RagChatContext.tsx b/ui/src/pages/RagChatTab/State/RagChatContext.tsx index 5b80dd2d..eb9a43e1 100644 --- a/ui/src/pages/RagChatTab/State/RagChatContext.tsx +++ b/ui/src/pages/RagChatTab/State/RagChatContext.tsx @@ -40,7 +40,7 @@ import { createContext, Dispatch, SetStateAction } from "react"; import { ChatHistoryResponse, ChatMessageType, - ToolEventResponse, + ChatEvent, } from "src/api/chatApi.ts"; import { Session } from "src/api/sessionApi.ts"; import { DataSourceType } from "src/api/dataSourceApi.ts"; @@ -64,10 +64,7 @@ export interface RagChatContextType { >; }; streamedChatState: [string, Dispatch>]; - streamedEventState: [ - ToolEventResponse[], - Dispatch>, - ]; + streamedEventState: [ChatEvent[], Dispatch>]; streamedAbortControllerState: [ AbortController | undefined, Dispatch>,