Skip to content

Fix Mem0 OSS #2604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 28, 2025
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pandas = [
openpyxl = [
"openpyxl>=3.1.5",
]
mem0 = ["mem0ai>=0.1.29"]
mem0 = ["mem0ai>=0.1.90"]
docling = [
"docling>=2.12.0",
]
Expand Down
11 changes: 8 additions & 3 deletions src/crewai/memory/storage/mem0_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,17 @@ def save(self, value: Any, metadata: Dict[str, Any]) -> None:
}

if params:
self.memory.add(value, **params | {"output_format": "v1.1"})
if isinstance(self.memory, MemoryClient) or isinstance(self.memory, AsyncMemoryClient):
params["output_format"] = "v1.1"
self.memory.add(value, **params)

def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
params = {"query": query, "limit": limit}
params = {"query": query, "limit": limit, "output_format": "v1.1"}
if user_id := self._get_user_id():
params["user_id"] = user_id

Expand All @@ -116,8 +118,11 @@ def search(

# Discard the filters for now since we create the filters
# automatically when the crew is created.
if isinstance(self.memory, Memory) or isinstance(self.memory, AsyncMemory):
del params["metadata"], params["output_format"]

results = self.memory.search(**params)
return [r for r in results if r["score"] >= score_threshold]
return [r for r in results["results"] if r["score"] >= score_threshold]

def _get_user_id(self) -> str:
return self._get_config().get("user_id", "")
Expand Down
3 changes: 1 addition & 2 deletions tests/memory/user_memory_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -65,4 +64,4 @@ def test_save_and_search(user_memory):
with patch.object(UserMemory, 'search', return_value=expected_result) as mock_search:
find = UserMemory.search("test value", score_threshold=0.01)[0]
mock_search.assert_called_once_with("test value", score_threshold=0.01)
assert find == expected_result[0]
assert find == expected_result[0]
88 changes: 85 additions & 3 deletions tests/storage/test_mem0_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class MockCrew:
def __init__(self, memory_config):
self.memory_config = memory_config
self.agents = [MagicMock(role="Test Agent")]


@pytest.fixture
Expand Down Expand Up @@ -107,11 +108,13 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie


@pytest.fixture
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client):
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
"""Fixture to create a Mem0Storage instance with mocked dependencies"""

# We need to patch the MemoryClient before it's instantiated
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
# We need to patch both MemoryClient and Memory to prevent actual initialization
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
patch.object(Memory, "__new__", return_value=mock_mem0_memory):

crew = MockCrew(
memory_config={
"provider": "mem0",
Expand Down Expand Up @@ -155,3 +158,82 @@ def test_mem0_storage_with_explict_config(
mem0_storage_with_memory_client_using_explictly_config.memory_config
== expected_config
)


def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.memory.add = MagicMock()

# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}

mem0_storage.save(test_value, test_metadata)

mem0_storage.memory.add.assert_called_once_with(
test_value,
agent_id="Test_Agent",
infer=False,
metadata={"type": "short_term", "key": "value"},
)


def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
"""Test save method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mem0_storage.memory.add = MagicMock()

# Test short_term memory type (already set in fixture)
test_value = "This is a test memory"
test_metadata = {"key": "value"}

mem0_storage.save(test_value, test_metadata)

mem0_storage.memory.add.assert_called_once_with(
test_value,
agent_id="Test_Agent",
infer=False,
metadata={"type": "short_term", "key": "value"},
output_format="v1.1"
)


def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test search method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)

results = mem0_storage.search("test query", limit=5, score_threshold=0.5)

mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
agent_id="Test_Agent",
user_id="test_user"
)

assert len(results) == 1
assert results[0]["content"] == "Result 1"


def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
"""Test search method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)

results = mem0_storage.search("test query", limit=5, score_threshold=0.5)

mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
agent_id="Test_Agent",
metadata={"type": "short_term"},
user_id="test_user",
output_format='v1.1'
)

assert len(results) == 1
assert results[0]["content"] == "Result 1"
53 changes: 49 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading