Skip to content

Commit 0936bd1

Browse files
committed
Update
1 parent 8639bf9 commit 0936bd1

File tree

8 files changed

+228
-102
lines changed

8 files changed

+228
-102
lines changed
Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,69 @@
1+
import asyncio
12
import os
23

34
from agno.agent import Agent
4-
from agno.db.sqlite import SqliteDb
5+
from agno.db.sqlite import AsyncSqliteDb
56
from agno.models.openai import OpenAIChat
67

78
# Remove the tmp db file before running the script
89
if os.path.exists("tmp/data.db"):
910
os.remove("tmp/data.db")
1011

1112
# Create agents for different users to demonstrate user-specific session history
12-
user_1_agent = Agent(
13+
agent = Agent(
1314
model=OpenAIChat(id="gpt-4o-mini"),
14-
user_id="user_1",
15-
db=SqliteDb(db_file="tmp/data.db"),
16-
add_history_to_context=True,
17-
num_history_runs=3,
15+
db=AsyncSqliteDb(db_file="tmp/data.db"),
1816
search_session_history=True, # allow searching previous sessions
1917
num_history_sessions=2, # only include the last 2 sessions in the search to avoid context length issues
2018
)
2119

22-
user_2_agent = Agent(
23-
model=OpenAIChat(id="gpt-4o-mini"),
24-
user_id="user_2",
25-
db=SqliteDb(db_file="tmp/data.db"),
26-
add_history_to_context=True,
27-
num_history_runs=3,
28-
search_session_history=True,
29-
num_history_sessions=2,
30-
)
3120

32-
# User 1 sessions
33-
print("=== User 1 Sessions ===")
34-
user_1_agent.print_response(
35-
"What is the capital of South Africa?", session_id="user1_session_1"
36-
)
37-
user_1_agent.print_response(
38-
"What is the capital of China?", session_id="user1_session_2"
39-
)
40-
user_1_agent.print_response(
41-
"What is the capital of France?", session_id="user1_session_3"
42-
)
21+
async def main():
22+
# User 1 sessions
23+
print("=== User 1 Sessions ===")
24+
await agent.aprint_response(
25+
"What is the capital of South Africa?",
26+
session_id="user1_session_1",
27+
user_id="user_1",
28+
)
29+
await agent.aprint_response(
30+
"What is the capital of China?", session_id="user1_session_2", user_id="user_1"
31+
)
32+
await agent.aprint_response(
33+
"What is the capital of France?", session_id="user1_session_3", user_id="user_1"
34+
)
4335

44-
# User 2 sessions
45-
print("\n=== User 2 Sessions ===")
46-
user_2_agent.print_response(
47-
"What is the population of India?", session_id="user2_session_1"
48-
)
49-
user_2_agent.print_response(
50-
"What is the currency of Japan?", session_id="user2_session_2"
51-
)
36+
# User 2 sessions
37+
print("\n=== User 2 Sessions ===")
38+
await agent.aprint_response(
39+
"What is the population of India?",
40+
session_id="user2_session_1",
41+
user_id="user_2",
42+
)
43+
await agent.aprint_response(
44+
"What is the currency of Japan?", session_id="user2_session_2", user_id="user_2"
45+
)
5246

53-
# Now test session history search - each user should only see their own sessions
54-
print("\n=== Testing Session History Search ===")
55-
print(
56-
"User 1 asking about previous conversations (should only see capitals, not population/currency):"
57-
)
58-
user_1_agent.print_response(
59-
"What did I discuss in my previous conversations?", session_id="user1_session_4"
60-
)
47+
# Now test session history search - each user should only see their own sessions
48+
print("\n=== Testing Session History Search ===")
49+
print(
50+
"User 1 asking about previous conversations (should only see capitals, not population/currency):"
51+
)
52+
await agent.aprint_response(
53+
"What did I discuss in my previous conversations?",
54+
session_id="user1_session_4",
55+
user_id="user_1",
56+
)
6157

62-
print(
63-
"\nUser 2 asking about previous conversations (should only see population/currency, not capitals):"
64-
)
65-
user_2_agent.print_response(
66-
"What did I discuss in my previous conversations?", session_id="user2_session_3"
67-
)
58+
print(
59+
"\nUser 2 asking about previous conversations (should only see population/currency, not capitals):"
60+
)
61+
await agent.aprint_response(
62+
"What did I discuss in my previous conversations?",
63+
session_id="user2_session_3",
64+
user_id="user_2",
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
asyncio.run(main())
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import asyncio
2+
import os
3+
4+
from agno.agent import Agent
5+
from agno.db.sqlite import AsyncSqliteDb
6+
from agno.models.openai import OpenAIChat
7+
from agno.team import Team
8+
9+
# Remove the tmp db file before running the script
10+
if os.path.exists("tmp/data.db"):
11+
os.remove("tmp/data.db")
12+
13+
# Create agents for different users to demonstrate user-specific session history
14+
team = Team(
15+
model=OpenAIChat(id="gpt-4o-mini"),
16+
members=[], # No members, just for the demo
17+
db=AsyncSqliteDb(db_file="tmp/data.db"),
18+
search_session_history=True, # allow searching previous sessions
19+
num_history_sessions=2, # only include the last 2 sessions in the search to avoid context length issues
20+
)
21+
22+
23+
async def main():
24+
# User 1 sessions
25+
print("=== User 1 Sessions ===")
26+
await team.aprint_response(
27+
"What is the capital of South Africa?",
28+
session_id="user1_session_1",
29+
user_id="user_1",
30+
)
31+
await team.aprint_response(
32+
"What is the capital of China?", session_id="user1_session_2", user_id="user_1"
33+
)
34+
await team.aprint_response(
35+
"What is the capital of France?", session_id="user1_session_3", user_id="user_1"
36+
)
37+
38+
# User 2 sessions
39+
print("\n=== User 2 Sessions ===")
40+
await team.aprint_response(
41+
"What is the population of India?",
42+
session_id="user2_session_1",
43+
user_id="user_2",
44+
)
45+
await team.aprint_response(
46+
"What is the currency of Japan?", session_id="user2_session_2", user_id="user_2"
47+
)
48+
49+
# Now test session history search - each user should only see their own sessions
50+
print("\n=== Testing Session History Search ===")
51+
print(
52+
"User 1 asking about previous conversations (should only see capitals, not population/currency):"
53+
)
54+
await team.aprint_response(
55+
"What did I discuss in my previous conversations?",
56+
session_id="user1_session_4",
57+
user_id="user_1",
58+
)
59+
60+
print(
61+
"\nUser 2 asking about previous conversations (should only see population/currency, not capitals):"
62+
)
63+
await team.aprint_response(
64+
"What did I discuss in my previous conversations?",
65+
session_id="user2_session_3",
66+
user_id="user_2",
67+
)
68+
69+
70+
if __name__ == "__main__":
71+
asyncio.run(main())

libs/agno/agno/agent/agent.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5313,6 +5313,7 @@ async def aget_tools(
53135313
session: AgentSession,
53145314
user_id: Optional[str] = None,
53155315
knowledge_filters: Optional[Dict[str, Any]] = None,
5316+
check_mcp_tools: bool = True,
53165317
) -> List[Union[Toolkit, Callable, Function, Dict]]:
53175318
agent_tools: List[Union[Toolkit, Callable, Function, Dict]] = []
53185319

@@ -5339,7 +5340,7 @@ async def aget_tools(
53395340
continue
53405341

53415342
# Only add the tool if it successfully connected and built its tools
5342-
if not tool.initialized: # type: ignore
5343+
if check_mcp_tools and not tool.initialized: # type: ignore
53435344
continue
53445345

53455346
agent_tools.append(tool)
@@ -5353,7 +5354,9 @@ async def aget_tools(
53535354
agent_tools.append(self._get_tool_call_history_function(session=session))
53545355
if self.search_session_history:
53555356
agent_tools.append(
5356-
await self._aget_previous_sessions_messages_function(num_history_sessions=self.num_history_sessions)
5357+
await self._aget_previous_sessions_messages_function(
5358+
num_history_sessions=self.num_history_sessions, user_id=user_id
5359+
)
53575360
)
53585361

53595362
if self.enable_agentic_memory:
@@ -5427,14 +5430,14 @@ def _determine_tools_for_model(
54275430
if name in _function_names:
54285431
continue
54295432
_function_names.append(name)
5430-
5433+
_func = _func.model_copy(deep=True)
54315434
_func._agent = self
54325435
_func.process_entrypoint(strict=strict)
54335436
if strict and _func.strict is None:
54345437
_func.strict = True
54355438
if self.tool_hooks is not None:
54365439
_func.tool_hooks = self.tool_hooks
5437-
_functions.append(_func.model_copy(deep=True))
5440+
_functions.append(_func)
54385441
log_debug(f"Added tool {name} from {tool.name}")
54395442

54405443
# Add instructions from the toolkit
@@ -5446,13 +5449,15 @@ def _determine_tools_for_model(
54465449
continue
54475450
_function_names.append(tool.name)
54485451

5449-
tool._agent = self
54505452
tool.process_entrypoint(strict=strict)
5453+
tool = tool.model_copy(deep=True)
5454+
5455+
tool._agent = self
54515456
if strict and tool.strict is None:
54525457
tool.strict = True
54535458
if self.tool_hooks is not None:
54545459
tool.tool_hooks = self.tool_hooks
5455-
_functions.append(tool.model_copy(deep=True))
5460+
_functions.append(tool)
54565461
log_debug(f"Added tool {tool.name}")
54575462

54585463
# Add instructions from the Function
@@ -5468,12 +5473,13 @@ def _determine_tools_for_model(
54685473
_function_names.append(function_name)
54695474

54705475
_func = Function.from_callable(tool, strict=strict)
5476+
_func = _func.model_copy(deep=True)
54715477
_func._agent = self
54725478
if strict:
54735479
_func.strict = True
54745480
if self.tool_hooks is not None:
54755481
_func.tool_hooks = self.tool_hooks
5476-
_functions.append(_func.model_copy(deep=True))
5482+
_functions.append(_func)
54775483
log_debug(f"Added tool {_func.name}")
54785484
except Exception as e:
54795485
log_warning(f"Could not add tool {tool}: {e}")
@@ -5806,7 +5812,11 @@ def get_run_output(self, run_id: str, session_id: Optional[str] = None) -> Optio
58065812
Returns:
58075813
Optional[RunOutput]: The RunOutput from the database or None if not found.
58085814
"""
5809-
return cast(RunOutput, get_run_output_util(self, run_id=run_id, session_id=session_id))
5815+
if not session_id and not self.session_id:
5816+
raise Exception("No session_id provided")
5817+
5818+
session_id_to_load = session_id or self.session_id
5819+
return cast(RunOutput, get_run_output_util(self, run_id=run_id, session_id=session_id_to_load))
58105820

58115821
async def aget_run_output(self, run_id: str, session_id: Optional[str] = None) -> Optional[RunOutput]:
58125822
"""
@@ -5818,7 +5828,11 @@ async def aget_run_output(self, run_id: str, session_id: Optional[str] = None) -
58185828
Returns:
58195829
Optional[RunOutput]: The RunOutput from the database or None if not found.
58205830
"""
5821-
return cast(RunOutput, await aget_run_output_util(self, run_id=run_id, session_id=session_id))
5831+
if not session_id and not self.session_id:
5832+
raise Exception("No session_id provided")
5833+
5834+
session_id_to_load = session_id or self.session_id
5835+
return cast(RunOutput, await aget_run_output_util(self, run_id=run_id, session_id=session_id_to_load))
58225836

58235837
def get_last_run_output(self, session_id: Optional[str] = None) -> Optional[RunOutput]:
58245838
"""
@@ -5830,7 +5844,11 @@ def get_last_run_output(self, session_id: Optional[str] = None) -> Optional[RunO
58305844
Returns:
58315845
Optional[RunOutput]: The last run response from the database or None if not found.
58325846
"""
5833-
return cast(RunOutput, get_last_run_output_util(self, session_id=session_id))
5847+
if not session_id and not self.session_id:
5848+
raise Exception("No session_id provided")
5849+
5850+
session_id_to_load = session_id or self.session_id
5851+
return cast(RunOutput, get_last_run_output_util(self, session_id=session_id_to_load))
58345852

58355853
async def aget_last_run_output(self, session_id: Optional[str] = None) -> Optional[RunOutput]:
58365854
"""
@@ -5842,7 +5860,11 @@ async def aget_last_run_output(self, session_id: Optional[str] = None) -> Option
58425860
Returns:
58435861
Optional[RunOutput]: The last run response from the database or None if not found.
58445862
"""
5845-
return cast(RunOutput, await aget_last_run_output_util(self, session_id=session_id))
5863+
if not session_id and not self.session_id:
5864+
raise Exception("No session_id provided")
5865+
5866+
session_id_to_load = session_id or self.session_id
5867+
return cast(RunOutput, await aget_last_run_output_util(self, session_id=session_id_to_load))
58465868

58475869
def get_session(
58485870
self,
@@ -9541,12 +9563,14 @@ def get_previous_session_messages() -> str:
95419563

95429564
return get_previous_session_messages
95439565

9544-
async def _aget_previous_sessions_messages_function(self, num_history_sessions: Optional[int] = 2) -> Callable:
9566+
async def _aget_previous_sessions_messages_function(
9567+
self, num_history_sessions: Optional[int] = 2, user_id: Optional[str] = None
9568+
) -> Function:
95459569
"""Factory function to create a get_previous_session_messages function.
95469570
95479571
Args:
95489572
num_history_sessions: The last n sessions to be taken from db
9549-
9573+
user_id: The user ID to filter sessions by
95509574
Returns:
95519575
Callable: A function that retrieves messages from previous sessions
95529576
"""
@@ -9564,12 +9588,22 @@ async def aget_previous_session_messages() -> str:
95649588
if self.db is None:
95659589
return "Previous session messages not available"
95669590

9567-
if isinstance(self.db, AsyncBaseDb):
9568-
selected_sessions = await self.db.get_sessions(
9569-
session_type=SessionType.AGENT, limit=num_history_sessions
9591+
if self._has_async_db():
9592+
selected_sessions = await self.db.get_sessions( # type: ignore
9593+
session_type=SessionType.AGENT,
9594+
limit=num_history_sessions,
9595+
user_id=user_id,
9596+
sort_by="created_at",
9597+
sort_order="desc",
95709598
)
95719599
else:
9572-
selected_sessions = self.db.get_sessions(session_type=SessionType.AGENT, limit=num_history_sessions)
9600+
selected_sessions = self.db.get_sessions(
9601+
session_type=SessionType.AGENT,
9602+
limit=num_history_sessions,
9603+
user_id=user_id,
9604+
sort_by="created_at",
9605+
sort_order="desc",
9606+
)
95739607

95749608
all_messages = []
95759609
seen_message_pairs = set()
@@ -9602,7 +9636,7 @@ async def aget_previous_session_messages() -> str:
96029636

96039637
return json.dumps([msg.to_dict() for msg in all_messages]) if all_messages else "No history found"
96049638

9605-
return aget_previous_session_messages
9639+
return Function.from_callable(aget_previous_session_messages, name="get_previous_session_messages")
96069640

96079641
###########################################################################
96089642
# Print Response

libs/agno/agno/os/schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def filter_meaningful_config(d: Dict[str, Any], defaults: Dict[str, Any]) -> Opt
246246
agent_tools = await agent.aget_tools(
247247
session=AgentSession(session_id=str(uuid4()), session_data={}),
248248
run_response=RunOutput(run_id=str(uuid4())),
249+
check_mcp_tools=False,
249250
)
250251
formatted_tools = format_tools(agent_tools) if agent_tools else None
251252

@@ -472,7 +473,9 @@ def filter_meaningful_config(d: Dict[str, Any], defaults: Dict[str, Any]) -> Opt
472473
async_mode=True,
473474
session_state={},
474475
team_run_context={},
476+
check_mcp_tools=False,
475477
)
478+
print(team.tools, _tools)
476479
team_tools = _tools
477480
formatted_tools = format_team_tools(team_tools) if team_tools else None
478481

0 commit comments

Comments
 (0)