Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions server/api/chat/agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ import {
convertReasoningStepToText,
extractFileIdsFromMessage,
extractImageFileNames,
extractItemIdsFromPath,
getCitationToImage,
handleError,
isMessageWithContext,
Expand Down Expand Up @@ -3532,16 +3533,17 @@ export const AgentMessageApi = async (c: Context) => {
}
}
const isMsgWithContext = isMessageWithContext(message)
const extractedInfo =
isMsgWithContext || (path && ids)
? await extractFileIdsFromMessage(message, email, ids)
: {
totalValidFileIdsFromLinkCount: 0,
fileIds: [],
collectionFolderIds: [],
}
const extractedInfo = isMsgWithContext
? await extractFileIdsFromMessage(message, email)
: {
totalValidFileIdsFromLinkCount: 0,
fileIds: [],
collectionFolderIds: [],
}
const pathExtractedInfo = isValidPath
? await extractItemIdsFromPath(ids)
: { collectionFileIds: [], collectionFolderIds: [], collectionIds: [] }
let fileIds = extractedInfo?.fileIds
let folderIds = extractedInfo?.collectionFolderIds
if (nonImageAttachmentFileIds && nonImageAttachmentFileIds.length > 0) {
fileIds = [...fileIds, ...nonImageAttachmentFileIds]
}
Expand Down Expand Up @@ -3760,9 +3762,6 @@ export const AgentMessageApi = async (c: Context) => {
imageAttachmentFileIds,
agentPromptForLLM,
fileIds.length > 0,
actualModelId,
Boolean(isValidPath),
folderIds,
)
stream.writeSSE({
event: ChatSSEvents.Start,
Expand Down Expand Up @@ -4304,6 +4303,12 @@ export const AgentMessageApi = async (c: Context) => {
userRequestsReasoning,
understandSpan,
agentPromptForLLM,
actualModelId,
{
collectionfileIds: pathExtractedInfo.collectionFileIds,
collectionFolderIds: pathExtractedInfo.collectionFolderIds,
collectionIds: pathExtractedInfo.collectionIds,
},
)
stream.writeSSE({
event: ChatSSEvents.Start,
Expand Down
72 changes: 67 additions & 5 deletions server/api/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,7 @@ async function* generateIterativeTimeFilterAndQueryRewrite(
userRequestsReasoning?: boolean,
queryRagSpan?: Span,
agentPrompt?: string,
pathExtractedInfo?: pathExtractedInfo,
): AsyncIterableIterator<
ConverseResponse & {
citation?: { index: number; item: any }
Expand Down Expand Up @@ -1345,8 +1346,25 @@ async function* generateIterativeTimeFilterAndQueryRewrite(
const collectionIds: string[] = []
const collectionFolderIds: string[] = []
const collectionFileIds: string[] = []
let source = []
if (
pathExtractedInfo &&
(pathExtractedInfo.collectionfileIds.length ||
pathExtractedInfo.collectionFolderIds.length ||
pathExtractedInfo.collectionIds.length)
) {
if (pathExtractedInfo.collectionFolderIds.length) {
source = pathExtractedInfo.collectionFolderIds
} else if (pathExtractedInfo.collectionfileIds.length) {
source = pathExtractedInfo.collectionfileIds
} else {
source = pathExtractedInfo.collectionIds
}
} else {
source = selectedItems[Apps.KnowledgeBase]
}

for (const itemId of selectedItems[Apps.KnowledgeBase]) {
for (const itemId of source) {
if (itemId.startsWith("cl-")) {
// Entire collection - remove cl- prefix
collectionIds.push(itemId.replace(/^cl[-_]/, ""))
Expand Down Expand Up @@ -2483,6 +2501,7 @@ async function* generatePointQueryTimeExpansion(
userRequestsReasoning: boolean,
eventRagSpan?: Span,
agentPrompt?: string,
pathExtractedInfo?: pathExtractedInfo,
): AsyncIterableIterator<
ConverseResponse & {
citation?: { index: number; item: any }
Expand Down Expand Up @@ -2600,8 +2619,24 @@ async function* generatePointQueryTimeExpansion(
const collectionIds: string[] = []
const collectionFolderIds: string[] = []
const collectionFileIds: string[] = []

for (const itemId of selectedItems[Apps.KnowledgeBase]) {
let source = []
if (
pathExtractedInfo &&
(pathExtractedInfo.collectionfileIds.length ||
pathExtractedInfo.collectionFolderIds.length ||
pathExtractedInfo.collectionIds.length)
) {
if (pathExtractedInfo.collectionFolderIds.length) {
source = pathExtractedInfo.collectionFolderIds
} else if (pathExtractedInfo.collectionfileIds.length) {
source = pathExtractedInfo.collectionfileIds
} else {
source = pathExtractedInfo.collectionIds
}
} else {
source = selectedItems[Apps.KnowledgeBase]
}
for (const itemId of source) {
if (itemId.startsWith("cl-")) {
// Entire collection - remove cl- prefix
collectionIds.push(itemId.replace(/^cl[-_]/, ""))
Expand Down Expand Up @@ -3047,6 +3082,7 @@ async function* generateMetadataQueryAnswer(
agentPrompt?: string,
maxIterations = 5,
modelId?: string,
pathExtractedInfo?: pathExtractedInfo,
): AsyncIterableIterator<
ConverseResponse & {
citation?: { index: number; item: any }
Expand Down Expand Up @@ -3163,8 +3199,24 @@ async function* generateMetadataQueryAnswer(
const collectionIds: string[] = []
const collectionFolderIds: string[] = []
const collectionFileIds: string[] = []

for (const itemId of selectedItems[Apps.KnowledgeBase]) {
let source = []
if (
pathExtractedInfo &&
(pathExtractedInfo.collectionfileIds.length ||
pathExtractedInfo.collectionFolderIds.length ||
pathExtractedInfo.collectionIds.length)
) {
if (pathExtractedInfo.collectionFolderIds?.length) {
source = pathExtractedInfo.collectionFolderIds
} else if (pathExtractedInfo.collectionfileIds.length) {
source = pathExtractedInfo.collectionfileIds
} else {
source = pathExtractedInfo.collectionIds
}
} else {
source = selectedItems[Apps.KnowledgeBase]
}
for (const itemId of source) {
if (itemId.startsWith("cl-")) {
// Entire collection - remove cl- prefix
collectionIds.push(itemId.replace(/^cl[-_]/, ""))
Expand Down Expand Up @@ -3777,6 +3829,12 @@ const fallbackText = (classification: QueryRouterLLMResponse): string => {
return `${searchDescription}${timeDescription}`
}

export type pathExtractedInfo = {
collectionfileIds: string[]
collectionFolderIds: string[]
collectionIds: string[]
}

export async function* UnderstandMessageAndAnswer(
email: string,
userCtx: string,
Expand All @@ -3789,6 +3847,7 @@ export async function* UnderstandMessageAndAnswer(
passedSpan?: Span,
agentPrompt?: string,
modelId?: string,
pathExtractedInfo?: pathExtractedInfo,
): AsyncIterableIterator<
ConverseResponse & {
citation?: { index: number; item: any }
Expand Down Expand Up @@ -3836,6 +3895,7 @@ export async function* UnderstandMessageAndAnswer(
agentPrompt,
5,
modelId,
pathExtractedInfo,
)

let hasYieldedAnswer = false
Expand Down Expand Up @@ -3884,6 +3944,7 @@ export async function* UnderstandMessageAndAnswer(
userRequestsReasoning,
eventRagSpan,
agentPrompt,
pathExtractedInfo,
)
} else {
loggerWithChild({ email: email }).info(
Expand All @@ -3906,6 +3967,7 @@ export async function* UnderstandMessageAndAnswer(
userRequestsReasoning,
ragSpan,
agentPrompt, // Pass agentPrompt to generateIterativeTimeFilterAndQueryRewrite
pathExtractedInfo,
)
}
}
Expand Down
73 changes: 73 additions & 0 deletions server/api/chat/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ import {
getCollectionFoldersItemIds,
} from "@/db/knowledgeBase"
import { db } from "@/db/client"
import { collections, collectionItems } from "@/db/schema"
import { and, eq, isNull } from "drizzle-orm"
import { get } from "http"

function slackTs(ts: string | number) {
Expand Down Expand Up @@ -691,7 +693,78 @@ export const extractFileIdsFromMessage = async (
collectionFolderIds: collectionIds.filter(Boolean),
}
}
export const extractItemIdsFromPath = async (
pathRefId: any,
): Promise<{
collectionFileIds: string[]
collectionFolderIds: string[]
collectionIds: string[]
}> => {
const collectionFileIds: string[] = []
const collectionFolderIds: string[] = []
const collectionIds: string[] = []

// If pathRefId is null or undefined, return empty arrays
if (!pathRefId) {
return {
collectionFileIds,
collectionFolderIds,
collectionIds,
}
}

const vespaId = String(pathRefId)

try {
// Query collectionItems first to get id and type
const [item] = await db
.select({ id: collectionItems.id, type: collectionItems.type })
.from(collectionItems)
.where(
and(
eq(collectionItems.vespaDocId, vespaId),
isNull(collectionItems.deletedAt),
),
)

if (item) {
// Based on type, append the appropriate prefix
if (item.type === "file") {
collectionFileIds.push(`clf-${item.id}`)
} else if (item.type === "folder") {
collectionFolderIds.push(`clfd-${item.id}`)
}
} else {
// If not found in collectionItems, check collections table
const [collection] = await db
.select({ id: collections.id })
.from(collections)
.where(
and(
eq(collections.vespaDocId, vespaId),
isNull(collections.deletedAt),
),
)

if (collection) {
collectionIds.push(`cl-${collection.id}`)
}
}
} catch (error) {
// Log error but don't throw - return empty arrays
getLoggerWithChild(Subsystem.Chat)().error(
`Error extracting item IDs from pathRefId: ${vespaId}`,
error,
)
}

// Ensure we always return the same structure
return {
collectionFileIds,
collectionFolderIds,
collectionIds,
}
}
export const handleError = (error: any) => {
let errorMessage = "Something went wrong. Please try again."
if (error?.code === OpenAIError.RateLimitError) {
Expand Down
Loading