diff --git a/frontend/src/components/CitationLink.tsx b/frontend/src/components/CitationLink.tsx index 4259e4cee..fce9984b6 100644 --- a/frontend/src/components/CitationLink.tsx +++ b/frontend/src/components/CitationLink.tsx @@ -5,6 +5,8 @@ import { TooltipProvider, TooltipTrigger, } from "@/components/ui/tooltip" +import { FileType } from "shared/types" +import { getFileType } from "shared/fileUtils" export interface Citation { url: string @@ -12,7 +14,6 @@ export interface Citation { docId: string itemId?: string clId?: string - chunkIndex?: number } export const createCitationLink = @@ -33,8 +34,9 @@ export const createCitationLink = const [isTooltipOpen, setIsTooltipOpen] = useState(false) // Extract citation index from children (which should be the citation number like "1", "2", etc.) - const citationIndex = - typeof children === "string" ? parseInt(children) - 1 : -1 + const parts = typeof children === "string" ? children.split("_") : [] + const citationIndex = parts.length > 0 ? parseInt(parts[0]) - 1 : -1 + let chunkIndex = parts.length > 1 ? parseInt(parts[1]) : undefined // Get citation by index if valid, otherwise fall back to URL matching const citation = @@ -44,6 +46,15 @@ export const createCitationLink = ? citations.find((c) => c.url === href) : undefined + if (chunkIndex !== undefined && citation) { + children = (citationIndex + 1).toString() + if ( + getFileType({ type: "", name: citation?.title ?? "" }) === + FileType.SPREADSHEET + ) + chunkIndex = Math.max(chunkIndex - 1, 0) + } + if (citation && citation.clId && citation.itemId) { return ( @@ -56,11 +67,7 @@ export const createCitationLink = e.preventDefault() e.stopPropagation() if (onCitationClick) { - if (citation.chunkIndex !== undefined) { - onCitationClick(citation, citation.chunkIndex) - } else { - onCitationClick(citation) - } + onCitationClick(citation, chunkIndex) } setIsTooltipOpen(false) }} @@ -69,66 +76,66 @@ export const createCitationLink = {showTooltip && ( - { - // Prevent closing when clicking inside the tooltip - e.preventDefault() - }} - > -
{ + { + // Prevent closing when clicking inside the tooltip e.preventDefault() - e.stopPropagation() - if (onCitationClick) { - onCitationClick(citation) - } - setIsTooltipOpen(false) }} > - {/* Document Icon */} -
- - - { + e.preventDefault() + e.stopPropagation() + if (onCitationClick) { + onCitationClick(citation) + } + setIsTooltipOpen(false) + }} + > + {/* Document Icon */} +
+ - -
- - {/* Content */} -
-
- {citation.title.split("/").pop() || "Untitled Document"} + xmlns="http://www.w3.org/2000/svg" + className="text-gray-600 dark:text-gray-400" + > + + +
-
- {citation.title.replace(/[^/]*$/, "") || "No file name"} + + {/* Content */} +
+
+ {citation.title.split("/").pop() || "Untitled Document"} +
+
+ {citation.title.replace(/[^/]*$/, "") || "No file name"} +
-
- + )} @@ -139,18 +146,20 @@ export const createCitationLink = const isNumericChild = typeof children === "string" && !isNaN(parseInt(children)) && - parseInt(children).toString() === children.trim() + parseInt(children).toString() === children.split("_")[0].trim() return ( {isNumericChild ? ( {children} - ) : ( children )} + ) : ( + children + )} ) } diff --git a/frontend/src/components/CitationPreview.tsx b/frontend/src/components/CitationPreview.tsx index 58c1ee2f3..77ba52fb5 100644 --- a/frontend/src/components/CitationPreview.tsx +++ b/frontend/src/components/CitationPreview.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState } from "react" +import React, { useEffect, useState, useRef, useMemo } from "react" import { X, FileText, ExternalLink, ArrowLeft } from "lucide-react" import { Citation } from "shared/types" import PdfViewer from "./PdfViewer" @@ -8,8 +8,9 @@ import { api } from "@/api" import { authFetch } from "@/utils/authFetch" import ExcelViewer from "./ExcelViewer" import CsvViewer from "./CsvViewer" -import { DocumentOperationsProvider } from "@/contexts/DocumentOperationsContext" +import { DocumentOperations } from "@/contexts/DocumentOperationsContext" import TxtViewer from "./TxtViewer" +import { useScopedFind } from "@/hooks/useScopedFind" interface CitationPreviewProps { citation: Citation | null @@ -17,112 +18,199 @@ interface CitationPreviewProps { onClose: () => void onBackToSources?: () => void showBackButton?: boolean + documentOperationsRef?: React.RefObject + onDocumentLoaded?: () => void } -export const CitationPreview: React.FC = React.memo( - ({ citation, isOpen, onClose, onBackToSources, showBackButton = false }) => { - const [documentContent, setDocumentContent] = useState(null) - const [loading, setLoading] = useState(false) - const [error, setError] = useState(null) - - useEffect(() => { - if (!citation || !isOpen) { - setDocumentContent(null) - setError(null) - return - } +// Inner component that has access to DocumentOperations context +const CitationPreview: React.FC = ({ + citation, + isOpen, + onClose, + onBackToSources, + showBackButton = false, + documentOperationsRef, + onDocumentLoaded, +}) => { + const [documentContent, setDocumentContent] = useState(null) + const [loading, setLoading] = useState(false) + const [error, setError] = useState(null) + const containerRef = useRef(null) - const loadDocument = async () => { - setLoading(true) - setError(null) - try { - if ( - citation.app === "KnowledgeBase" && - citation.itemId && - citation.clId - ) { - const response = - await api.cl[citation.clId].files[citation.itemId].content.$get() + useEffect(() => { + if (!citation || !isOpen) { + setDocumentContent(null) + setError(null) + return + } - if (!response.ok) { - throw new Error( - `Failed to fetch document: ${response.statusText}`, - ) - } + const loadDocument = async () => { + setLoading(true) + setError(null) + try { + if ( + citation.app === "KnowledgeBase" && + citation.itemId && + citation.clId + ) { + const response = + await api.cl[citation.clId].files[citation.itemId].content.$get() - const blob = await response.blob() - setDocumentContent(blob) - } else if (citation.url) { - // For external documents, try to fetch directly - const response = await authFetch(citation.url, { - method: "GET", - }) + if (!response.ok) { + throw new Error(`Failed to fetch document: ${response.statusText}`) + } - if (!response.ok) { - throw new Error( - `Failed to fetch document: ${response.statusText}`, - ) - } + const blob = await response.blob() + setDocumentContent(blob) + } else if (citation.url) { + // For external documents, try to fetch directly + const response = await authFetch(citation.url, { + method: "GET", + }) - const blob = await response.blob() - setDocumentContent(blob) - } else { - throw new Error("No document source available") + if (!response.ok) { + throw new Error(`Failed to fetch document: ${response.statusText}`) } - } catch (err) { - console.error("Error loading document:", err) - setError( - err instanceof Error ? err.message : "Failed to load document", + + const blob = await response.blob() + setDocumentContent(blob) + } else { + throw new Error("No document source available") + } + } catch (err) { + console.error("Error loading document:", err) + setError(err instanceof Error ? err.message : "Failed to load document") + } finally { + setLoading(false) + } + } + + loadDocument() + }, [citation, isOpen]) + + const { highlightText, clearHighlights, scrollToMatch } = useScopedFind( + containerRef, + { + documentId: citation?.itemId, + }, + ) + + // Expose the highlight functions via the document operations ref + useEffect(() => { + if (documentOperationsRef?.current) { + documentOperationsRef.current.highlightText = async ( + text: string, + chunkIndex: number, + pageIndex?: number, + waitForTextLayer: boolean = false, + ) => { + if (!containerRef.current) { + return false + } + + try { + const success = await highlightText( + text, + chunkIndex, + pageIndex, + waitForTextLayer, ) - } finally { - setLoading(false) + return success + } catch (error) { + console.error("Error calling highlightText:", error) + return false } } - loadDocument() - }, [citation, isOpen]) - - const getFileExtension = (filename: string): string => { - return filename.toLowerCase().split(".").pop() || "" + documentOperationsRef.current.clearHighlights = clearHighlights + documentOperationsRef.current.scrollToMatch = scrollToMatch } + }, [documentOperationsRef, highlightText, clearHighlights, scrollToMatch]) - const renderViewer = () => { - if (!documentContent || !citation) return null + useEffect(() => { + clearHighlights() + }, [citation?.itemId, clearHighlights]) - const fileName = citation.title || "" - const extension = getFileExtension(fileName) + const getFileExtension = (mimeType: string, filename: string): string => { + if (mimeType === "application/pdf") { + return "pdf" + } + if ( + mimeType === + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ) { + return "docx" + } + if (mimeType === "application/msword") { + return "doc" + } + if (mimeType === "text/markdown") { + return "md" + } + if (mimeType === "text/plain") { + return "txt" + } + if (mimeType === "application/vnd.ms-excel") { + return "xls" + } + if ( + mimeType === + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) { + return "xlsx" + } + if (mimeType === "text/csv") { + return "csv" + } + if (mimeType === "text/tsv") { + return "tsv" + } + return filename.toLowerCase().split(".").pop() || "" + } - // Create a File object from the blob - const file = new File([documentContent], fileName, { - type: documentContent.type || getDefaultMimeType(extension), - }) - - switch (extension) { - case "pdf": - return ( - - - - ) - case "md": - case "markdown": - return ( + const viewerElement = useMemo(() => { + if (!documentContent || !citation) return null + + const fileName = citation.title || "" + const extension = getFileExtension(documentContent.type, fileName) + + // Create a File object from the blob + const file = new File([documentContent], fileName, { + type: documentContent.type || getDefaultMimeType(extension), + }) + + switch (extension) { + case "pdf": + return ( +
+ +
+ ) + case "md": + case "markdown": + return ( +
- ) - case "docx": - case "doc": - return ( +
+ ) + case "docx": + case "doc": + return ( +
= React.memo( renderFooters: true, renderFootnotes: true, renderEndnotes: true, + renderComments: false, + renderChanges: false, breakPages: true, + ignoreLastRenderedPageBreak: true, + inWrapper: true, + ignoreWidth: false, + ignoreHeight: false, + ignoreFonts: false, }} /> - ) - case "xlsx": - case "xls": - return( - - ) - case "csv": - case "tsv": - return( - - ) - case "txt": - case "text": - return( - - ) - +
+ ) + case "xlsx": + case "xls": + return ( +
+ +
+ ) + case "csv": + case "tsv": + return ( +
+ +
+ ) + case "txt": + case "text": + return ( +
+ +
+ ) - default: - // For other file types, try to display as text or show a generic message - return ( -
- -

- Preview not available for this file type. -

- {citation.url && ( - - - Open in new tab - - )} -
- ) - } + default: + // For other file types, try to display as text or show a generic message + return ( +
+ +

+ Preview not available for this file type. +

+ {citation.url && ( + + + Open in new tab + + )} +
+ ) } + }, [citation, documentContent]) - const getDefaultMimeType = (extension: string): string => { - switch (extension) { - case "pdf": - return "application/pdf" - case "docx": - return "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - case "doc": - return "application/msword" - case "md": - case "markdown": - return "text/markdown" - case "txt": - return "text/plain" - default: - return "application/octet-stream" - } + const getDefaultMimeType = (extension: string): string => { + switch (extension) { + case "pdf": + return "application/pdf" + case "docx": + return "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + case "doc": + return "application/msword" + case "md": + case "markdown": + return "text/markdown" + case "txt": + return "text/plain" + default: + return "application/octet-stream" + } + } + + // Notify parent when document is loaded and ready + useEffect(() => { + if ( + !loading && + !error && + documentContent && + onDocumentLoaded && + viewerElement + ) { + onDocumentLoaded() } + }, [loading, error, documentContent, onDocumentLoaded, viewerElement]) - if (!isOpen) return null + if (!isOpen) return null - return ( -
- {/* Header */} -
-
- {showBackButton && onBackToSources && ( - + return ( +
+ {/* Header */} +
+
+ {showBackButton && onBackToSources && ( + + )} +
+

+ {citation.title.split("/").pop() || "Document Preview"} +

+ {citation?.app && ( +

+ Source:{" "} + {citation.title.replace(/\/[^/]*$/, "") || "Unknown Source"} +

)} -
-

- {citation.title.split("/").pop() || "Document Preview"} -

- {citation?.app && ( -

- Source:{" "} - {citation.title.replace(/\/[^/]*$/, "") || "Unknown Source"} -

- )} -
-
+ +
- {/* Content */} -
- {loading && ( -
-
-
-

- Loading document... -

-
+ {/* Content */} +
+ {loading && ( +
+
+
+

+ Loading document... +

- )} +
+ )} - {error && ( -
-
-
- -
-

{error}

- {citation?.url && ( - - - Try opening in new tab - - )} + {error && ( +
+
+
+
+

{error}

+ {citation?.url && ( + + + Try opening in new tab + + )}
- )} +
+ )} - {!loading && !error && documentContent && ( -
{renderViewer()}
- )} -
+ {!loading && !error && documentContent && ( +
{viewerElement}
+ )}
- ) - }, -) +
+ ) +} CitationPreview.displayName = "CitationPreview" diff --git a/frontend/src/contexts/DocumentOperationsContext.tsx b/frontend/src/contexts/DocumentOperationsContext.tsx index 8ad4f822f..8e28cbdec 100644 --- a/frontend/src/contexts/DocumentOperationsContext.tsx +++ b/frontend/src/contexts/DocumentOperationsContext.tsx @@ -1,8 +1,19 @@ -import React, { createContext, useContext, useRef, useImperativeHandle, forwardRef } from 'react' +import React, { + createContext, + useContext, + useRef, + useImperativeHandle, + forwardRef, +} from "react" // Define the interface for document operations export interface DocumentOperations { - highlightText?: (text: string, chunkIndex: number, pageIndex?: number) => Promise + highlightText?: ( + text: string, + chunkIndex: number, + pageIndex?: number, + waitForTextLayer?: boolean, + ) => Promise clearHighlights?: () => void scrollToMatch?: (index: number) => boolean goToPage?: (pageIndex: number) => Promise @@ -15,20 +26,29 @@ const DocumentOperationsContext = createContext<{ } | null>(null) // Provider component -export const DocumentOperationsProvider: React.FC<{ children: React.ReactNode }> = ({ children }) => { - const documentOperationsRef = useRef({} as DocumentOperations) +export const DocumentOperationsProvider: React.FC<{ + children: React.ReactNode +}> = ({ children }) => { + const documentOperationsRef = useRef( + {} as DocumentOperations, + ) - const setGoToPageFn = React.useCallback((fn: ((pageIndex: number) => Promise) | null) => { - if (documentOperationsRef.current) { - documentOperationsRef.current.goToPage = fn || undefined - } - }, []) + const setGoToPageFn = React.useCallback( + (fn: ((pageIndex: number) => Promise) | null) => { + if (documentOperationsRef.current) { + documentOperationsRef.current.goToPage = fn || undefined + } + }, + [], + ) return ( - + {children} ) @@ -38,43 +58,66 @@ export const DocumentOperationsProvider: React.FC<{ children: React.ReactNode }> export const useDocumentOperations = () => { const context = useContext(DocumentOperationsContext) if (!context) { - throw new Error('useDocumentOperations must be used within a DocumentOperationsProvider') + throw new Error( + "useDocumentOperations must be used within a DocumentOperationsProvider", + ) } return context } // Higher-order component to expose document operations via ref export const withDocumentOperations =

( - Component: React.ComponentType

}> + Component: React.ComponentType< + P & { documentOperationsRef: React.RefObject } + >, ) => { return forwardRef((props, ref) => { const { documentOperationsRef } = useDocumentOperations() - - useImperativeHandle(ref, () => ({ - highlightText: async (text: string, chunkIndex: number, pageIndex?: number) => { - if (documentOperationsRef.current?.highlightText) { - return await documentOperationsRef.current.highlightText(text, chunkIndex, pageIndex) - } - return false - }, - clearHighlights: () => { - if (documentOperationsRef.current?.clearHighlights) { - documentOperationsRef.current.clearHighlights() - } - }, - scrollToMatch: (index: number) => { - if (documentOperationsRef.current?.scrollToMatch) { - return documentOperationsRef.current.scrollToMatch(index) - } - return false - }, - goToPage: async (pageIndex: number) => { - if (documentOperationsRef.current?.goToPage) { - await documentOperationsRef.current.goToPage(pageIndex) - } - } - }), [documentOperationsRef]) - return + useImperativeHandle( + ref, + () => ({ + highlightText: async ( + text: string, + chunkIndex: number, + pageIndex?: number, + waitForTextLayer: boolean = false, + ) => { + if (documentOperationsRef.current?.highlightText) { + return await documentOperationsRef.current.highlightText( + text, + chunkIndex, + pageIndex, + waitForTextLayer, + ) + } + return false + }, + clearHighlights: () => { + if (documentOperationsRef.current?.clearHighlights) { + documentOperationsRef.current.clearHighlights() + } + }, + scrollToMatch: (index: number) => { + if (documentOperationsRef.current?.scrollToMatch) { + return documentOperationsRef.current.scrollToMatch(index) + } + return false + }, + goToPage: async (pageIndex: number) => { + if (documentOperationsRef.current?.goToPage) { + await documentOperationsRef.current.goToPage(pageIndex) + } + }, + }), + [documentOperationsRef], + ) + + return ( + + ) }) } diff --git a/frontend/src/hooks/useScopedFind.ts b/frontend/src/hooks/useScopedFind.ts index ea258e5fc..34b85f22a 100644 --- a/frontend/src/hooks/useScopedFind.ts +++ b/frontend/src/hooks/useScopedFind.ts @@ -1,365 +1,410 @@ -import { useCallback, useEffect, useState, useRef } from "react"; -import { api } from "@/api"; -import { useDocumentOperations } from "@/contexts/DocumentOperationsContext"; +import { useCallback, useEffect, useState, useRef } from "react" +import { api } from "@/api" +import { useDocumentOperations } from "@/contexts/DocumentOperationsContext" type Options = { - caseSensitive?: boolean; - highlightClass?: string; - activeClass?: string; - debug?: boolean; // Enable debug logging - documentId?: string; // Document ID for caching -}; + caseSensitive?: boolean + highlightClass?: string + activeClass?: string + debug?: boolean // Enable debug logging + documentId?: string // Document ID for caching +} type HighlightMatch = { - startIndex: number; - endIndex: number; - length: number; - similarity: number; - highlightedText: string; - originalLine?: string; - processedLine?: string; -}; + startIndex: number + endIndex: number + length: number + similarity: number + highlightedText: string + originalLine?: string + processedLine?: string +} type HighlightResponse = { - success: boolean; - matches?: HighlightMatch[]; - totalMatches?: number; - message?: string; - debug?: any; -}; + success: boolean + matches?: HighlightMatch[] + totalMatches?: number + message?: string + debug?: any +} type CacheEntry = { - response: HighlightResponse; - timestamp: number; -}; + response: HighlightResponse + timestamp: number +} type HighlightCache = { - [key: string]: CacheEntry; -}; + [key: string]: CacheEntry +} // Cache duration constant - defined at module scope to prevent re-declaration on each render -const CACHE_DURATION = 5 * 60 * 1000; // 5 minutes +const CACHE_DURATION = 5 * 60 * 1000 // 5 minutes export function useScopedFind( containerRef: React.RefObject, - opts: Options = {} + opts: Options = {}, ) { - const { documentOperationsRef } = useDocumentOperations(); - const { + const { documentOperationsRef } = useDocumentOperations() + const { caseSensitive = true, - highlightClass = "bg-yellow-200/60 dark:bg-yellow-200/40 rounded-sm px-0.5 py-px", + highlightClass = "bg-yellow-200/60 dark:bg-yellow-200/40 rounded-sm px-0.5 py-px", debug = false, documentId, - } = opts; + } = opts // Cache for API responses - const cacheRef = useRef({}); + const cacheRef = useRef({}) - const [matches, setMatches] = useState([]); - const [index, setIndex] = useState(0); - const [isLoading, setIsLoading] = useState(false); + const [matches, setMatches] = useState([]) + const [index, setIndex] = useState(0) + const [isLoading, setIsLoading] = useState(false) // Generate cache key based on document ID, chunk index, and options - const generateCacheKey = useCallback(( - docId: string | undefined, - chunkIdx: number | null | undefined, - ): string => { - const keyComponents = [ - docId || 'no-doc-id', - chunkIdx !== null && chunkIdx !== undefined ? chunkIdx.toString() : 'no-chunk-idx', - ]; - return keyComponents.join('|'); - }, []); + const generateCacheKey = useCallback( + ( + docId: string | undefined, + chunkIdx: number | null | undefined, + ): string => { + const keyComponents = [ + docId || "no-doc-id", + chunkIdx !== null && chunkIdx !== undefined + ? chunkIdx.toString() + : "no-chunk-idx", + ] + return keyComponents.join("|") + }, + [], + ) // Clean expired cache entries const cleanExpiredCache = useCallback(() => { - const now = Date.now(); - const cache = cacheRef.current; - Object.keys(cache).forEach(key => { + const now = Date.now() + const cache = cacheRef.current + Object.keys(cache).forEach((key) => { if (now - cache[key].timestamp > CACHE_DURATION) { - delete cache[key]; + delete cache[key] } - }); - }, []); + }) + }, []) // Extract text content from the container const extractContainerText = useCallback((container: HTMLElement): string => { const walker = document.createTreeWalker(container, NodeFilter.SHOW_TEXT, { acceptNode(n) { - const p = (n as Text).parentElement; - if (!p) return NodeFilter.FILTER_REJECT; - const tag = p.tagName.toLowerCase(); - if (tag === "script" || tag === "style") return NodeFilter.FILTER_REJECT; - if (!(n as Text).nodeValue?.trim()) return NodeFilter.FILTER_REJECT; - return NodeFilter.FILTER_ACCEPT; + const p = (n as Text).parentElement + if (!p) return NodeFilter.FILTER_REJECT + const tag = p.tagName.toLowerCase() + if (tag === "script" || tag === "style") return NodeFilter.FILTER_REJECT + if (!(n as Text).nodeValue?.trim()) return NodeFilter.FILTER_REJECT + return NodeFilter.FILTER_ACCEPT }, - }); + }) - let text = ""; - let node: Node | null; + let text = "" + let node: Node | null while ((node = walker.nextNode())) { - text += (node as Text).nodeValue; + text += (node as Text).nodeValue } - - return text; - }, []); + + return text + }, []) // Create highlight marks based on backend response - const createHighlightMarks = useCallback(( - container: HTMLElement, - match: HighlightMatch - ): HTMLElement[] => { - const marks: HTMLElement[] = []; - - try { - // Find all text nodes and their positions - const walker = document.createTreeWalker(container, NodeFilter.SHOW_TEXT, { - acceptNode(n) { - const p = (n as Text).parentElement; - if (!p) return NodeFilter.FILTER_REJECT; - const tag = p.tagName.toLowerCase(); - if (tag === "script" || tag === "style") return NodeFilter.FILTER_REJECT; - if (!n.nodeValue?.trim()) return NodeFilter.FILTER_REJECT; - return NodeFilter.FILTER_ACCEPT; - }, - }); - - const textNodes: { node: Text; start: number; end: number }[] = []; - let currentPos = 0; - let node: Node | null; - - // Build a map of text nodes and their positions - while ((node = walker.nextNode())) { - const textNode = node as Text; - const nodeLength = textNode.nodeValue!.length; - textNodes.push({ - node: textNode, - start: currentPos, - end: currentPos + nodeLength - }); - currentPos += nodeLength; - } - - // Find all text nodes that intersect with our match - const intersectingNodes = textNodes.filter(({ start, end }) => - start < match.endIndex && end > match.startIndex - ); - - // Create highlights for each intersecting text node - for (const { node: textNode, start: nodeStart } of intersectingNodes) { - const startOffset = Math.max(0, match.startIndex - nodeStart); - const endOffset = Math.min(textNode.nodeValue!.length, match.endIndex - nodeStart); - - if (startOffset < endOffset) { - try { - // Create a range for this text segment - const range = document.createRange(); - range.setStart(textNode, startOffset); - range.setEnd(textNode, endOffset); - - // Create and insert the mark - const mark = document.createElement("mark"); - mark.className = `${highlightClass}`; - mark.setAttribute('data-match-index', '0'); - + const createHighlightMarks = useCallback( + (container: HTMLElement, match: HighlightMatch): HTMLElement[] => { + const marks: HTMLElement[] = [] + + try { + // Find all text nodes and their positions + const walker = document.createTreeWalker( + container, + NodeFilter.SHOW_TEXT, + { + acceptNode(n) { + const p = (n as Text).parentElement + if (!p) return NodeFilter.FILTER_REJECT + const tag = p.tagName.toLowerCase() + if (tag === "script" || tag === "style") + return NodeFilter.FILTER_REJECT + if (!n.nodeValue?.trim()) return NodeFilter.FILTER_REJECT + return NodeFilter.FILTER_ACCEPT + }, + }, + ) + + const textNodes: { node: Text; start: number; end: number }[] = [] + let currentPos = 0 + let node: Node | null + + // Build a map of text nodes and their positions + while ((node = walker.nextNode())) { + const textNode = node as Text + const nodeLength = textNode.nodeValue!.length + textNodes.push({ + node: textNode, + start: currentPos, + end: currentPos + nodeLength, + }) + currentPos += nodeLength + } + + // Find all text nodes that intersect with our match + const intersectingNodes = textNodes.filter( + ({ start, end }) => start < match.endIndex && end > match.startIndex, + ) + + // Create highlights for each intersecting text node + for (const { node: textNode, start: nodeStart } of intersectingNodes) { + const startOffset = Math.max(0, match.startIndex - nodeStart) + const endOffset = Math.min( + textNode.nodeValue!.length, + match.endIndex - nodeStart, + ) + + if (startOffset < endOffset) { try { - range.surroundContents(mark); - marks.push(mark); - } catch (rangeError) { - console.warn('Failed to wrap range with mark, trying alternative approach:', rangeError); - - // Alternative: split text node and insert mark - const originalText = textNode.nodeValue!; - const beforeText = textNode.nodeValue!.substring(0, startOffset); - const matchText = textNode.nodeValue!.substring(startOffset, endOffset); - const afterText = textNode.nodeValue!.substring(endOffset); - + // Create a range for this text segment + const range = document.createRange() + range.setStart(textNode, startOffset) + range.setEnd(textNode, endOffset) + + // Create and insert the mark + const mark = document.createElement("mark") + mark.className = `${highlightClass}` + mark.setAttribute("data-match-index", "0") + try { - // Replace the text node content with before text - textNode.nodeValue = beforeText; - - // Create and insert the mark - const mark = document.createElement("mark"); - mark.className = `${highlightClass}`; - mark.setAttribute('data-match-index', '0'); - mark.textContent = matchText; - - // Insert mark after the text node - textNode.parentNode!.insertBefore(mark, textNode.nextSibling); - marks.push(mark); - - // Insert remaining text after the mark - if (afterText) { - const afterNode = document.createTextNode(afterText); - mark.parentNode!.insertBefore(afterNode, mark.nextSibling); + range.surroundContents(mark) + marks.push(mark) + } catch (rangeError) { + console.warn( + "Failed to wrap range with mark, trying alternative approach:", + rangeError, + ) + + // Alternative: split text node and insert mark + const originalText = textNode.nodeValue! + const beforeText = textNode.nodeValue!.substring(0, startOffset) + const matchText = textNode.nodeValue!.substring( + startOffset, + endOffset, + ) + const afterText = textNode.nodeValue!.substring(endOffset) + + try { + // Replace the text node content with before text + textNode.nodeValue = beforeText + + // Create and insert the mark + const mark = document.createElement("mark") + mark.className = `${highlightClass}` + mark.setAttribute("data-match-index", "0") + mark.textContent = matchText + + // Insert mark after the text node + textNode.parentNode!.insertBefore(mark, textNode.nextSibling) + marks.push(mark) + + // Insert remaining text after the mark + if (afterText) { + const afterNode = document.createTextNode(afterText) + mark.parentNode!.insertBefore(afterNode, mark.nextSibling) + } + } catch (fallbackError) { + // Restore original text on error + textNode.nodeValue = originalText + console.error( + "Fallback highlighting approach failed:", + fallbackError, + ) } - } catch (fallbackError) { - // Restore original text on error - textNode.nodeValue = originalText; - console.error('Fallback highlighting approach failed:', fallbackError); } + } catch (error) { + console.warn( + "Error processing text node for highlighting:", + error, + ) } - } catch (error) { - console.warn('Error processing text node for highlighting:', error); } } + } catch (error) { + console.error("Error creating highlight marks:", error) } - - } catch (error) { - console.error('Error creating highlight marks:', error); - } - - return marks; - }, [highlightClass]); + + return marks + }, + [highlightClass], + ) const clearHighlights = useCallback(() => { - const root = containerRef.current; - if (!root) return; - - const marks = root.querySelectorAll('mark[data-match-index]'); + const root = containerRef.current + if (!root) return + + const marks = root.querySelectorAll("mark[data-match-index]") marks.forEach((m) => { - const parent = m.parentNode!; + const parent = m.parentNode! // unwrap - while (m.firstChild) parent.insertBefore(m.firstChild, m); - parent.removeChild(m); - parent.normalize(); // merge adjacent text nodes - }); - - setMatches([]); - setIndex(0); - }, [containerRef]); + while (m.firstChild) parent.insertBefore(m.firstChild, m) + parent.removeChild(m) + parent.normalize() // merge adjacent text nodes + }) + + setMatches([]) + setIndex(0) + }, [containerRef]) // Wait for text layer to be fully rendered and positioned - const waitForTextLayerReady = useCallback(async (container: HTMLElement, timeoutMs = 5000): Promise => { - return new Promise((resolve) => { - const startTime = Date.now(); - let lastTextLength = 0; - let text = ''; - let stableCount = 0; - const requiredStableChecks = 3; - - const checkTextLayer = () => { - const currentTime = Date.now(); - if (currentTime - startTime > timeoutMs) { - if (debug) { - console.log('Text layer wait timeout reached'); - } - resolve(text); - return; - } - - // Extract current text length - text = extractContainerText(container); - const currentTextLength = text.length; - - if (debug && currentTextLength !== lastTextLength) { - console.log(`Text layer length changed: ${lastTextLength} -> ${currentTextLength}`); - } - - // Check if text length has stabilized - if (currentTextLength === lastTextLength && currentTextLength > 0) { - stableCount++; - if (stableCount >= requiredStableChecks) { + const waitForTextLayerReady = useCallback( + async (container: HTMLElement, timeoutMs = 5000): Promise => { + return new Promise((resolve) => { + const startTime = Date.now() + let lastTextLength = 0 + let text = "" + let stableCount = 0 + const requiredStableChecks = 3 + + const checkTextLayer = () => { + const currentTime = Date.now() + if (currentTime - startTime > timeoutMs) { if (debug) { - console.log(`Text layer stabilized at length ${currentTextLength}`); + console.log("Text layer wait timeout reached") } - resolve(text); - return; + resolve(text) + return } - } else { - stableCount = 0; + + // Extract current text length + text = extractContainerText(container) + const currentTextLength = text.length + + if (debug && currentTextLength !== lastTextLength) { + console.log( + `Text layer length changed: ${lastTextLength} -> ${currentTextLength}`, + ) + } + + // Check if text length has stabilized + if (currentTextLength === lastTextLength && currentTextLength > 0) { + stableCount++ + if (stableCount >= requiredStableChecks) { + if (debug) { + console.log( + `Text layer stabilized at length ${currentTextLength}`, + ) + } + resolve(text) + return + } + } else { + stableCount = 0 + } + + lastTextLength = currentTextLength + + // Use requestAnimationFrame for the next check to ensure DOM updates are processed + requestAnimationFrame(() => { + setTimeout(checkTextLayer, 50) // Check every 50ms + }) } - - lastTextLength = currentTextLength; - - // Use requestAnimationFrame for the next check to ensure DOM updates are processed - requestAnimationFrame(() => { - setTimeout(checkTextLayer, 50); // Check every 50ms - }); - }; - - // Start checking after one animation frame - requestAnimationFrame(checkTextLayer); - }); - }, [extractContainerText, debug]); + + // Start checking after one animation frame + requestAnimationFrame(checkTextLayer) + }) + }, + [extractContainerText, debug], + ) const highlightText = useCallback( - async (text: string, chunkIndex: number, pageIndex?: number): Promise => { + async ( + text: string, + chunkIndex: number, + pageIndex?: number, + waitForTextLayer: boolean = false, + ): Promise => { if (debug) { - console.log('highlightText called with:', text); + console.log("highlightText called with:", text) } - - const root = containerRef.current; + + const root = containerRef.current if (!root) { - if (debug) console.log('No container ref found'); - return false; + if (debug) console.log("No container ref found") + return false } - + if (debug) { - console.log('Container found:', root); + console.log("Container found:", root) } - clearHighlights(); - if (!text) return false; + clearHighlights() + if (!text) return false + + setIsLoading(true) - setIsLoading(true); - try { - let containerText = ''; + let containerText = "" // For PDFs, ensure the page is rendered before extracting text if (documentOperationsRef?.current?.goToPage) { if (debug) { - console.log('PDF or Spreadsheet detected', pageIndex); + console.log("PDF or Spreadsheet detected", pageIndex) } - if(pageIndex !== undefined) { + if (pageIndex !== undefined) { if (debug) { - console.log('Going to page or subsheet:', pageIndex); + console.log("Going to page or subsheet:", pageIndex) } - await documentOperationsRef.current.goToPage(pageIndex); - + await documentOperationsRef.current.goToPage(pageIndex) + // Wait for text layer to be fully rendered and positioned if (debug) { - console.log('Waiting for text layer to be ready...'); + console.log("Waiting for text layer to be ready...") } - containerText = await waitForTextLayerReady(root); + containerText = await waitForTextLayerReady(root) if (debug) { - console.log('Text layer ready, proceeding with highlighting'); + console.log("Text layer ready, proceeding with highlighting") } } else { if (debug) { - console.log('No page or subsheet index provided, skipping highlight'); + console.log( + "No page or subsheet index provided, skipping highlight", + ) } - return false; + return false } } else { - containerText = extractContainerText(root); + if (waitForTextLayer) { + containerText = await waitForTextLayerReady(root) + } else { + containerText = extractContainerText(root) + } } - + if (debug) { - console.log('Container text extracted, length:', containerText.length); + console.log("Container text extracted, length:", containerText.length) } // Clean expired cache entries - cleanExpiredCache(); + cleanExpiredCache() // Generate cache key - const canUseCache = !!documentId; + const canUseCache = !!documentId const cacheKey = canUseCache ? generateCacheKey(documentId, chunkIndex) - : ''; + : "" // Check cache first (only if safe) - const cachedEntry = canUseCache ? cacheRef.current[cacheKey] : undefined; - let result: HighlightResponse; + const cachedEntry = canUseCache ? cacheRef.current[cacheKey] : undefined + let result: HighlightResponse - if (cachedEntry && (Date.now() - cachedEntry.timestamp) < CACHE_DURATION) { + if ( + cachedEntry && + Date.now() - cachedEntry.timestamp < CACHE_DURATION + ) { if (debug) { - console.log('Using cached result for key:', cacheKey); + console.log("Using cached result for key:", cacheKey) } - result = cachedEntry.response; + result = cachedEntry.response } else { if (debug) { - console.log('Cache miss, making API call for key:', cacheKey); + console.log("Cache miss, making API call for key:", cacheKey) } const response = await api.highlight.$post({ @@ -368,134 +413,156 @@ export function useScopedFind( documentContent: containerText, options: { caseSensitive, - } - } - }); + }, + }, + }) if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); + throw new Error(`HTTP error! status: ${response.status}`) } - result = await response.json(); - + result = await response.json() + // Only cache successful responses and only when safe if (result.success && canUseCache) { cacheRef.current[cacheKey] = { response: result, timestamp: Date.now(), - }; + } if (debug) { - console.log('Cached successful result for key:', cacheKey); + console.log("Cached successful result for key:", cacheKey) } } else if (result.success && !canUseCache && debug) { - console.log('Skipping cache write (no documentId)'); + console.log("Skipping cache write (no documentId)") } else { if (debug) { - console.log('Not caching failed response for key:', cacheKey); + console.log("Not caching failed response for key:", cacheKey) } } } - + if (debug) { - console.log('Backend response:', result); + console.log("Backend response:", result) } if (!result.success || !result.matches || result.matches.length === 0) { if (debug) { - console.log('No matches found:', result.message); + console.log("No matches found:", result.message) } - return false; + return false } // Create highlight marks for all matches - const allMarks: HTMLElement[] = []; - let longestMatchIndex = 0; - let longestMatchLength = 0; - + const allMarks: HTMLElement[] = [] + let longestMatchIndex = 0 + let longestMatchLength = 0 + result.matches.forEach((match, matchIndex) => { - const marks = createHighlightMarks(root, match); - marks.forEach(mark => { - mark.setAttribute('data-match-index', matchIndex.toString()); - }); - allMarks.push(...marks); - + const marks = createHighlightMarks(root, match) + marks.forEach((mark) => { + mark.setAttribute("data-match-index", matchIndex.toString()) + }) + allMarks.push(...marks) + if (match.length > longestMatchLength) { - longestMatchLength = match.length; - longestMatchIndex = allMarks.length - marks.length; + longestMatchLength = match.length + longestMatchIndex = allMarks.length - marks.length } - }); - + }) + if (debug) { - console.log(`Created ${allMarks.length} highlight marks from ${result.matches.length} matches`); - console.log(`Longest match index: ${longestMatchIndex} with length: ${longestMatchLength}`); + console.log( + `Created ${allMarks.length} highlight marks from ${result.matches.length} matches`, + ) + console.log( + `Longest match index: ${longestMatchIndex} with length: ${longestMatchLength}`, + ) } - - setMatches(allMarks); - setIndex(longestMatchIndex); - - return allMarks.length > 0; - + + setMatches(allMarks) + setIndex(longestMatchIndex) + + return allMarks.length > 0 } catch (error) { - console.error('Error during backend highlighting:', error); - return false; + console.error("Error during backend highlighting:", error) + return false } finally { - setIsLoading(false); + setIsLoading(false) } }, - [clearHighlights, containerRef, extractContainerText, createHighlightMarks, caseSensitive, debug, documentId, generateCacheKey, cleanExpiredCache] - ); + [ + clearHighlights, + containerRef, + extractContainerText, + createHighlightMarks, + caseSensitive, + debug, + documentId, + generateCacheKey, + cleanExpiredCache, + ], + ) const scrollToMatch = useCallback( (matchIndex: number = 0) => { - if (!matches.length || !containerRef.current) return false; - const bounded = ((matchIndex % matches.length) + matches.length) % matches.length; + if (!matches.length || !containerRef.current) return false + const bounded = + ((matchIndex % matches.length) + matches.length) % matches.length + + const container = containerRef.current + const target = matches[bounded] - const container = containerRef.current; - const target = matches[bounded]; - if (container.scrollHeight > container.clientHeight) { - const containerRect = container.getBoundingClientRect(); - const targetRect = target.getBoundingClientRect(); - - const targetTop = targetRect.top - containerRect.top; - const containerHeight = container.clientHeight; - const targetHeight = targetRect.height; - - const scrollTop = container.scrollTop + targetTop - (containerHeight / 2) + (targetHeight / 2); - - container.scrollTo({ - top: Math.max(0, scrollTop), - behavior: 'smooth' - }); + const containerRect = container.getBoundingClientRect() + const targetRect = target.getBoundingClientRect() + + const targetTop = targetRect.top - containerRect.top + const containerHeight = container.clientHeight + const targetHeight = targetRect.height + + const scrollTop = + container.scrollTop + + targetTop - + containerHeight / 2 + + targetHeight / 2 + + container.scrollTo({ + top: Math.max(0, scrollTop), + behavior: "smooth", + }) } else { - target.scrollIntoView({ block: "center", inline: "nearest", behavior: "smooth" }); + target.scrollIntoView({ + block: "center", + inline: "nearest", + behavior: "smooth", + }) } - setIndex(bounded); - return true; + setIndex(bounded) + return true }, - [matches, containerRef] - ); + [matches, containerRef], + ) // Auto-scroll to the current index (which is set to the longest match) whenever matches update useEffect(() => { if (matches.length) { - scrollToMatch(index); + scrollToMatch(index) } - }, [matches, index]); + }, [matches, index]) // Clean up when container unmounts - useEffect(() => () => clearHighlights(), [clearHighlights]); + useEffect(() => () => clearHighlights(), [clearHighlights]) // Clean up expired cache entries periodically useEffect(() => { const interval = setInterval(() => { - cleanExpiredCache(); - }, CACHE_DURATION / 2); // Clean every 2.5 minutes + cleanExpiredCache() + }, CACHE_DURATION / 2) // Clean every 2.5 minutes - return () => clearInterval(interval); - }, [cleanExpiredCache]); + return () => clearInterval(interval) + }, [cleanExpiredCache]) return { highlightText, @@ -504,5 +571,5 @@ export function useScopedFind( matches, index, isLoading, - }; + } } diff --git a/frontend/src/routes/_authenticated/chat.tsx b/frontend/src/routes/_authenticated/chat.tsx index 1528d5a2f..3b49461c2 100644 --- a/frontend/src/routes/_authenticated/chat.tsx +++ b/frontend/src/routes/_authenticated/chat.tsx @@ -97,7 +97,7 @@ import { ShareModal } from "@/components/ShareModal" import { AttachmentGallery } from "@/components/AttachmentGallery" import { useVirtualizer } from "@tanstack/react-virtual" import { renderToStaticMarkup } from "react-dom/server" -import { CitationPreview } from "@/components/CitationPreview" +import CitationPreview from "@/components/CitationPreview" import { createCitationLink } from "@/components/CitationLink" import { createPortal } from "react-dom" import { @@ -105,6 +105,10 @@ import { processMessage, createTableComponents, } from "@/utils/chatUtils.tsx" +import { + useDocumentOperations, + DocumentOperationsProvider, +} from "@/contexts/DocumentOperationsContext" export const THINKING_PLACEHOLDER = "Thinking" @@ -225,6 +229,7 @@ export const ChatPage = ({ const [sharedChatData, setSharedChatData] = useState(null) const [sharedChatLoading, setSharedChatLoading] = useState(false) const [sharedChatError, setSharedChatError] = useState(null) + const { documentOperationsRef } = useDocumentOperations() const data = useLoaderData({ from: isWithChatId @@ -408,7 +413,11 @@ export const ChatPage = ({ const [selectedCitation, setSelectedCitation] = useState( null, ) + const [selectedChunkIndex, setSelectedChunkIndex] = useState( + null, + ) const [cameFromSources, setCameFromSources] = useState(false) + const [isDocumentLoaded, setIsDocumentLoaded] = useState(false) // Compute disableRetry flag for retry buttons const disableRetry = isStreaming || retryIsStreaming || isSharedChat @@ -1177,9 +1186,104 @@ export const ChatPage = ({ } } + // Handle chunk index changes from CitationPreview + const handleChunkIndexChange = useCallback( + async (newChunkIndex: number | null, documentId: string, docId: string) => { + if (!documentId) { + console.error("handleChunkIndexChange called without documentId") + return + } + + if (selectedCitation?.itemId !== documentId) { + return + } + + if (newChunkIndex === null) { + documentOperationsRef?.current?.clearHighlights?.() + return + } + + if (newChunkIndex !== null) { + try { + const chunkContentResponse = await api.chunk[":cId"].files[ + ":docId" + ].content.$get({ + param: { cId: newChunkIndex.toString(), docId: docId }, + }) + + if (!chunkContentResponse.ok) { + console.error( + "Failed to fetch chunk content:", + chunkContentResponse.status, + ) + toast({ + title: "Error", + description: "Failed to load chunk content", + variant: "destructive", + }) + return + } + + const chunkContent = await chunkContentResponse.json() + + // Ensure we are still on the same document before mutating UI + if (selectedCitation?.itemId !== documentId) { + return + } + + if (chunkContent && chunkContent.chunkContent) { + if (documentOperationsRef?.current?.clearHighlights) { + documentOperationsRef.current.clearHighlights() + } + + if (documentOperationsRef?.current?.highlightText) { + try { + await documentOperationsRef.current.highlightText( + chunkContent.chunkContent, + newChunkIndex, + chunkContent.pageIndex, + true, + ) + } catch (error) { + console.error( + "Error highlighting chunk text:", + chunkContent.chunkContent, + error, + ) + } + } + } + } catch (error) { + console.error("Error in handleChunkIndexChange:", error) + toast({ + title: "Error", + description: "Failed to process chunk navigation", + variant: "destructive", + }) + } + } + }, + [selectedCitation, toast, documentOperationsRef], + ) + + useEffect(() => { + if (selectedCitation && isDocumentLoaded) { + handleChunkIndexChange( + selectedChunkIndex, + selectedCitation?.itemId ?? "", + selectedCitation?.docId ?? "", + ) + } + }, [ + selectedChunkIndex, + selectedCitation, + isDocumentLoaded, + handleChunkIndexChange, + ]) + // Handler for citation clicks - moved before conditional returns const handleCitationClick = useCallback( - (citation: Citation, fromSources: boolean = false) => { + (citation: Citation, chunkIndex?: number, fromSources: boolean = false) => { if (!citation || !citation.clId || !citation.itemId) { // For citations without clId or itemId, open as regular link if (citation.url) { @@ -1197,6 +1301,11 @@ export const ChatPage = ({ setCurrentCitations([]) setCurrentMessageId(null) } + // Handle chunk index change if provided + setSelectedChunkIndex(null) + setTimeout(() => { + setSelectedChunkIndex(chunkIndex ?? null) + }, 0) }, [], ) @@ -1206,8 +1315,27 @@ export const ChatPage = ({ setIsCitationPreviewOpen(false) setSelectedCitation(null) setCameFromSources(false) + setSelectedChunkIndex(null) + setIsDocumentLoaded(false) + }, []) + + // Callback for when document is loaded in CitationPreview + const handleDocumentLoaded = useCallback(() => { + setIsDocumentLoaded(true) }, []) + useEffect(() => { + setIsDocumentLoaded(false) + }, [selectedCitation]) + + useEffect(() => { + setIsCitationPreviewOpen(false) + setSelectedCitation(null) + setCameFromSources(false) + setSelectedChunkIndex(null) + setIsDocumentLoaded(false) + }, [chatId]) + // Handler for back to sources navigation const handleBackToSources = useCallback(() => { if (currentCitations.length > 0 && currentMessageId) { @@ -1216,6 +1344,7 @@ export const ChatPage = ({ setIsCitationPreviewOpen(false) setSelectedCitation(null) setCameFromSources(false) + setSelectedChunkIndex(null) } }, [currentCitations, currentMessageId]) @@ -1610,6 +1739,8 @@ export const ChatPage = ({ onClose={handleCloseCitationPreview} showBackButton={cameFromSources} onBackToSources={handleBackToSources} + documentOperationsRef={documentOperationsRef} + onDocumentLoaded={handleDocumentLoaded} />

) @@ -1683,7 +1814,11 @@ const CitationList = ({ onCitationClick, }: { citations: Citation[] - onCitationClick?: (citation: Citation, fromSources?: boolean) => void + onCitationClick?: ( + citation: Citation, + chunkIndex?: number, + fromSources?: boolean, + ) => void }) => { return (
    @@ -1693,7 +1828,7 @@ const CitationList = ({ className="border-[#E6EBF5] dark:border-gray-700 border-[1px] rounded-[10px] mt-[12px] w-[85%] cursor-pointer hover:border-gray-400 dark:hover:border-gray-500 transition-colors" onClick={(e) => { e.preventDefault() - onCitationClick?.(citation, true) + onCitationClick?.(citation, undefined, true) }} >
    @@ -1729,7 +1864,11 @@ const Sources = ({ showSources: boolean citations: Citation[] closeSources: () => void - onCitationClick?: (citation: Citation, fromSources?: boolean) => void + onCitationClick?: ( + citation: Citation, + chunkIndex?: number, + fromSources?: boolean, + ) => void }) => { return showSources ? (
    + + + ) }, errorComponent: errorComponent, diff --git a/frontend/src/utils/chatUtils.tsx b/frontend/src/utils/chatUtils.tsx index 0626a045f..f3b10f233 100644 --- a/frontend/src/utils/chatUtils.tsx +++ b/frontend/src/utils/chatUtils.tsx @@ -4,7 +4,8 @@ import { splitGroupedCitationsWithSpaces } from "@/lib/utils" export const generateUUID = () => crypto.randomUUID() export const textToCitationIndex = /\[(\d+)\]/g -export const textToImageCitationIndex = /\[(\d+_\d+)\]/g +export const textToImageCitationIndex = /(? { @@ -12,6 +13,7 @@ export const cleanCitationsFromResponse = (text: string): string => { return text .replace(textToCitationIndex, "") .replace(textToImageCitationIndex, "") + .replace(textToKbItemCitationIndex, "") .replace(/[ \t]+/g, " ") .trim() } @@ -22,18 +24,6 @@ export const processMessage = ( citationUrls: string[], ) => { text = splitGroupedCitationsWithSpaces(text) - text = text.replace( - /(\[\d+_\d+\])/g, - (fullMatch, capturedCitation, offset, string) => { - // Check if this image citation appears earlier in the string - const firstIndex = string.indexOf(fullMatch) - if (firstIndex < offset) { - // remove duplicate image citations - return "" - } - return capturedCitation - }, - ) text = text.replace( textToImageCitationIndex, (match, citationKey, offset, string) => { @@ -46,6 +36,16 @@ export const processMessage = ( return `![image-citation:${citationKey}](image-citation:${citationKey})` }, ) + text = text.replace(textToKbItemCitationIndex, (_, citationKey) => { + const index = citationMap + ? citationMap[parseInt(citationKey.split("_")[0], 10)] + : parseInt(citationKey.split("_")[0], 10) + const chunkIndex = parseInt(citationKey.split("_")[1], 10) + const url = citationUrls[index] + return typeof index === "number" && typeof chunkIndex === "number" && url + ? `[${index + 1}_${chunkIndex}](${url})` + : "" + }) if (citationMap) { return text.replace(textToCitationIndex, (match, num) => { diff --git a/server/ai/agentPrompts.ts b/server/ai/agentPrompts.ts index 73ec5529c..fd5c8bd07 100644 --- a/server/ai/agentPrompts.ts +++ b/server/ai/agentPrompts.ts @@ -859,16 +859,38 @@ You must respond in valid JSON format with the following structure: # Error Handling If information is missing or unclear: Set "answer" to null` -export const agentBaselineFileContextPromptJson = ( +export const agentBaselineKbContextPromptJson = ( userContext: string, dateForAI: string, retrievedContext: string, + agentPromptData?: AgentPromptData ) => `The current date is: ${dateForAI}. Based on this information, make your answers. Don't try to give vague answers without any logic. Be formal as much as possible. -You are an AI assistant with access to a SINGLE file. You have access to the following types of data: +You are an AI assistant with access to some data given as context. You should only answer from that given context. You have access to the following types of data: +1. Files (documents, spreadsheets, etc.) +2. User profiles +3. Emails +4. Calendar events + +${agentPromptData ? ` +# Context of the agent {priority} +Name: ${agentPromptData.name} +Description: ${agentPromptData.description} +Prompt: ${agentPromptData.prompt} + +# Agent Sources +${agentPromptData.sources.length > 0 ? agentPromptData.sources.map((source) => `- ${typeof source === "string" ? source : JSON.stringify(source)}`).join("\\n") : "No specific sources provided by agent."} +This is the context of the agent, it is very important to follow this. You MUST prioritize and filter information based on the # Agent Sources provided. If sources are listed, your response should strictly align with the content and type of these sources. If no specific sources are listed under # Agent Sources, proceed with the general context.` : ""} + +## File & Chunk Formatting (CRITICAL) +- Each file starts with a header line exactly like: + index {docId} {file context begins here...} +- \`docId\` is a unique identifier for that file (e.g., 0, 1, 2, etc.). +- Inside the file context, text is split into chunks. +- Each chunk might begin with a bracketed numeric index, e.g.: [0], [1], [2], etc. +- This is the chunk index within that file, if it exists. -1. Files (pdfs, documents, readme etc.) The context provided will be formatted with specific fields: ## File Context Format - Title @@ -877,12 +899,34 @@ The context provided will be formatted with specific fields: - File Size - Creation and update timestamps - Owner information -- Content chunks with their indices +- Content chunks with their indices (inline within the file content) - Relevance scores -## Chunk Context Format (IMPORTANT) -- The entire file is provided below as a single text block. -- The file is split into chunks inline; each chunk begins with a bracketed numeric index like [0], [1], [2], etc. -- These indices are the ONLY valid citation targets. +## User Context Format +- App and Entity type +- Addition date +- Name and email +- Gender +- Job title +- Department +- Location +- Relevance score +## Email Context Format +- App and Entity type +- Timestamp +- Subject +- From/To/Cc/Bcc +- Labels +- Content chunks +- Relevance score +## Event Context Format +- App and Entity type +- Event name and description +- Location and URLs +- Time information +- Organizer and attendees +- Recurrence patterns +- Meeting links +- Relevance score # Context of the user talking to you ${userContext} This includes: @@ -894,53 +938,58 @@ This includes: ${retrievedContext} # Guidelines for Response 1. Data Interpretation: - - Use ONLY the provided chunks as your knowledge base. - - Treat each [number] as the authoritative chunk index. + - Use ONLY the provided files and their chunks as your knowledge base. + - Treat every file header \`index {docId} ...\` as the start of a new document. + - Treat every bracketed number like [0], [1], [2] as the authoritative chunk index within that document. - If dates exist, interpret them relative to the user's timezone when paraphrasing. 2. Response Structure: - - Start with the most relevant facts from the chunks. + - Start with the most relevant facts from the chunks across files. - Keep order chronological when it helps comprehension. - - Every factual statement MUST cite the chunk it came from using [index] where index = the chunk's \`index\` value. + - Every factual statement MUST cite the exact chunk it came from using the format: + K[docId_chunkIndex] + where: + - \`docId\` is taken from the file header line ("index {docId} ..."). + - \`chunkIndex\` is the bracketed number prefixed on that chunk within the same file. + - Examples: + - Single citation: "X is true K[12_3]." + - Two citations in one sentence (from different files or chunks): "X K[12_3] and Y K[7_0]." - Use at most 1-2 citations per sentence; NEVER add more than 2 for one sentence. -3. Citation Rules (CHUNK-LEVEL ONLY): - - Format: [0], [12], [37] — the number is the chunk \`index\`. - - Place the citation immediately after the relevant statement. - - Do NOT cite the file itself, only chunks. - - Do NOT group indices inside one bracket. WRONG: "[0, 1]". - - If a sentence draws from two distinct chunks, cite them as separate brackets inline, e.g., "... was agreed [3] and finalized [7]". +3. Citation Rules (DOCUMENT+CHUNK LEVEL ONLY): + - ALWAYS cite at the chunk level with the K[docId_chunkIndex] format. + - Place the citation immediately after the relevant claim. + - Do NOT group indices inside one set of brackets (WRONG: "K[12_3,7_1]"). + - If a sentence draws on two distinct chunks (possibly from different files), include two separate citations inline, e.g., "... K[12_3] ... K[7_1]". - Only cite information that appears verbatim or is directly inferable from the cited chunk. + - If you cannot ground a claim to a specific chunk, do not make the claim. 4. Quality Assurance: - - Cross-check across multiple chunks when available and note inconsistencies. - - Briefly note inconsistencies if chunks conflict. + - Cross-check across multiple chunks/files when available and briefly note inconsistencies if they exist. - Keep tone professional and concise. - - Acknowledge gaps if the chunks don't contain enough detail. + - Acknowledge gaps if the provided chunks don't contain enough detail. # Response Format You must respond in valid JSON format with the following structure: { - "answer": "Your detailed answer to the query found in context with citations in [index] format or null if not found. This can be well formatted markdown value inside the answer field." + "answer": "Your detailed answer to the query based ONLY on the provided files, with citations in K[docId_chunkIndex] format, or null if not found. This can be well formatted markdown inside the answer field." } -If NO relevant items are found in Retrieved Context or context doesn't match query: +If NO relevant items are found in Retrieved Context or the context doesn't match the query: { "answer": null } # Important Notes: -- Do not worry about sensitive questions, you are a bot with the access and authorization to answer based on context -- Maintain professional tone appropriate for workspace context -- Format dates relative to current user time -- Clean and normalize any raw content as needed -- Consider the relationship between different pieces of content -- If no clear answer is found in the retrieved context, set "answer" to null -- Do not explain why you couldn't find the answer in the context, just set it to null -- We want only 2 cases, either answer is found or we set it to null -- No explanation why answer was not found in the context, just set it to null -- Citations must use the exact index numbers from the provided context -- Keep citations natural and relevant - don't overcite -- Ensure that any mention of dates or times is expressed in the user's local time zone. Always respect the user's time zone. +- Do not worry about sensitive questions; you are authorized to answer based on the provided context. +- Maintain a professional tone appropriate for a workspace context. +- Format dates relative to current user time. +- Clean and normalize any raw content as needed. +- Consider relationships between pieces of content across files. +- If no clear answer is found in the provided chunks, set "answer" to null. +- Do not explain why an answer wasn't found; simply set it to null. +- Citations must use the exact K[docId_chunkIndex] format. +- Keep citations natural and relevant—don't overcite. +- Ensure all mentions of dates/times are expressed in the user's local time zone. # Error Handling -If information is missing or unclear, or the query lacks context set "answer" as null` +If information is missing or unclear, or the query lacks context, set "answer" as null` export const agentQueryRewritePromptJson = ( userContext: string, diff --git a/server/ai/context.ts b/server/ai/context.ts index 04b04d6d8..a75c1f1ae 100644 --- a/server/ai/context.ts +++ b/server/ai/context.ts @@ -37,123 +37,144 @@ import { chunkSheetWithHeaders } from "@/sheetChunk" // Utility function to extract header from chunks and remove headers from each chunk const extractHeaderAndDataChunks = ( - chunks_summary: (string | { chunk: string; score: number; index: number })[] | undefined, + chunks_summary: + | (string | { chunk: string; score: number; index: number })[] + | undefined, + matchfeatures?: any, +): { + chunks_summary: (string | { chunk: string; score: number; index: number })[] matchfeatures?: any -): { - chunks_summary: (string | { chunk: string; score: number; index: number })[]; - matchfeatures?: any; } => { if (!chunks_summary || chunks_summary.length === 0) { - return { chunks_summary: [], matchfeatures }; + return { chunks_summary: [], matchfeatures } } // Find the header from the first chunk - let headerChunk = ''; + let headerChunk = "" if (chunks_summary.length > 0) { - const firstChunk = typeof chunks_summary[0] === "string" ? chunks_summary[0] : chunks_summary[0].chunk; - const lines = firstChunk.split('\n'); - if (lines.length > 0 && lines[0].includes('\t')) { - headerChunk = lines[0]; // Extract the header line + const firstChunk = + typeof chunks_summary[0] === "string" + ? chunks_summary[0] + : chunks_summary[0].chunk + const lines = firstChunk.split("\n") + if (lines.length > 0 && lines[0].includes("\t")) { + headerChunk = lines[0] // Extract the header line } } - + // Process all chunks: remove header from each and keep only data rows - const processedChunks: (string | { chunk: string; score: number; index: number })[] = []; - let newMatchfeatures = matchfeatures; - + const processedChunks: ( + | string + | { chunk: string; score: number; index: number } + )[] = [] + let newMatchfeatures = matchfeatures + // Add header as first chunk if found, using the same structure as original if (headerChunk) { if (typeof chunks_summary[0] === "string") { - processedChunks.push(headerChunk); + processedChunks.push(headerChunk) } else { processedChunks.push({ chunk: headerChunk, score: 1, index: 0, - }); + }) } - // Update matchfeatures to include the header chunk score - if (newMatchfeatures) { - const existingCells = newMatchfeatures.chunk_scores?.cells || {}; - const scores = Object.values(existingCells) as number[]; - const maxScore = scores.length > 0 ? Math.max(...scores) : 0; - // Create new chunk_scores that match the new chunks - const newChunkScores: Record = {} - newChunkScores["0"] = maxScore + 1 - Object.entries(existingCells).forEach(([idx, score]) => { - newChunkScores[(parseInt(idx) + 1).toString()] = score as number - }) - - newMatchfeatures = { - ...newMatchfeatures, - chunk_scores: { - cells: newChunkScores - } - }; - } + // Update matchfeatures to include the header chunk score + if (newMatchfeatures) { + const existingCells = newMatchfeatures.chunk_scores?.cells || {} + const scores = Object.values(existingCells) as number[] + const maxScore = scores.length > 0 ? Math.max(...scores) : 0 + // Create new chunk_scores that match the new chunks + const newChunkScores: Record = {} + newChunkScores["0"] = maxScore + 1 + Object.entries(existingCells).forEach(([idx, score]) => { + newChunkScores[(parseInt(idx) + 1).toString()] = score as number + }) + + newMatchfeatures = { + ...newMatchfeatures, + chunk_scores: { + cells: newChunkScores, + }, + } + } } - + // Process each original chunk: remove header and add data rows for (let i = 0; i < chunks_summary.length; i++) { - const originalChunk = chunks_summary[i]; - const chunkContent = typeof originalChunk === "string" ? originalChunk : originalChunk.chunk; - const lines = chunkContent.split('\n'); - + const originalChunk = chunks_summary[i] + const chunkContent = + typeof originalChunk === "string" ? originalChunk : originalChunk.chunk + const lines = chunkContent.split("\n") + // Skip the first line (header) and keep only data rows - const dataRows = lines.slice(1).filter(line => line.trim().length > 0); + const dataRows = lines.slice(1).filter((line) => line.trim().length > 0) if (dataRows.length > 0) { - const dataContent = dataRows.join('\n'); - + const dataContent = dataRows.join("\n") + if (typeof originalChunk === "string") { - processedChunks.push(dataContent); + processedChunks.push(dataContent) } else { processedChunks.push({ chunk: dataContent, score: originalChunk.score, - index: originalChunk.index - }); + index: originalChunk.index, + }) } } } - - return { chunks_summary: processedChunks, matchfeatures: newMatchfeatures }; -}; + + return { chunks_summary: processedChunks, matchfeatures: newMatchfeatures } +} // Utility function to process sheet queries for spreadsheet files const processSheetQuery = async ( - chunks_summary: (string | { chunk: string; score: number; index: number })[] | undefined, + chunks_summary: + | (string | { chunk: string; score: number; index: number })[] + | undefined, query: string, - matchfeatures: any + matchfeatures: any, ): Promise<{ - chunks_summary: { chunk: string; score: number; index: number }[]; - matchfeatures: any; - maxSummaryChunks: number; + chunks_summary: { chunk: string; score: number; index: number }[] + matchfeatures: any + maxSummaryChunks: number } | null> => { const duckDBResult = await querySheetChunks( - chunks_summary?.map((c) => typeof c === "string" ? c : c.chunk) || [], - query + chunks_summary?.map((c) => (typeof c === "string" ? c : c.chunk)) || [], + query, ) - + // If DuckDB query failed (null means not metric-related or SQL generation failed), return null to fallback to original approach if (!duckDBResult) { - return null; + return null } - + // Create metadata chunk with query information (excluding data) - const metadataChunk = JSON.stringify({ - assumptions: duckDBResult.assumptions, - schema_fragment: duckDBResult.schema_fragment - }, null, 2) - + const metadataChunk = JSON.stringify( + { + assumptions: duckDBResult.assumptions, + schema_fragment: duckDBResult.schema_fragment, + }, + null, + 2, + ) + // Use chunkSheetWithHeaders to chunk the 2D array data - const dataChunks = chunkSheetWithHeaders(duckDBResult.data.rows, {headerRows: 1}) - + const dataChunks = chunkSheetWithHeaders(duckDBResult.data.rows, { + headerRows: 1, + }) + // Combine metadata chunk with data chunks const allChunks = [metadataChunk, ...dataChunks] - - const newChunksSummary = allChunks.map((c, idx) => ({chunk: c, score: 0, index: idx})) - + + const newChunksSummary = allChunks.map((c, idx) => ({ + chunk: c, + score: 0, + index: idx, + })) + // Update matchfeatures to correspond to the new chunks let newMatchfeatures = matchfeatures if (matchfeatures) { @@ -162,20 +183,20 @@ const processSheetQuery = async ( allChunks.forEach((_, idx) => { newChunkScores[idx.toString()] = 0 // All new chunks get score 0 }) - + // Update the matchfeatures with new chunk_scores newMatchfeatures = { ...matchfeatures, chunk_scores: { - cells: newChunkScores - } + cells: newChunkScores, + }, } } - + return { chunks_summary: newChunksSummary, matchfeatures: newMatchfeatures, - maxSummaryChunks: allChunks.length + maxSummaryChunks: allChunks.length, } } @@ -249,7 +270,7 @@ const constructFileContext = ( return `App: ${fields.app} Entity: ${fields.entity} -Title: ${fields.title ? `Title: ${fields.title}` : ""}${typeof fields.createdAt === "number" && isFinite(fields.createdAt) ? `\nCreated: ${getRelativeTime(fields.createdAt)} (${new Date(fields.createdAt).toLocaleString("en-US", {timeZone: userTimezone})})` : ""}${typeof fields.updatedAt === "number" && isFinite(fields.updatedAt) ? `\nUpdated At: ${getRelativeTime(fields.updatedAt)} (${new Date(fields.updatedAt).toLocaleString("en-US", {timeZone: userTimezone})})` : ""} +Title: ${fields.title ? `Title: ${fields.title}` : ""}${typeof fields.createdAt === "number" && isFinite(fields.createdAt) ? `\nCreated: ${getRelativeTime(fields.createdAt)} (${new Date(fields.createdAt).toLocaleString("en-US", { timeZone: userTimezone })})` : ""}${typeof fields.updatedAt === "number" && isFinite(fields.updatedAt) ? `\nUpdated At: ${getRelativeTime(fields.updatedAt)} (${new Date(fields.updatedAt).toLocaleString("en-US", { timeZone: userTimezone })})` : ""} ${fields.owner ? `Owner: ${fields.owner}` : ""} ${fields.parentId ? `parent FolderId: ${fields.parentId}` : ""} ${fields.ownerEmail ? `Owner Email: ${fields.ownerEmail}` : ""} @@ -316,7 +337,7 @@ const constructMailContext = ( } return `App: ${fields.app} -Entity: ${fields.entity}${typeof fields.timestamp === "number" && isFinite(fields.timestamp) ? `\nSent: ${getRelativeTime(fields.timestamp)} (${new Date(fields.timestamp).toLocaleString("en-US", {timeZone: userTimezone})})` : ""} +Entity: ${fields.entity}${typeof fields.timestamp === "number" && isFinite(fields.timestamp) ? `\nSent: ${getRelativeTime(fields.timestamp)} (${new Date(fields.timestamp).toLocaleString("en-US", { timeZone: userTimezone })})` : ""} ${fields.subject ? `Subject: ${fields.subject}` : ""} ${fields.from ? `From: ${fields.from}` : ""} ${fields.to ? `To: ${fields.to.join(", ")}` : ""} @@ -348,7 +369,7 @@ const constructSlackMessageContext = ( Username: ${fields.username} Message: ${fields.text} ${fields.threadId ? "it's a message thread" : ""} - ${typeof fields.createdAt === "number" && isFinite(fields.createdAt) ? `\n Time: ${getRelativeTime(fields.createdAt)} (${new Date(fields.createdAt).toLocaleString("en-US", {timeZone: userTimezone})})` : ""} + ${typeof fields.createdAt === "number" && isFinite(fields.createdAt) ? `\n Time: ${getRelativeTime(fields.createdAt)} (${new Date(fields.createdAt).toLocaleString("en-US", { timeZone: userTimezone })})` : ""} User is part of Workspace: ${fields.teamName} vespa relevance score: ${relevance}` } @@ -380,7 +401,7 @@ ${ typeof fields.createdAt === "number" && isFinite(fields.createdAt) ? `\nCreated: ${getRelativeTime(fields.createdAt)} (${new Date( fields.createdAt, - ).toLocaleString("en-US", {timeZone: userTimezone})})` + ).toLocaleString("en-US", { timeZone: userTimezone })})` : "" } vespa relevance score: ${relevance}` @@ -432,7 +453,7 @@ const constructMailAttachmentContext = ( Entity: ${fields.entity} ${ typeof fields.timestamp === "number" && isFinite(fields.timestamp) - ? `\nSent: ${getRelativeTime(fields.timestamp)} (${new Date(fields.timestamp).toLocaleString("en-US", {timeZone: userTimeZone})})` + ? `\nSent: ${getRelativeTime(fields.timestamp)} (${new Date(fields.timestamp).toLocaleString("en-US", { timeZone: userTimeZone })})` : "" } ${fields.filename ? `Filename: ${fields.filename}` : ""} @@ -460,7 +481,7 @@ ${ ? `\nStart Time: ${ !fields.defaultStartTime ? new Date(fields.startTime).toUTCString() + - `(${new Date(fields.startTime).toLocaleString("en-US", {timeZone: userTimeZone})})` + `(${new Date(fields.startTime).toLocaleString("en-US", { timeZone: userTimeZone })})` : `No start time specified but date is ${new Date(fields.startTime)}` }` : "" @@ -470,7 +491,7 @@ ${ ? `\nEnd Time: ${ !fields.defaultStartTime ? new Date(fields.endTime).toUTCString() + - `(${new Date(fields.endTime).toLocaleString("en-US", {timeZone: userTimeZone})})` + `(${new Date(fields.endTime).toLocaleString("en-US", { timeZone: userTimeZone })})` : `No end time specified but date is ${new Date(fields.endTime)}` }` : "" @@ -697,12 +718,12 @@ const constructDataSourceFileContext = ( ${fields.fileSize ? `File Size: ${fields.fileSize} bytes` : ""} ${ typeof fields.createdAt === "number" && isFinite(fields.createdAt) - ? `\nCreated: ${getRelativeTime(fields.createdAt)} (${new Date(fields.createdAt).toLocaleString("en-US", {timeZone: userTimeZone})})` + ? `\nCreated: ${getRelativeTime(fields.createdAt)} (${new Date(fields.createdAt).toLocaleString("en-US", { timeZone: userTimeZone })})` : "" } ${ typeof fields.updatedAt === "number" && isFinite(fields.updatedAt) - ? `\nUpdated At: ${getRelativeTime(fields.updatedAt)} (${new Date(fields.updatedAt).toLocaleString("en-US", {timeZone: userTimeZone})})` + ? `\nUpdated At: ${getRelativeTime(fields.updatedAt)} (${new Date(fields.updatedAt).toLocaleString("en-US", { timeZone: userTimeZone })})` : "" } ${fields.uploadedBy ? `Uploaded By: ${fields.uploadedBy}` : ""} @@ -716,9 +737,8 @@ const constructCollectionFileContext = ( relevance: number, maxSummaryChunks?: number, isSelectedFiles?: boolean, - isMsgWithSources?: boolean, + isMsgWithKbItems?: boolean, ): string => { - if (!maxSummaryChunks && !isSelectedFiles) { maxSummaryChunks = fields.chunks_summary?.length } @@ -745,8 +765,7 @@ const constructCollectionFileContext = ( } let content = "" - if (isMsgWithSources && fields.chunks_pos_summary) { - // When user has selected one file to chat with, use original chunk positions + if (isMsgWithKbItems && fields.chunks_pos_summary) { content = chunks .map((v) => { const originalIndex = fields.chunks_pos_summary?.[v.index] ?? v.index @@ -779,13 +798,12 @@ const constructCollectionFileContext = ( ? fields.image_chunks_summary?.length : 5 - if (fields.matchfeatures) { + const summaryStrings = + fields.image_chunks_summary?.map((c) => + typeof c === "string" ? c : c.chunk, + ) || [] - const summaryStrings = fields.image_chunks_summary?.map((c) => - typeof c === "string" ? c : c.chunk, - ) || [] - imageChunks = getSortedScoredImageChunks( fields.matchfeatures, fields.image_chunks_pos_summary as number[], @@ -794,7 +812,7 @@ const constructCollectionFileContext = ( ) } else { const imageChunksPos = fields.image_chunks_pos_summary as number[] - + imageChunks = fields.image_chunks_summary?.map((chunk, idx) => { const result = { @@ -807,10 +825,9 @@ const constructCollectionFileContext = ( } let imageContent = imageChunks - .slice(0, maxImageChunks) - .map((v) => v.chunk) - .join("\n") - + .slice(0, maxImageChunks) + .map((v) => v.chunk) + .join("\n") return `Source: Knowledge Base File: ${fields.fileName || "N/A"} @@ -850,7 +867,12 @@ export const answerMetadataContextMap = ( searchResult.relevance, ) } else if (searchResult.fields.sddocname === eventSchema) { - return constructEventContext(searchResult.fields, searchResult.relevance, dateForAI, userTimeZone) + return constructEventContext( + searchResult.fields, + searchResult.relevance, + dateForAI, + userTimeZone, + ) } else { throw new Error( `Invalid search result type: ${searchResult.fields.sddocname}`, @@ -889,36 +911,58 @@ export const answerContextMap = async ( userMetadata: UserMetadataType, maxSummaryChunks?: number, isSelectedFiles?: boolean, - isMsgWithSources?: boolean, + isMsgWithKbItems?: boolean, query?: string, ): Promise => { - if(searchResult.fields.sddocname === fileSchema || searchResult.fields.sddocname === dataSourceFileSchema || searchResult.fields.sddocname === KbItemsSchema || searchResult.fields.sddocname === mailAttachmentSchema) { + if ( + searchResult.fields.sddocname === fileSchema || + searchResult.fields.sddocname === dataSourceFileSchema || + searchResult.fields.sddocname === KbItemsSchema || + searchResult.fields.sddocname === mailAttachmentSchema + ) { let mimeType - if(searchResult.fields.sddocname === mailAttachmentSchema) { + if (searchResult.fields.sddocname === mailAttachmentSchema) { mimeType = searchResult.fields.fileType } else { mimeType = searchResult.fields.mimeType } - if(mimeType === "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" || + if ( + mimeType === + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" || mimeType === "application/vnd.ms-excel" || - mimeType === "text/csv") { - const result = extractHeaderAndDataChunks(searchResult.fields.chunks_summary, searchResult.fields.matchfeatures); - searchResult.fields.chunks_summary = result.chunks_summary; - if (result.matchfeatures) { - searchResult.fields.matchfeatures = result.matchfeatures; - } - - if (query) { - const sheetResult = await processSheetQuery(searchResult.fields.chunks_summary, query, searchResult.fields.matchfeatures) - if (sheetResult) { - const { chunks_summary, matchfeatures, maxSummaryChunks: newMaxSummaryChunks } = sheetResult - searchResult.fields.chunks_summary = chunks_summary - searchResult.fields.matchfeatures = matchfeatures - maxSummaryChunks = newMaxSummaryChunks - } else { - maxSummaryChunks = Math.min(searchResult.fields.chunks_summary?.length || 0, 100) - } + mimeType === "text/csv" + ) { + const result = extractHeaderAndDataChunks( + searchResult.fields.chunks_summary, + searchResult.fields.matchfeatures, + ) + searchResult.fields.chunks_summary = result.chunks_summary + if (result.matchfeatures) { + searchResult.fields.matchfeatures = result.matchfeatures + } + + if (query) { + const sheetResult = await processSheetQuery( + searchResult.fields.chunks_summary, + query, + searchResult.fields.matchfeatures, + ) + if (sheetResult) { + const { + chunks_summary, + matchfeatures, + maxSummaryChunks: newMaxSummaryChunks, + } = sheetResult + searchResult.fields.chunks_summary = chunks_summary + searchResult.fields.matchfeatures = matchfeatures + maxSummaryChunks = newMaxSummaryChunks + } else { + maxSummaryChunks = Math.min( + searchResult.fields.chunks_summary?.length || 0, + 100, + ) } + } } } if (searchResult.fields.sddocname === fileSchema) { @@ -940,7 +984,12 @@ export const answerContextMap = async ( isSelectedFiles, ) } else if (searchResult.fields.sddocname === eventSchema) { - return constructEventContext(searchResult.fields, searchResult.relevance, userMetadata.dateForAI, userMetadata.userTimezone) + return constructEventContext( + searchResult.fields, + searchResult.relevance, + userMetadata.dateForAI, + userMetadata.userTimezone, + ) } else if (searchResult.fields.sddocname === mailAttachmentSchema) { return constructMailAttachmentContext( searchResult.fields, @@ -975,7 +1024,7 @@ export const answerContextMap = async ( searchResult.relevance, maxSummaryChunks, isSelectedFiles, - isMsgWithSources, + isMsgWithKbItems, ) } else { throw new Error( diff --git a/server/ai/provider/index.ts b/server/ai/provider/index.ts index 056688bae..d3ce536d0 100644 --- a/server/ai/provider/index.ts +++ b/server/ai/provider/index.ts @@ -107,8 +107,8 @@ import type { ProviderV2 } from "@ai-sdk/provider" import { agentAnalyzeInitialResultsOrRewriteSystemPrompt, agentAnalyzeInitialResultsOrRewriteV2SystemPrompt, - agentBaselineFileContextPromptJson, agentBaselineFilesContextPromptJson, + agentBaselineKbContextPromptJson, agentBaselinePrompt, agentBaselinePromptJson, agentBaselineReasoningPromptJson, @@ -1247,7 +1247,7 @@ export const baselineRAGJsonStream = ( retrievedCtx: string, params: ModelParams, specificFiles?: boolean, - isMsgWithSources?: boolean, + isMsgWithKbItems?: boolean, ): AsyncIterableIterator => { if (!params.modelId) { params.modelId = defaultFastModel @@ -1261,23 +1261,34 @@ export const baselineRAGJsonStream = ( if (specificFiles) { Logger.info("Using baselineFilesContextPromptJson") - if (isMsgWithSources) { - params.systemPrompt = agentBaselineFileContextPromptJson( - userCtx, - userMetadata.dateForAI, - retrievedCtx, - ) - } else if (!isAgentPromptEmpty(params.agentPrompt)) { - params.systemPrompt = agentBaselineFilesContextPromptJson( - userCtx, - indexToCitation(retrievedCtx), - parseAgentPrompt(params.agentPrompt), - ) + if (!isAgentPromptEmpty(params.agentPrompt)) { + if (isMsgWithKbItems) { + params.systemPrompt = agentBaselineKbContextPromptJson( + userCtx, + userMetadata.dateForAI, + retrievedCtx, + parseAgentPrompt(params.agentPrompt), + ) + } else { + params.systemPrompt = agentBaselineFilesContextPromptJson( + userCtx, + indexToCitation(retrievedCtx), + parseAgentPrompt(params.agentPrompt), + ) + } } else { - params.systemPrompt = baselineFilesContextPromptJson( - userCtx, - indexToCitation(retrievedCtx), - ) + if (isMsgWithKbItems) { + params.systemPrompt = agentBaselineKbContextPromptJson( + userCtx, + userMetadata.dateForAI, + retrievedCtx, + ) + } else { + params.systemPrompt = baselineFilesContextPromptJson( + userCtx, + indexToCitation(retrievedCtx), + ) + } } } else if (defaultReasoning) { Logger.info("Using baselineReasoningPromptJson") diff --git a/server/api/chat/chat.ts b/server/api/chat/chat.ts index 17f77eb7e..3dfd3e2d2 100644 --- a/server/api/chat/chat.ts +++ b/server/api/chat/chat.ts @@ -181,6 +181,7 @@ import { isValidApp, isValidEntity, collectFollowupContext, + textToKbItemCitationIndex, } from "./utils" import { getRecentChainBreakClassifications, @@ -193,9 +194,6 @@ import { } from "@/db/attachment" import type { AttachmentMetadata } from "@/shared/types" import { parseAttachmentMetadata } from "@/utils/parseAttachment" -import { isImageFile } from "shared/fileUtils" -import { promises as fs } from "node:fs" -import path from "node:path" import { getAgentUsageByUsers, getChatCountsByAgents, @@ -468,21 +466,27 @@ const checkAndYieldCitations = async function* ( baseIndex: number = 0, email: string, yieldedImageCitations: Set, - isMsgWithSources?: boolean, + isMsgWithKbItems?: boolean, ) { const text = splitGroupedCitationsWithSpaces(textInput) let match let imgMatch + let kbMatch = null while ( (match = textToCitationIndex.exec(text)) !== null || - (imgMatch = textToImageCitationIndex.exec(text)) !== null + (imgMatch = textToImageCitationIndex.exec(text)) !== null || + (isMsgWithKbItems && + (kbMatch = textToKbItemCitationIndex.exec(text)) !== null) ) { - if (match) { - const citationIndex = parseInt(match[1], 10) + if (match || kbMatch) { + let citationIndex = 0 + if (match) { + citationIndex = parseInt(match[1], 10) + } else if (kbMatch) { + citationIndex = parseInt(kbMatch[1].split("_")[0], 10) + } if (!yieldedCitations.has(citationIndex)) { - const item = isMsgWithSources - ? results[baseIndex] - : results[citationIndex - baseIndex] + const item = results[citationIndex - baseIndex] if (item) { // TODO: fix this properly, empty citations making streaming broke const f = (item as any)?.fields @@ -496,15 +500,13 @@ const checkAndYieldCitations = async function* ( yield { citation: { index: citationIndex, - item: isMsgWithSources - ? searchToCitation(item as VespaSearchResults, citationIndex) - : searchToCitation(item as VespaSearchResults), + item: searchToCitation(item as VespaSearchResults), }, } yieldedCitations.add(citationIndex) } else { loggerWithChild({ email: email }).error( - `Found a citation index but could not find it in the search result: ${citationIndex}, ${results.length}` + `Found a citation index but could not find it in the search result: ${citationIndex}, ${results.length}`, ) } } @@ -553,7 +555,7 @@ const checkAndYieldCitations = async function* ( yieldedImageCitations.add(citationIndex) } else { loggerWithChild({ email: email }).error( - `Found a citation index but could not find it in the search result: ${citationIndex}, ${results.length}` + `Found a citation index but could not find it in the search result: ${citationIndex}, ${results.length}`, ) continue } @@ -595,7 +597,7 @@ async function* processIterator( previousResultsLength: number = 0, userRequestsReasoning?: boolean, email?: string, - isMsgWithSources?: boolean, + isMsgWithKbItems?: boolean, ): AsyncIterableIterator< ConverseResponse & { citation?: { index: number; item: any } @@ -624,7 +626,7 @@ async function* processIterator( previousResultsLength, email!, yieldedImageCitations, - isMsgWithSources, + isMsgWithKbItems, ) yield { text: chunk.text, reasoning } } else { @@ -650,7 +652,7 @@ async function* processIterator( previousResultsLength, email!, yieldedImageCitations, - isMsgWithSources, + isMsgWithKbItems, ) yield { text: token, reasoning } } @@ -688,7 +690,7 @@ async function* processIterator( previousResultsLength, email!, yieldedImageCitations, - isMsgWithSources, + isMsgWithKbItems, ) currentAnswer = parsed.answer } @@ -1105,6 +1107,7 @@ export async function buildContext( userMetadata: UserMetadataType, startIndex: number = 0, builtUserQuery?: string, + isMsgWithKbItems?: boolean, ): Promise { const contextPromises = results?.map( async (v, i) => @@ -1113,7 +1116,7 @@ export async function buildContext( userMetadata, maxSummaryCount, undefined, - undefined, + isMsgWithKbItems, builtUserQuery, )}`, ) @@ -1385,7 +1388,7 @@ async function* generateIterativeTimeFilterAndQueryRewrite( dataSourceIds: agentSpecificDataSourceIds, channelIds: channelIds, collectionSelections: agentSpecificCollectionSelections, - selectedItem: selectedItem,//agentIntegration format (app_integrations format) + selectedItem: selectedItem, //agentIntegration format (app_integrations format) }, ) } @@ -1482,6 +1485,7 @@ async function* generateIterativeTimeFilterAndQueryRewrite( userMetadata, 0, message, + agentSpecificCollectionSelections.length > 0, ) const queryRewriteSpan = rewriteSpan?.startSpan("query_rewriter") @@ -1616,6 +1620,7 @@ async function* generateIterativeTimeFilterAndQueryRewrite( userMetadata, 0, message, + agentSpecificCollectionSelections.length > 0, ) const { imageFileNames } = extractImageFileNames( @@ -1648,6 +1653,8 @@ async function* generateIterativeTimeFilterAndQueryRewrite( agentPrompt, imageFileNames, }, + agentSpecificCollectionSelections.length > 0, + agentSpecificCollectionSelections.length > 0, ) const answer = yield* processIterator( @@ -1656,6 +1663,7 @@ async function* generateIterativeTimeFilterAndQueryRewrite( previousResultsLength, config.isReasoning && userRequestsReasoning, email, + agentSpecificCollectionSelections.length > 0, ) if (answer) { ragSpan?.setAttribute("answer_found", true) @@ -1810,6 +1818,7 @@ async function* generateIterativeTimeFilterAndQueryRewrite( userMetadata, startIndex, message, + agentSpecificCollectionSelections.length > 0, ) const { imageFileNames } = extractImageFileNames( @@ -1845,6 +1854,8 @@ async function* generateIterativeTimeFilterAndQueryRewrite( messages, imageFileNames, }, + agentSpecificCollectionSelections.length > 0, + agentSpecificCollectionSelections.length > 0, ) const answer = yield* processIterator( @@ -1853,6 +1864,7 @@ async function* generateIterativeTimeFilterAndQueryRewrite( previousResultsLength, config.isReasoning && userRequestsReasoning, email, + agentSpecificCollectionSelections.length > 0, ) if (answer) { @@ -1891,7 +1903,7 @@ async function* generateAnswerFromGivenContext( passedSpan?: Span, threadIds?: string[], attachmentFileIds?: string[], - isMsgWithSources?: boolean, + isMsgWithKbItems?: boolean, modelId?: string, isValidPath?: boolean, folderIds?: string[], @@ -2134,7 +2146,7 @@ async function* generateAnswerFromGivenContext( userMetadata, i < chunksPerDocument.length ? chunksPerDocument[i] : 0, true, - isMsgWithSources, + isMsgWithKbItems, message, ) if ( @@ -2158,7 +2170,7 @@ async function* generateAnswerFromGivenContext( ) } } - return isMsgWithSources ? content : `Index ${i + startIndex} \n ${content}` + return `Index ${i + startIndex} \n ${content}` }) const resolvedContexts = contextPromises @@ -2210,7 +2222,7 @@ async function* generateAnswerFromGivenContext( imageFileNames: finalImageFileNames, }, true, - isMsgWithSources, + isMsgWithKbItems, ) const answer = yield* processIterator( @@ -2219,7 +2231,7 @@ async function* generateAnswerFromGivenContext( previousResultsLength, userRequestsReasoning, email, - isMsgWithSources, + isMsgWithKbItems, ) if (answer) { generateAnswerSpan?.setAttribute("answer_found", true) @@ -2904,6 +2916,7 @@ async function* processResultsForMetadata( email?: string, agentContext?: string, modelId?: string, + isMsgWithKbItems?: boolean, ) { if (app?.length == 1 && app[0] === Apps.GoogleDrive) { chunksCount = config.maxGoogleDriveSummary @@ -2919,7 +2932,14 @@ async function* processResultsForMetadata( "Document chunk size", `full_context maxed to ${chunksCount}`, ) - const context = await buildContext(items, chunksCount, userMetadata, 0, input) + const context = await buildContext( + items, + chunksCount, + userMetadata, + 0, + input, + isMsgWithKbItems, + ) const { imageFileNames } = extractImageFileNames(context, items) const streamOptions = { stream: true, @@ -2947,6 +2967,8 @@ async function* processResultsForMetadata( userMetadata, context, streamOptions, + isMsgWithKbItems, + isMsgWithKbItems, ) } @@ -2955,6 +2977,8 @@ async function* processResultsForMetadata( items, 0, config.isReasoning && userRequestsReasoning, + email, + isMsgWithKbItems, ) } @@ -3291,7 +3315,14 @@ async function* generateMetadataQueryAnswer( pageSpan?.setAttribute( "context", - await buildContext(items, 20, userMetadata, 0, input), + await buildContext( + items, + 20, + userMetadata, + 0, + input, + agentSpecificCollectionSelections.length > 0, + ), ) if (!items.length) { loggerWithChild({ email: email }).info( @@ -3319,6 +3350,7 @@ async function* generateMetadataQueryAnswer( email, agentPrompt, modelId, + agentSpecificCollectionSelections.length > 0, ) if (answer == null) { @@ -3479,7 +3511,14 @@ async function* generateMetadataQueryAnswer( span?.setAttribute( "context", - await buildContext(items, 20, userMetadata, 0, input), + await buildContext( + items, + 20, + userMetadata, + 0, + input, + agentSpecificCollectionSelections.length > 0, + ), ) span?.end() loggerWithChild({ email: email }).info( @@ -3509,6 +3548,7 @@ async function* generateMetadataQueryAnswer( email, agentPrompt, modelId, + agentSpecificCollectionSelections.length > 0, ) return } else if ( @@ -3628,7 +3668,14 @@ async function* generateMetadataQueryAnswer( ) iterationSpan?.setAttribute( `context`, - await buildContext(items, 20, userMetadata, 0, input), + await buildContext( + items, + 20, + userMetadata, + 0, + input, + agentSpecificCollectionSelections.length > 0, + ), ) iterationSpan?.end() @@ -3661,6 +3708,7 @@ async function* generateMetadataQueryAnswer( email, agentPrompt, modelId, + agentSpecificCollectionSelections.length > 0, ) if (answer == null) { @@ -3939,7 +3987,7 @@ export async function* UnderstandMessageAndAnswerForGivenContext( threadIds?: string[], attachmentFileIds?: string[], agentPrompt?: string, - isMsgWithSources?: boolean, + isMsgWithKbItems?: boolean, modelId?: string, isValidPath?: boolean, folderIds?: string[], @@ -3973,7 +4021,7 @@ export async function* UnderstandMessageAndAnswerForGivenContext( passedSpan, threadIds, attachmentFileIds, - isMsgWithSources, + isMsgWithKbItems, modelId, isValidPath, folderIds, @@ -6986,14 +7034,18 @@ export const EnhancedMessageFeedbackApi = async (c: Context) => { // Debug logging loggerWithChild({ email: email }).info( `Enhanced feedback request received - ${JSON.stringify({ - messageId, - type, - shareChat, - customFeedback: !!customFeedback, - selectedOptionsCount: selectedOptions?.length || 0, - }, null, 2)} - },` + ${JSON.stringify( + { + messageId, + type, + shareChat, + customFeedback: !!customFeedback, + selectedOptionsCount: selectedOptions?.length || 0, + }, + null, + 2, + )} + },`, ) const message = await getMessageByExternalId(db, messageId) diff --git a/server/api/chat/types.ts b/server/api/chat/types.ts index 051bbe56f..04cac55e7 100644 --- a/server/api/chat/types.ts +++ b/server/api/chat/types.ts @@ -72,7 +72,6 @@ export const MinimalCitationSchema = z.object({ threadId: z.string().optional(), itemId: z.string().optional(), clId: z.string().optional(), - chunkIndex: z.number().int().min(0).optional(), }) export type Citation = z.infer diff --git a/server/api/chat/utils.ts b/server/api/chat/utils.ts index d6a538a73..01d40361e 100644 --- a/server/api/chat/utils.ts +++ b/server/api/chat/utils.ts @@ -450,10 +450,7 @@ export const extractImageFileNames = ( return { imageFileNames } } -export const searchToCitation = ( - result: VespaSearchResults, - chunkIndex?: number, -): Citation => { +export const searchToCitation = (result: VespaSearchResults): Citation => { const fields = result.fields if (result.fields.sddocname === userSchema) { return { @@ -535,7 +532,6 @@ export const searchToCitation = ( entity: clFields.entity, itemId: clFields.itemId, clId: clFields.clId, - chunkIndex: chunkIndex, } } else if (result.fields.sddocname === chatContainerSchema) { return { @@ -558,7 +554,8 @@ const searchToCitations = (results: VespaSearchResults[]): Citation[] => { } export const textToCitationIndex = /\[(\d+)\]/g -export const textToImageCitationIndex = /\[(\d+_\d+)\]/g +export const textToImageCitationIndex = /(?