|
2 | 2 | import datetime |
3 | 3 | import hashlib |
4 | 4 | import time |
5 | | -from typing import Any, Dict, List |
| 5 | +from typing import Any, Dict, List, Optional |
6 | 6 |
|
7 | 7 | from pipecat.frames.frames import Frame, LLMMessagesFrame |
8 | 8 | from pipecat.processors.aggregators.openai_llm_context import ( |
|
25 | 25 | raise Exception(f"Missing module: {e}") |
26 | 26 |
|
27 | 27 |
|
| 28 | +def get_user_memories(user_id: str, limit: int = 5): |
| 29 | + """ |
| 30 | + Shared function to retrieve user memories from Mem0. |
| 31 | + Uses ImprovedMem0MemoryService for reliable memory retrieval. |
| 32 | +
|
| 33 | + Args: |
| 34 | + user_id: User identifier for memory retrieval |
| 35 | + limit: Maximum number of memories to retrieve |
| 36 | +
|
| 37 | + Returns: |
| 38 | + List of memories or None if failed |
| 39 | + """ |
| 40 | + if not config.MEM0_API_KEY: |
| 41 | + logger.debug("MEM0_API_KEY not configured") |
| 42 | + return None |
| 43 | + |
| 44 | + try: |
| 45 | + # Use ImprovedMem0MemoryService following the same pattern as the working pipeline |
| 46 | + memory_params = ImprovedMem0MemoryService.InputParams() |
| 47 | + memory_service = ImprovedMem0MemoryService( |
| 48 | + api_key=config.MEM0_API_KEY, |
| 49 | + user_id=user_id, |
| 50 | + params=memory_params, |
| 51 | + ) |
| 52 | + |
| 53 | + # Access the memory client to get all memories for the user |
| 54 | + memories = memory_service.memory_client.get_all(user_id=user_id, limit=limit) |
| 55 | + return memories |
| 56 | + except Exception as e: |
| 57 | + logger.error(f"Error retrieving memories: {e}") |
| 58 | + return None |
| 59 | + |
| 60 | + |
| 61 | +def format_memories_as_context(memories) -> Optional[str]: |
| 62 | + """ |
| 63 | + Format memories into user context data (just the memory data, no instructions). |
| 64 | + Instructions will be moved to system prompt. |
| 65 | +
|
| 66 | + Args: |
| 67 | + memories: List of memory objects from Mem0 |
| 68 | +
|
| 69 | + Returns: |
| 70 | + str: Formatted memory data, None if no valid memories |
| 71 | + """ |
| 72 | + if not memories: |
| 73 | + return None |
| 74 | + |
| 75 | + context_lines = [] |
| 76 | + |
| 77 | + for memory in memories: |
| 78 | + # Extract memory text based on Mem0 response format |
| 79 | + memory_text = "" |
| 80 | + |
| 81 | + if isinstance(memory, dict): |
| 82 | + # Handle dict format |
| 83 | + memory_text = ( |
| 84 | + memory.get("text", "") |
| 85 | + or memory.get("memory", "") |
| 86 | + or memory.get("content", "") |
| 87 | + ) |
| 88 | + elif hasattr(memory, "text"): |
| 89 | + memory_text = memory.text |
| 90 | + elif hasattr(memory, "memory"): |
| 91 | + memory_text = memory.memory |
| 92 | + elif hasattr(memory, "content"): |
| 93 | + memory_text = memory.content |
| 94 | + |
| 95 | + # Clean and validate memory text |
| 96 | + if memory_text and isinstance(memory_text, str): |
| 97 | + cleaned_text = memory_text.strip() |
| 98 | + if cleaned_text and len(cleaned_text) > 10: # Skip very short memories |
| 99 | + context_lines.append(f"- {cleaned_text}") |
| 100 | + |
| 101 | + # Only return context if we have actual content |
| 102 | + if not context_lines: |
| 103 | + return None |
| 104 | + |
| 105 | + # Just the memory data - instructions are now in system prompt |
| 106 | + context = f""" |
| 107 | +
|
| 108 | +[USER MEMORY CONTEXT] |
| 109 | +{chr(10).join(context_lines)} |
| 110 | +
|
| 111 | +""" |
| 112 | + return context |
| 113 | + |
| 114 | + |
| 115 | +# ====== END CONSOLIDATED MEMORY FUNCTIONS ====== |
| 116 | + |
| 117 | + |
28 | 118 | class ImprovedMem0MemoryService(Mem0MemoryService): |
29 | 119 | """ |
30 | 120 | An improved version of Mem0MemoryService with enhanced reliability and performance. |
@@ -439,6 +529,43 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): |
439 | 529 |
|
440 | 530 | if latest_user_message: |
441 | 531 | process_start = time.time() |
| 532 | + |
| 533 | + # NEW: Enhance user message with historical memory context |
| 534 | + user_id = getattr(self, "user_id", None) |
| 535 | + if user_id: |
| 536 | + try: |
| 537 | + # Get user memories directly |
| 538 | + memories = get_user_memories(user_id, limit=5) |
| 539 | + |
| 540 | + if memories: |
| 541 | + # Format memories as context |
| 542 | + memory_context = format_memories_as_context(memories) |
| 543 | + |
| 544 | + if memory_context: |
| 545 | + # Enhance user message with memory context |
| 546 | + enhanced_user_message = f"{memory_context}\nCurrent user message: {latest_user_message}" |
| 547 | + |
| 548 | + # Update the context with the enhanced user message |
| 549 | + context_messages = context.get_messages() |
| 550 | + for i, msg in enumerate(context_messages): |
| 551 | + if ( |
| 552 | + msg.get("role") == "user" |
| 553 | + and msg.get("content") |
| 554 | + == latest_user_message |
| 555 | + ): |
| 556 | + # Update the last user message with memory context |
| 557 | + context_messages[i][ |
| 558 | + "content" |
| 559 | + ] = enhanced_user_message |
| 560 | + logger.debug( |
| 561 | + f"Enhanced user message with historical memory context (memories: {len(memories)})" |
| 562 | + ) |
| 563 | + break |
| 564 | + |
| 565 | + except Exception as e: |
| 566 | + logger.error(f"Dynamic memory enhancement failed: {e}") |
| 567 | + # Continue with original message |
| 568 | + |
442 | 569 | # Enhance context with memories before passing it downstream (with fault tolerance) |
443 | 570 | logger.debug( |
444 | 571 | f"Enhancing context with memories based on user message: '{latest_user_message[:50]}...'" |
|
0 commit comments