Skip to content

Commit d8225f3

Browse files
ewilliams-clouderabaasitshariefjkwatsonmliu-clouderaactions-user
authored
Tool Calling (#219)
* feat: added tool calling with retrieval and search results (#214) * Retrieval and Searching happens in tandem * Source nodes cited but not available in chat response are also provided to chat response with a score of 0.0 i.e. source node information used from chat history * Reduced llm calls by removing the number of agents and using only tasks * Further refactoring might be needed * support openai for tool calling * Propagate error from chat streaming * fix mypy issues * conditionally do evals based on whether we have source nodes * don't show direct llm call alert if tool calling is enabled * add opik for llama-index, fix chat history query! * catch an error when trying to extract source nodes from responses * don't scroll to top of response when done streaming * added azure api version to settings and prompt changes * wip on using mcp with crewai * switch to playing with the fetch mcp server * remove the unneeded env to the uvx command * refactor to allow multiple MCP servers/toolsets, and add the mcp-server-fetch as an option * add in text2sql2text tool * move to using mcp.json * add selectedTools as an attribute of the session queryConfiguration * hook up selected tools to the UI * use the sessions's tools settings * rename amp settings to studio settings * remove path to caii domain * fix table name * remove tools from query configuration and remove calls to search (leave in search code for now) * move mcp.json file to tools in root * fix mypy * Update release version to dev-testing * use the right model name for crew * added verbosity back * Add some text to the tool calling options abot the power of the inference model * readme for the mcp.json * remove text2sql tool * minor changes to description and tools manager formatting * add support for env * add in the root environment to the tools * add tools/ to gitignore * load node env before starting up python * fix a mypy issue * fix mypy issue with missing model * remove console.error * bug fix to passing task contexts, and minor prompt changes * remove serper tool * Rename model provider tests * fix bug with overriding tools when updating session * added date to researcher task description * Rename methods; use list instead of List * Add methods to provider interface * add more logs for tool calling * Implement methods * Clean up code * wip on moving chat history * get ruff passing * WIP mcp tool calling * ruff check * remove mcp server * add beta tags * remove comment on chat history tool --------- Co-authored-by: Baasit Sharief <baasitsharief@gmail.com> Co-authored-by: jwatson <jkwatson@gmail.com> Co-authored-by: Michael Liu <mliu@cloudera.com> Co-authored-by: actions-user <actions@github.com>
1 parent 0daf0c5 commit d8225f3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1857
-537
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ databases/
99
**/docling-output.txt
1010
**/.DS_Store
1111
.history
12-
addresses/
12+
addresses/
13+
tools/

backend/src/main/java/com/cloudera/cai/rag/Types.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,12 @@ public record RagDataSource(
101101
@Nullable Long totalDocSize,
102102
boolean availableForDefaultProject) {}
103103

104+
@With
104105
public record QueryConfiguration(
105-
boolean enableHyde, boolean enableSummaryFilter, boolean enableToolCalling) {}
106+
boolean enableHyde,
107+
boolean enableSummaryFilter,
108+
boolean enableToolCalling,
109+
List<String> selectedTools) {}
106110

107111
@With
108112
@Builder

backend/src/main/java/com/cloudera/cai/rag/sessions/SessionRepository.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
@Component
5858
public class SessionRepository {
5959
public static final Types.QueryConfiguration DEFAULT_QUERY_CONFIGURATION =
60-
new Types.QueryConfiguration(false, true, false);
60+
new Types.QueryConfiguration(false, true, false, List.of());
6161
private final Jdbi jdbi;
6262
private final ObjectMapper objectMapper = new ObjectMapper();
6363

@@ -169,6 +169,9 @@ private Types.QueryConfiguration extractQueryConfiguration(RowView rowView)
169169
if (queryConfiguration == null) {
170170
return DEFAULT_QUERY_CONFIGURATION;
171171
}
172+
if (queryConfiguration.selectedTools() == null) {
173+
queryConfiguration = queryConfiguration.withSelectedTools(List.of());
174+
}
172175
return queryConfiguration;
173176
}
174177

backend/src/test/java/com/cloudera/cai/rag/TestData.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public static Types.Session createTestSessionInstance(
8282
"test-model",
8383
"test-rerank-model",
8484
3,
85-
new Types.QueryConfiguration(false, true, true));
85+
new Types.QueryConfiguration(false, true, true, List.of()));
8686
}
8787

8888
public static Types.CreateSession createSessionInstance(String sessionName) {
@@ -97,7 +97,7 @@ public static Types.CreateSession createSessionInstance(
9797
"test-model",
9898
"test-rerank-model",
9999
3,
100-
new Types.QueryConfiguration(false, true, true),
100+
new Types.QueryConfiguration(false, true, true, List.of()),
101101
projectId);
102102
}
103103

backend/src/test/java/com/cloudera/cai/rag/sessions/SessionControllerTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ void update() {
145145
.withRerankModel(updatedRerankModel)
146146
.withName(updatedName)
147147
.withProjectId(updatedProjectId)
148-
.withQueryConfiguration(new Types.QueryConfiguration(true, false, true)),
148+
.withQueryConfiguration(
149+
new Types.QueryConfiguration(true, false, true, List.of("foo"))),
149150
request);
150151

151152
assertThat(updatedSession.id()).isNotNull();
@@ -160,7 +161,7 @@ void update() {
160161
assertThat(updatedSession.createdById()).isEqualTo("test-user");
161162
assertThat(updatedSession.lastInteractionTime()).isNull();
162163
assertThat(updatedSession.queryConfiguration())
163-
.isEqualTo(new Types.QueryConfiguration(true, false, true));
164+
.isEqualTo(new Types.QueryConfiguration(true, false, true, List.of("foo")));
164165
}
165166

166167
@Test

llm-service/app/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def rag_log_level(self) -> int:
6868
def rag_databases_dir(self) -> str:
6969
return os.environ.get("RAG_DATABASES_DIR", os.path.join("..", "databases"))
7070

71+
@property
72+
def tools_dir(self) -> str:
73+
return os.path.join("..", "tools")
74+
7175
@property
7276
def caii_domain(self) -> str:
7377
return os.environ["CAII_DOMAIN"]
@@ -149,6 +153,10 @@ def azure_openai_api_key(self) -> Optional[str]:
149153
def azure_openai_endpoint(self) -> Optional[str]:
150154
return os.environ.get("AZURE_OPENAI_ENDPOINT")
151155

156+
@property
157+
def azure_openai_api_version(self) -> Optional[str]:
158+
return os.environ.get("AZURE_OPENAI_API_VERSION")
159+
152160
@property
153161
def openai_api_key(self) -> Optional[str]:
154162
return os.environ.get("OPENAI_API_KEY")
@@ -157,4 +165,5 @@ def openai_api_key(self) -> Optional[str]:
157165
def openai_api_base(self) -> Optional[str]:
158166
return os.environ.get("OPENAI_API_BASE")
159167

168+
160169
settings = _Settings()

llm-service/app/main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@
3838

3939
import functools
4040
import logging
41+
import os
4142
import sys
4243
import time
4344
from collections.abc import Awaitable, Callable
4445
from contextlib import asynccontextmanager
4546
from typing import AsyncGenerator
4647

48+
import opik
4749
from fastapi import FastAPI, Request, Response
4850
from fastapi.middleware.cors import CORSMiddleware
4951
from uvicorn.logging import DefaultFormatter
@@ -73,6 +75,16 @@ def _configure_logger() -> None:
7375

7476
_configure_logger()
7577

78+
if os.environ.get("ENABLE_OPIK") == "True":
79+
opik.configure(
80+
use_local=True, url=os.environ.get("OPIK_URL", "http://localhost:5174")
81+
)
82+
83+
from llama_index.core import set_global_handler
84+
85+
# You should provide your OPIK API key and Workspace using the following environment variables:
86+
# OPIK_API_KEY, OPIK_WORKSPACE
87+
set_global_handler("opik")
7688

7789
###################################
7890
# Lifespan events

llm-service/app/rag_types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@
4040

4141
from pydantic import BaseModel
4242

43-
from app.services.query.query_configuration import tool_types
44-
4543

4644
class RagPredictConfiguration(BaseModel):
4745
exclude_knowledge_base: Optional[bool] = False
4846
use_question_condensing: Optional[bool] = True
49-
tools: Optional[list[tool_types]] = None

llm-service/app/routers/index/sessions/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
import logging
4141
import queue
4242
import time
43-
from concurrent.futures import ThreadPoolExecutor
44-
from typing import Optional, Generator
43+
from concurrent.futures import Future, ThreadPoolExecutor
44+
from typing import Optional, Generator, Any
4545

4646
from fastapi import APIRouter, Header, HTTPException
4747
from fastapi.responses import StreamingResponse
@@ -60,7 +60,8 @@
6060
from ....services.chat_history.paginator import paginate
6161
from ....services.metadata_apis import session_metadata_api
6262
from ....services.mlflow import rating_mlflow_log_metric, feedback_mlflow_log_table
63-
from ....services.query.agents.crewai_querier import CrewEvent, poison_pill
63+
from ....services.query.agents.crewai_querier import poison_pill
64+
from ....services.query.crew_events import CrewEvent
6465
from ....services.session import rename_session
6566

6667
logger = logging.getLogger(__name__)
@@ -232,8 +233,11 @@ def stream_chat_completion(
232233

233234
crew_events_queue: queue.Queue[CrewEvent] = queue.Queue()
234235

235-
def crew_callback() -> Generator[str, None, None]:
236+
def crew_callback(chat_future: Future[Any]) -> Generator[str, None, None]:
236237
while True:
238+
if chat_future.done() and (e := chat_future.exception()):
239+
raise e
240+
237241
try:
238242
event_data = crew_events_queue.get(block=True, timeout=1.0)
239243
if event_data.type == poison_pill:
@@ -262,7 +266,7 @@ def generate_stream() -> Generator[str, None, None]:
262266
crew_events_queue=crew_events_queue,
263267
)
264268

265-
yield from crew_callback()
269+
yield from crew_callback(future)
266270

267271
first_message = True
268272
for response in future.result():
@@ -278,6 +282,9 @@ def generate_stream() -> Generator[str, None, None]:
278282
json_delta = json.dumps({"text": response.delta})
279283
yield f"data: {json_delta}\n\n"
280284
yield f'data: {{"response_id" : "{response_id}"}}\n\n'
285+
except TimeoutError:
286+
logger.exception("Timeout: Failed to stream chat completion")
287+
yield 'data: {{"error" : "Timeout: Failed to stream chat completion"}}\n\n'
281288
except Exception as e:
282289
logger.exception("Failed to stream chat completion")
283290
yield f'data: {{"error" : "{e}"}}\n\n'

llm-service/app/routers/index/tools/__init__.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,30 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38-
38+
import json
3939
import logging
40+
import os
41+
from typing import Any
42+
from pydantic import BaseModel
4043

4144
from fastapi import APIRouter
42-
from pydantic import BaseModel
4345

4446
from .... import exceptions
47+
from ....config import settings
4548

4649
logger = logging.getLogger(__name__)
4750

4851
router = APIRouter(prefix="/tools", tags=["Tools"])
4952

5053

5154
class Tool(BaseModel):
52-
id: str
55+
"""
56+
Represents a tool in the MCP configuration.
57+
"""
58+
5359
name: str
54-
description: str
60+
61+
metadata: dict[str, Any]
5562

5663

5764
@router.get(
@@ -61,10 +68,9 @@ class Tool(BaseModel):
6168
)
6269
@exceptions.propagates
6370
def tools() -> list[Tool]:
64-
return [
65-
Tool(
66-
id="1",
67-
name="search",
68-
description="Searches the internet for the given query.",
69-
),
70-
]
71+
72+
mcp_json_path = os.path.join(settings.tools_dir, "mcp.json")
73+
74+
with open(mcp_json_path, "r") as f:
75+
mcp_config = json.load(f)
76+
return [Tool(**server) for server in mcp_config["mcp_servers"]]

llm-service/app/services/chat/chat.py

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38-
38+
import logging
39+
import re
3940
import time
4041
import uuid
4142
from typing import Optional
4243

4344
from fastapi import HTTPException
4445
from llama_index.core.chat_engine.types import AgentChatResponse
46+
from llama_index.core.schema import NodeWithScore
4547

4648
from app.ai.vector_stores.vector_store_factory import VectorStoreFactory
4749
from app.rag_types import RagPredictConfiguration
@@ -58,6 +60,7 @@
5860
from app.services.query import querier
5961
from app.services.query.query_configuration import QueryConfiguration
6062

63+
logger = logging.getLogger(__name__)
6164

6265
def chat(
6366
session: Session,
@@ -115,22 +118,41 @@ def _run_chat(
115118
query_configuration,
116119
retrieve_chat_history(session.id),
117120
)
118-
return finalize_response(response, condensed_question, data_source_id, query, query_configuration, response_id,session, user_name)
121+
return finalize_response(
122+
response,
123+
condensed_question,
124+
data_source_id,
125+
query,
126+
query_configuration,
127+
response_id,
128+
session,
129+
user_name,
130+
)
119131

120132

121-
def finalize_response(chat_response: AgentChatResponse,
122-
condensed_question: str | None,
123-
data_source_id: Optional[int],
124-
query: str,
125-
query_configuration: QueryConfiguration,
126-
response_id: str,
127-
session: Session,
128-
user_name: Optional[str]) -> RagStudioChatMessage:
133+
def finalize_response(
134+
chat_response: AgentChatResponse,
135+
condensed_question: str | None,
136+
data_source_id: Optional[int],
137+
query: str,
138+
query_configuration: QueryConfiguration,
139+
response_id: str,
140+
session: Session,
141+
user_name: Optional[str],
142+
) -> RagStudioChatMessage:
129143
if condensed_question and (condensed_question.strip() == query.strip()):
130144
condensed_question = None
131-
relevance, faithfulness = evaluators.evaluate_response(
132-
query, chat_response, session.inference_model
133-
)
145+
146+
if data_source_id:
147+
chat_response = extract_nodes_from_response_str(chat_response, data_source_id)
148+
149+
evaluations = []
150+
if len(chat_response.source_nodes) != 0:
151+
relevance, faithfulness = evaluators.evaluate_response(
152+
query, chat_response, session.inference_model
153+
)
154+
evaluations.append(Evaluation(name="relevance", value=relevance))
155+
evaluations.append(Evaluation(name="faithfulness", value=faithfulness))
134156
response_source_nodes = format_source_nodes(chat_response, data_source_id)
135157
new_chat_message = RagStudioChatMessage(
136158
id=response_id,
@@ -141,10 +163,7 @@ def finalize_response(chat_response: AgentChatResponse,
141163
user=query,
142164
assistant=chat_response.response,
143165
),
144-
evaluations=[
145-
Evaluation(name="relevance", value=relevance),
146-
Evaluation(name="faithfulness", value=faithfulness),
147-
],
166+
evaluations=evaluations,
148167
timestamp=time.time(),
149168
condensed_question=condensed_question,
150169
)
@@ -156,6 +175,38 @@ def finalize_response(chat_response: AgentChatResponse,
156175
return new_chat_message
157176

158177

178+
def extract_nodes_from_response_str(
179+
chat_response: AgentChatResponse, data_source_id: int
180+
) -> AgentChatResponse:
181+
# get nodes from response source nodes
182+
node_ids_present = set([node.node_id for node in chat_response.source_nodes])
183+
# pull the source nodes from the response citations
184+
extracted_node_ids = re.findall(
185+
r"<a class='rag_citation' href='(.*?)'>",
186+
chat_response.response,
187+
)
188+
# remove duplicates
189+
extracted_node_ids = [
190+
node_id for node_id in extracted_node_ids if node_id not in node_ids_present
191+
]
192+
if len(extracted_node_ids) > 0:
193+
try:
194+
qdrant_store = VectorStoreFactory.for_chunks(data_source_id)
195+
vector_store = qdrant_store.llama_vector_store()
196+
extracted_source_nodes = vector_store.get_nodes(node_ids=extracted_node_ids)
197+
198+
# cast them into NodeWithScore with score 0.0
199+
extracted_source_nodes_w_score = [
200+
NodeWithScore(node=node, score=0.0) for node in extracted_source_nodes
201+
]
202+
# add the source nodes to the response
203+
chat_response.source_nodes += extracted_source_nodes_w_score
204+
except Exception as e:
205+
logger.warning("Failed to extract nodes from response citations (%s): %s", extracted_node_ids, e)
206+
pass
207+
return chat_response
208+
209+
159210
def direct_llm_chat(
160211
session: Session, response_id: str, query: str, user_name: Optional[str]
161212
) -> RagStudioChatMessage:
@@ -179,5 +230,3 @@ def direct_llm_chat(
179230
)
180231
chat_history_manager.append_to_history(session.id, [new_chat_message])
181232
return new_chat_message
182-
183-

0 commit comments

Comments
 (0)