Skip to content

Add support for storing, maintaining user memories in the backend #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/interface/web/app/components/userMemory/userMemory.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { useState } from "react";
import { Input } from "@/components/ui/input";
import { Button } from "@/components/ui/button";
import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react";
import { useToast } from "@/components/ui/use-toast";

export interface UserMemorySchema {
id: number;
raw: string;
created_at: string;
}

interface UserMemoryProps {
memory: UserMemorySchema;
onDelete: (id: number) => void;
onUpdate: (id: number, raw: string) => void;
}

export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) {
const [isEditing, setIsEditing] = useState(false);
const [content, setContent] = useState(memory.raw);
const { toast } = useToast();

const handleUpdate = () => {
onUpdate(memory.id, content);
setIsEditing(false);
toast({
title: "Memory Updated",
description: "Your memory has been successfully updated.",
});
};

const handleDelete = () => {
onDelete(memory.id);
toast({
title: "Memory Deleted",
description: "Your memory has been successfully deleted.",
});
};

return (
<div className="flex items-center gap-2 w-full">
{isEditing ? (
<>
<Input
value={content}
onChange={(e) => setContent(e.target.value)}
className="flex-1"
/>
<Button
variant="ghost"
size="icon"
onClick={handleUpdate}
title="Save"
>
<FloppyDisk className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="icon"
onClick={() => setIsEditing(false)}
title="Cancel"
>
<X className="h-4 w-4" />
</Button>
</>
) : (
<>
<Input value={memory.raw} readOnly className="flex-1" />
<Button
variant="ghost"
size="icon"
onClick={() => setIsEditing(true)}
title="Edit"
>
<Pencil className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="icon"
onClick={handleDelete}
title="Delete"
>
<TrashSimple className="h-4 w-4" />
</Button>
</>
)}
</div>
);
}
122 changes: 119 additions & 3 deletions src/interface/web/app/settings/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { Button } from "@/components/ui/button";
import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp";
import { Input } from "@/components/ui/input";
import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card";

import {
DropdownMenu,
DropdownMenuContent,
Expand All @@ -23,9 +24,25 @@ import {
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import {
AlertDialog, AlertDialogAction, AlertDialogCancel,
AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
AlertDialogTrigger
} from "@/components/ui/alert-dialog";

import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogTrigger
} from "@/components/ui/dialog";

import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table";

import {
Expand Down Expand Up @@ -67,6 +84,7 @@ import Loading from "../components/loading/loading";
import IntlTelInput from "intl-tel-input/react";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "../components/appSidebar/appSidebar";
import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory";
import { Separator } from "@/components/ui/separator";
import { KhojLogoType } from "../components/logo/khojLogo";
import { Progress } from "@/components/ui/progress";
Expand Down Expand Up @@ -308,6 +326,7 @@ export default function SettingsView() {
const [numberValidationState, setNumberValidationState] = useState<PhoneNumberValidationState>(
PhoneNumberValidationState.Verified,
);
const [memories, setMemories] = useState<UserMemorySchema[]>([]);
const [isExporting, setIsExporting] = useState(false);
const [exportProgress, setExportProgress] = useState(0);
const [exportedConversations, setExportedConversations] = useState(0);
Expand Down Expand Up @@ -649,6 +668,65 @@ export default function SettingsView() {
}
};

const fetchMemories = async () => {
try {
console.log("Fetching memories...");
const response = await fetch('/api/memories/');
if (!response.ok) throw new Error('Failed to fetch memories');
const data = await response.json();
setMemories(data);
} catch (error) {
console.error('Error fetching memories:', error);
toast({
title: "Error",
description: "Failed to fetch memories. Please try again.",
variant: "destructive"
});
}
};

const handleDeleteMemory = async (id: number) => {
try {
const response = await fetch(`/api/memories/${id}`, {
method: 'DELETE'
});
if (!response.ok) throw new Error('Failed to delete memory');
setMemories(memories.filter(memory => memory.id !== id));
} catch (error) {
console.error('Error deleting memory:', error);
toast({
title: "Error",
description: "Failed to delete memory. Please try again.",
variant: "destructive"
});
}
};

const handleUpdateMemory = async (id: number, raw: string) => {
try {
const response = await fetch(`/api/memories/${id}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ raw, memory_id: id }),
});
if (!response.ok) throw new Error('Failed to update memory');
const updatedMemory: UserMemorySchema = await response.json();
setMemories(memories.map(memory =>
memory.id === id ? updatedMemory : memory
));
} catch (error) {
console.error('Error updating memory:', error);
toast({
title: "Error",
description: "Failed to update memory. Please try again.",
variant: "destructive"
});
}
};


const syncContent = async (type: string) => {
try {
const response = await fetch(`/api/content?t=${type}`, {
Expand Down Expand Up @@ -1237,7 +1315,45 @@ export default function SettingsView() {
</Button>
</CardFooter>
</Card>

<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<Brain className="h-7 w-7 mr-2" />
Memories
</CardHeader>
<CardContent className="overflow-hidden">
<p className="pb-4 text-gray-400">
View and manage your long-term memories
</p>
</CardContent>
<CardFooter className="flex flex-wrap gap-4">
<Dialog onOpenChange={(open) => open && fetchMemories()}>
<DialogTrigger asChild>
<Button variant="outline">
<Brain className="h-5 w-5 mr-2" />
Browse Memories
</Button>
</DialogTrigger>
<DialogContent className="max-w-2xl max-h-[80vh] overflow-y-auto">
<DialogHeader>
<DialogTitle>Your Memories</DialogTitle>
</DialogHeader>
<div className="grid gap-4 py-4">
{memories.map((memory) => (
<UserMemory
key={memory.id}
memory={memory}
onDelete={handleDeleteMemory}
onUpdate={handleUpdateMemory}
/>
))}
{memories.length === 0 && (
<p className="text-center text-gray-500">No memories found</p>
)}
</div>
</DialogContent>
</Dialog>
</CardFooter>
</Card>
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<TrashSimple className="h-7 w-7 mr-2 text-red-500" />
Expand Down
2 changes: 2 additions & 0 deletions src/khoj/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def configure_routes(app):
from khoj.routers.api_agents import api_agents
from khoj.routers.api_chat import api_chat
from khoj.routers.api_content import api_content
from khoj.routers.api_memories import api_memories
from khoj.routers.api_model import api_model
from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client
Expand All @@ -322,6 +323,7 @@ def configure_routes(app):
app.include_router(api_chat, prefix="/api/chat")
app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_model, prefix="/api/model")
app.include_router(api_memories, prefix="/api/memories")
app.include_router(api_content, prefix="/api/content")
app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client)
Expand Down
89 changes: 89 additions & 0 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserMemory,
UserRequests,
UserTextToImageModelConfig,
UserVoiceModelConfig,
Expand Down Expand Up @@ -574,6 +575,16 @@ def get_default_search_model() -> SearchModelConfig:
return SearchModelConfig.objects.first()


async def aget_default_search_model() -> SearchModelConfig:
default_search_model = await SearchModelConfig.objects.filter(name="default").afirst()

if default_search_model:
return default_search_model
elif await SearchModelConfig.objects.count() == 0:
await SearchModelConfig.objects.acreate()
return await SearchModelConfig.objects.afirst()


def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
Expand Down Expand Up @@ -1989,3 +2000,81 @@ def delete_automation(user: KhojUser, automation_id: str):

automation.remove()
return automation_metadata


class UserMemoryAdapters:
@staticmethod
@require_valid_user
async def pull_memories(user: KhojUser, window=10, limit=5) -> list[UserMemory]:
"""
Pulls memories from the database for a given user. Medium term memory.
"""
time_frame = datetime.now(timezone.utc) - timedelta(days=window)
memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit]
return memories

@staticmethod
@require_valid_user
async def save_memory(user: KhojUser, memory: str) -> UserMemory:
"""
Saves a memory to the database for a given user.
"""
embeddings_model = state.embeddings_model
model = await aget_default_search_model()

embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory)
memory_instance = await UserMemory.objects.acreate(
user=user, embeddings=embeddings, raw=memory, search_model=model
)

return memory_instance

@staticmethod
@require_valid_user
async def search_memories(user: KhojUser, query: str) -> list[UserMemory]:
"""
Searches for memories in the database for a given user. Long term memory.
"""
embeddings_model = state.embeddings_model
model = await aget_default_search_model()

max_distance = model.bi_encoder_confidence_threshold or math.inf

embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query)

relevant_memories = (
UserMemory.objects.filter(user=user)
.annotate(distance=CosineDistance("embeddings", embedded_query))
.order_by("distance")
)

relevant_memories = relevant_memories.filter(distance__lte=max_distance)

return relevant_memories[:10]

@staticmethod
@require_valid_user
async def delete_memory(user: KhojUser, memory_id: str) -> bool:
"""
Deletes a memory from the database for a given user.
"""
try:
memory = await UserMemory.objects.aget(user=user, id=memory_id)
await memory.adelete()
return True
except UserMemory.DoesNotExist:
return False

@staticmethod
def convert_memories_to_dict(memories: List[UserMemory]) -> List[dict]:
"""
Converts a list of Memory objects to a list of dictionaries.
"""
return [
{
"id": memory.id,
"raw": memory.raw,
"updated_at": memory.updated_at,
}
for memory in memories
]
2 changes: 2 additions & 0 deletions src/khoj/database/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserMemory,
UserRequests,
UserVoiceModelConfig,
VoiceModelOption,
Expand Down Expand Up @@ -181,6 +182,7 @@ def get_email_login_url(self, request, queryset):
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
admin.site.register(UserMemory, unfold_admin.ModelAdmin)


@admin.register(Agent)
Expand Down
Loading
Loading