Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 50 additions & 48 deletions cookbook/agents/state/last_n_session_messages.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,69 @@
import asyncio
import os

from agno.agent import Agent
from agno.db.sqlite import SqliteDb
from agno.db.sqlite import AsyncSqliteDb
from agno.models.openai import OpenAIChat

# Remove the tmp db file before running the script
if os.path.exists("tmp/data.db"):
os.remove("tmp/data.db")

# Create agents for different users to demonstrate user-specific session history
user_1_agent = Agent(
agent = Agent(
model=OpenAIChat(id="gpt-4o-mini"),
user_id="user_1",
db=SqliteDb(db_file="tmp/data.db"),
add_history_to_context=True,
num_history_runs=3,
db=AsyncSqliteDb(db_file="tmp/data.db"),
search_session_history=True, # allow searching previous sessions
num_history_sessions=2, # only include the last 2 sessions in the search to avoid context length issues
)

user_2_agent = Agent(
model=OpenAIChat(id="gpt-4o-mini"),
user_id="user_2",
db=SqliteDb(db_file="tmp/data.db"),
add_history_to_context=True,
num_history_runs=3,
search_session_history=True,
num_history_sessions=2,
)

# User 1 sessions
print("=== User 1 Sessions ===")
user_1_agent.print_response(
"What is the capital of South Africa?", session_id="user1_session_1"
)
user_1_agent.print_response(
"What is the capital of China?", session_id="user1_session_2"
)
user_1_agent.print_response(
"What is the capital of France?", session_id="user1_session_3"
)
async def main():
# User 1 sessions
print("=== User 1 Sessions ===")
await agent.aprint_response(
"What is the capital of South Africa?",
session_id="user1_session_1",
user_id="user_1",
)
await agent.aprint_response(
"What is the capital of China?", session_id="user1_session_2", user_id="user_1"
)
await agent.aprint_response(
"What is the capital of France?", session_id="user1_session_3", user_id="user_1"
)

# User 2 sessions
print("\n=== User 2 Sessions ===")
user_2_agent.print_response(
"What is the population of India?", session_id="user2_session_1"
)
user_2_agent.print_response(
"What is the currency of Japan?", session_id="user2_session_2"
)
# User 2 sessions
print("\n=== User 2 Sessions ===")
await agent.aprint_response(
"What is the population of India?",
session_id="user2_session_1",
user_id="user_2",
)
await agent.aprint_response(
"What is the currency of Japan?", session_id="user2_session_2", user_id="user_2"
)

# Now test session history search - each user should only see their own sessions
print("\n=== Testing Session History Search ===")
print(
"User 1 asking about previous conversations (should only see capitals, not population/currency):"
)
user_1_agent.print_response(
"What did I discuss in my previous conversations?", session_id="user1_session_4"
)
# Now test session history search - each user should only see their own sessions
print("\n=== Testing Session History Search ===")
print(
"User 1 asking about previous conversations (should only see capitals, not population/currency):"
)
await agent.aprint_response(
"What did I discuss in my previous conversations?",
session_id="user1_session_4",
user_id="user_1",
)

print(
"\nUser 2 asking about previous conversations (should only see population/currency, not capitals):"
)
user_2_agent.print_response(
"What did I discuss in my previous conversations?", session_id="user2_session_3"
)
print(
"\nUser 2 asking about previous conversations (should only see population/currency, not capitals):"
)
await agent.aprint_response(
"What did I discuss in my previous conversations?",
session_id="user2_session_3",
user_id="user_2",
)


if __name__ == "__main__":
asyncio.run(main())
70 changes: 70 additions & 0 deletions cookbook/teams/session/11_search_session_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import asyncio
import os

from agno.db.sqlite import AsyncSqliteDb
from agno.models.openai import OpenAIChat
from agno.team import Team

# Remove the tmp db file before running the script
if os.path.exists("tmp/data.db"):
os.remove("tmp/data.db")

# Create agents for different users to demonstrate user-specific session history
team = Team(
model=OpenAIChat(id="gpt-4o-mini"),
members=[], # No members, just for the demo
db=AsyncSqliteDb(db_file="tmp/data.db"),
search_session_history=True, # allow searching previous sessions
num_history_sessions=2, # only include the last 2 sessions in the search to avoid context length issues
)


async def main():
# User 1 sessions
print("=== User 1 Sessions ===")
await team.aprint_response(
"What is the capital of South Africa?",
session_id="user1_session_1",
user_id="user_1",
)
await team.aprint_response(
"What is the capital of China?", session_id="user1_session_2", user_id="user_1"
)
await team.aprint_response(
"What is the capital of France?", session_id="user1_session_3", user_id="user_1"
)

# User 2 sessions
print("\n=== User 2 Sessions ===")
await team.aprint_response(
"What is the population of India?",
session_id="user2_session_1",
user_id="user_2",
)
await team.aprint_response(
"What is the currency of Japan?", session_id="user2_session_2", user_id="user_2"
)

# Now test session history search - each user should only see their own sessions
print("\n=== Testing Session History Search ===")
print(
"User 1 asking about previous conversations (should only see capitals, not population/currency):"
)
await team.aprint_response(
"What did I discuss in my previous conversations?",
session_id="user1_session_4",
user_id="user_1",
)

print(
"\nUser 2 asking about previous conversations (should only see population/currency, not capitals):"
)
await team.aprint_response(
"What did I discuss in my previous conversations?",
session_id="user2_session_3",
user_id="user_2",
)


if __name__ == "__main__":
asyncio.run(main())
70 changes: 52 additions & 18 deletions libs/agno/agno/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5313,6 +5313,7 @@ async def aget_tools(
session: AgentSession,
user_id: Optional[str] = None,
knowledge_filters: Optional[Dict[str, Any]] = None,
check_mcp_tools: bool = True,
) -> List[Union[Toolkit, Callable, Function, Dict]]:
agent_tools: List[Union[Toolkit, Callable, Function, Dict]] = []

Expand All @@ -5339,7 +5340,7 @@ async def aget_tools(
continue

# Only add the tool if it successfully connected and built its tools
if not tool.initialized: # type: ignore
if check_mcp_tools and not tool.initialized: # type: ignore
continue

agent_tools.append(tool)
Expand All @@ -5353,7 +5354,9 @@ async def aget_tools(
agent_tools.append(self._get_tool_call_history_function(session=session))
if self.search_session_history:
agent_tools.append(
await self._aget_previous_sessions_messages_function(num_history_sessions=self.num_history_sessions)
await self._aget_previous_sessions_messages_function(
num_history_sessions=self.num_history_sessions, user_id=user_id
)
)

if self.enable_agentic_memory:
Expand Down Expand Up @@ -5427,14 +5430,14 @@ def _determine_tools_for_model(
if name in _function_names:
continue
_function_names.append(name)

_func = _func.model_copy(deep=True)
_func._agent = self
_func.process_entrypoint(strict=strict)
if strict and _func.strict is None:
_func.strict = True
if self.tool_hooks is not None:
_func.tool_hooks = self.tool_hooks
_functions.append(_func.model_copy(deep=True))
_functions.append(_func)
log_debug(f"Added tool {name} from {tool.name}")

# Add instructions from the toolkit
Expand All @@ -5446,13 +5449,15 @@ def _determine_tools_for_model(
continue
_function_names.append(tool.name)

tool._agent = self
tool.process_entrypoint(strict=strict)
tool = tool.model_copy(deep=True)

tool._agent = self
if strict and tool.strict is None:
tool.strict = True
if self.tool_hooks is not None:
tool.tool_hooks = self.tool_hooks
_functions.append(tool.model_copy(deep=True))
_functions.append(tool)
log_debug(f"Added tool {tool.name}")

# Add instructions from the Function
Expand All @@ -5468,12 +5473,13 @@ def _determine_tools_for_model(
_function_names.append(function_name)

_func = Function.from_callable(tool, strict=strict)
_func = _func.model_copy(deep=True)
_func._agent = self
if strict:
_func.strict = True
if self.tool_hooks is not None:
_func.tool_hooks = self.tool_hooks
_functions.append(_func.model_copy(deep=True))
_functions.append(_func)
log_debug(f"Added tool {_func.name}")
except Exception as e:
log_warning(f"Could not add tool {tool}: {e}")
Expand Down Expand Up @@ -5806,7 +5812,11 @@ def get_run_output(self, run_id: str, session_id: Optional[str] = None) -> Optio
Returns:
Optional[RunOutput]: The RunOutput from the database or None if not found.
"""
return cast(RunOutput, get_run_output_util(self, run_id=run_id, session_id=session_id))
if not session_id and not self.session_id:
raise Exception("No session_id provided")

session_id_to_load = session_id or self.session_id
return cast(RunOutput, get_run_output_util(self, run_id=run_id, session_id=session_id_to_load))

async def aget_run_output(self, run_id: str, session_id: Optional[str] = None) -> Optional[RunOutput]:
"""
Expand All @@ -5818,7 +5828,11 @@ async def aget_run_output(self, run_id: str, session_id: Optional[str] = None) -
Returns:
Optional[RunOutput]: The RunOutput from the database or None if not found.
"""
return cast(RunOutput, await aget_run_output_util(self, run_id=run_id, session_id=session_id))
if not session_id and not self.session_id:
raise Exception("No session_id provided")

session_id_to_load = session_id or self.session_id
return cast(RunOutput, await aget_run_output_util(self, run_id=run_id, session_id=session_id_to_load))

def get_last_run_output(self, session_id: Optional[str] = None) -> Optional[RunOutput]:
"""
Expand All @@ -5830,7 +5844,11 @@ def get_last_run_output(self, session_id: Optional[str] = None) -> Optional[RunO
Returns:
Optional[RunOutput]: The last run response from the database or None if not found.
"""
return cast(RunOutput, get_last_run_output_util(self, session_id=session_id))
if not session_id and not self.session_id:
raise Exception("No session_id provided")

session_id_to_load = session_id or self.session_id
return cast(RunOutput, get_last_run_output_util(self, session_id=session_id_to_load))

async def aget_last_run_output(self, session_id: Optional[str] = None) -> Optional[RunOutput]:
"""
Expand All @@ -5842,7 +5860,11 @@ async def aget_last_run_output(self, session_id: Optional[str] = None) -> Option
Returns:
Optional[RunOutput]: The last run response from the database or None if not found.
"""
return cast(RunOutput, await aget_last_run_output_util(self, session_id=session_id))
if not session_id and not self.session_id:
raise Exception("No session_id provided")

session_id_to_load = session_id or self.session_id
return cast(RunOutput, await aget_last_run_output_util(self, session_id=session_id_to_load))

def get_session(
self,
Expand Down Expand Up @@ -9541,12 +9563,14 @@ def get_previous_session_messages() -> str:

return get_previous_session_messages

async def _aget_previous_sessions_messages_function(self, num_history_sessions: Optional[int] = 2) -> Callable:
async def _aget_previous_sessions_messages_function(
self, num_history_sessions: Optional[int] = 2, user_id: Optional[str] = None
) -> Function:
"""Factory function to create a get_previous_session_messages function.

Args:
num_history_sessions: The last n sessions to be taken from db

user_id: The user ID to filter sessions by
Returns:
Callable: A function that retrieves messages from previous sessions
"""
Expand All @@ -9564,12 +9588,22 @@ async def aget_previous_session_messages() -> str:
if self.db is None:
return "Previous session messages not available"

if isinstance(self.db, AsyncBaseDb):
selected_sessions = await self.db.get_sessions(
session_type=SessionType.AGENT, limit=num_history_sessions
if self._has_async_db():
selected_sessions = await self.db.get_sessions( # type: ignore
session_type=SessionType.AGENT,
limit=num_history_sessions,
user_id=user_id,
sort_by="created_at",
sort_order="desc",
)
else:
selected_sessions = self.db.get_sessions(session_type=SessionType.AGENT, limit=num_history_sessions)
selected_sessions = self.db.get_sessions(
session_type=SessionType.AGENT,
limit=num_history_sessions,
user_id=user_id,
sort_by="created_at",
sort_order="desc",
)

all_messages = []
seen_message_pairs = set()
Expand Down Expand Up @@ -9602,7 +9636,7 @@ async def aget_previous_session_messages() -> str:

return json.dumps([msg.to_dict() for msg in all_messages]) if all_messages else "No history found"

return aget_previous_session_messages
return Function.from_callable(aget_previous_session_messages, name="get_previous_session_messages")

###########################################################################
# Print Response
Expand Down
3 changes: 3 additions & 0 deletions libs/agno/agno/os/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def filter_meaningful_config(d: Dict[str, Any], defaults: Dict[str, Any]) -> Opt
agent_tools = await agent.aget_tools(
session=AgentSession(session_id=str(uuid4()), session_data={}),
run_response=RunOutput(run_id=str(uuid4())),
check_mcp_tools=False,
)
formatted_tools = format_tools(agent_tools) if agent_tools else None

Expand Down Expand Up @@ -472,7 +473,9 @@ def filter_meaningful_config(d: Dict[str, Any], defaults: Dict[str, Any]) -> Opt
async_mode=True,
session_state={},
team_run_context={},
check_mcp_tools=False,
)
print(team.tools, _tools)
team_tools = _tools
formatted_tools = format_team_tools(team_tools) if team_tools else None

Expand Down
Loading
Loading