Skip to content

Commit a86a121

Browse files
Fix Mem0 OSS (#2604)
* Fix Mem0 OSS * add test * fix lint and tests * fix * add tests * drop test * changed to class comparision * fixed test cases * Update src/crewai/memory/storage/mem0_storage.py * Update src/crewai/memory/storage/mem0_storage.py * fix * fix lock file --------- Co-authored-by: Vidit-Ostwal <viditostwal@gmail.com>
1 parent 566935f commit a86a121

File tree

5 files changed

+218
-15
lines changed

5 files changed

+218
-15
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pandas = [
6060
openpyxl = [
6161
"openpyxl>=3.1.5",
6262
]
63-
mem0 = ["mem0ai>=0.1.29"]
63+
mem0 = ["mem0ai>=0.1.94"]
6464
docling = [
6565
"docling>=2.12.0",
6666
]

src/crewai/memory/storage/mem0_storage.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,17 @@ def save(self, value: Any, metadata: Dict[str, Any]) -> None:
8888
}
8989

9090
if params:
91-
self.memory.add(value, **params | {"output_format": "v1.1"})
91+
if isinstance(self.memory, MemoryClient):
92+
params["output_format"] = "v1.1"
93+
self.memory.add(value, **params)
9294

9395
def search(
9496
self,
9597
query: str,
9698
limit: int = 3,
9799
score_threshold: float = 0.35,
98100
) -> List[Any]:
99-
params = {"query": query, "limit": limit}
101+
params = {"query": query, "limit": limit, "output_format": "v1.1"}
100102
if user_id := self._get_user_id():
101103
params["user_id"] = user_id
102104

@@ -116,8 +118,11 @@ def search(
116118

117119
# Discard the filters for now since we create the filters
118120
# automatically when the crew is created.
121+
if isinstance(self.memory, Memory):
122+
del params["metadata"], params["output_format"]
123+
119124
results = self.memory.search(**params)
120-
return [r for r in results if r["score"] >= score_threshold]
125+
return [r for r in results["results"] if r["score"] >= score_threshold]
121126

122127
def _get_user_id(self) -> str:
123128
return self._get_config().get("user_id", "")

tests/memory/user_memory_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from unittest.mock import MagicMock, patch
32

43
import pytest
@@ -65,4 +64,4 @@ def test_save_and_search(user_memory):
6564
with patch.object(UserMemory, 'search', return_value=expected_result) as mock_search:
6665
find = UserMemory.search("test value", score_threshold=0.01)[0]
6766
mock_search.assert_called_once_with("test value", score_threshold=0.01)
68-
assert find == expected_result[0]
67+
assert find == expected_result[0]

tests/storage/test_mem0_storage.py

+85-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
class MockCrew:
1616
def __init__(self, memory_config):
1717
self.memory_config = memory_config
18+
self.agents = [MagicMock(role="Test Agent")]
1819

1920

2021
@pytest.fixture
@@ -107,11 +108,13 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
107108

108109

109110
@pytest.fixture
110-
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client):
111+
def mem0_storage_with_memory_client_using_explictly_config(mock_mem0_memory_client, mock_mem0_memory):
111112
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
112113

113-
# We need to patch the MemoryClient before it's instantiated
114-
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
114+
# We need to patch both MemoryClient and Memory to prevent actual initialization
115+
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client), \
116+
patch.object(Memory, "__new__", return_value=mock_mem0_memory):
117+
115118
crew = MockCrew(
116119
memory_config={
117120
"provider": "mem0",
@@ -155,3 +158,82 @@ def test_mem0_storage_with_explict_config(
155158
mem0_storage_with_memory_client_using_explictly_config.memory_config
156159
== expected_config
157160
)
161+
162+
163+
def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
164+
"""Test save method for different memory types"""
165+
mem0_storage, _, _ = mem0_storage_with_mocked_config
166+
mem0_storage.memory.add = MagicMock()
167+
168+
# Test short_term memory type (already set in fixture)
169+
test_value = "This is a test memory"
170+
test_metadata = {"key": "value"}
171+
172+
mem0_storage.save(test_value, test_metadata)
173+
174+
mem0_storage.memory.add.assert_called_once_with(
175+
test_value,
176+
agent_id="Test_Agent",
177+
infer=False,
178+
metadata={"type": "short_term", "key": "value"},
179+
)
180+
181+
182+
def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
183+
"""Test save method for different memory types"""
184+
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
185+
mem0_storage.memory.add = MagicMock()
186+
187+
# Test short_term memory type (already set in fixture)
188+
test_value = "This is a test memory"
189+
test_metadata = {"key": "value"}
190+
191+
mem0_storage.save(test_value, test_metadata)
192+
193+
mem0_storage.memory.add.assert_called_once_with(
194+
test_value,
195+
agent_id="Test_Agent",
196+
infer=False,
197+
metadata={"type": "short_term", "key": "value"},
198+
output_format="v1.1"
199+
)
200+
201+
202+
def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
203+
"""Test search method for different memory types"""
204+
mem0_storage, _, _ = mem0_storage_with_mocked_config
205+
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
206+
mem0_storage.memory.search = MagicMock(return_value=mock_results)
207+
208+
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
209+
210+
mem0_storage.memory.search.assert_called_once_with(
211+
query="test query",
212+
limit=5,
213+
agent_id="Test_Agent",
214+
user_id="test_user"
215+
)
216+
217+
assert len(results) == 1
218+
assert results[0]["content"] == "Result 1"
219+
220+
221+
def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_config_from_crew):
222+
"""Test search method for different memory types"""
223+
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
224+
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
225+
mem0_storage.memory.search = MagicMock(return_value=mock_results)
226+
227+
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
228+
229+
mem0_storage.memory.search.assert_called_once_with(
230+
query="test query",
231+
limit=5,
232+
agent_id="Test_Agent",
233+
metadata={"type": "short_term"},
234+
user_id="test_user",
235+
output_format='v1.1'
236+
)
237+
238+
assert len(results) == 1
239+
assert results[0]["content"] == "Result 1"

0 commit comments

Comments
 (0)