diff --git a/frontend/src/components/OAuthTab.tsx b/frontend/src/components/OAuthTab.tsx index 857e3ed22..7cebdafbe 100644 --- a/frontend/src/components/OAuthTab.tsx +++ b/frontend/src/components/OAuthTab.tsx @@ -11,11 +11,15 @@ import { CardHeader, CardTitle, } from "@/components/ui/card" +import { Button } from "@/components/ui/button" import { Apps, UserRole } from "shared/types" import { LoaderContent } from "@/lib/common" import { OAuthIntegrationStatus } from "@/types" import { X } from "lucide-react" import { ConfirmModal } from "@/components/ui/confirmModal" +import { api } from "@/api" +import { toast } from "@/hooks/use-toast" +import { getErrorMessage } from "@/lib/utils" interface OAuthTabProps { isPending: boolean @@ -24,6 +28,7 @@ interface OAuthTabProps { updateStatus: string handleDelete: () => void userRole: UserRole + connectorId?: string } const OAuthTab = ({ @@ -33,12 +38,14 @@ const OAuthTab = ({ updateStatus, handleDelete, userRole, + connectorId, }: OAuthTabProps) => { const [modalState, setModalState] = useState<{ open: boolean title: string description: string }>({ open: false, title: "", description: "" }) + const [isStartingIngestion, setIsStartingIngestion] = useState(false) const handleConfirmDelete = () => { handleDelete() @@ -59,6 +66,50 @@ const OAuthTab = ({ })) } + const handleStartIngestion = async () => { + if (!connectorId) { + toast({ + title: "Error", + description: "Connector ID not found", + variant: "destructive", + }) + return + } + + setIsStartingIngestion(true) + try { + // Role-based API routing + const isAdmin = + userRole === UserRole.Admin || userRole === UserRole.SuperAdmin + + const response = isAdmin + ? await api.admin.google.start_ingestion.$post({ + json: { connectorId }, + }) + : await api.google.start_ingestion.$post({ + json: { connectorId }, + }) + + if (response.ok) { + toast({ + title: "Ingestion Started", + description: "Data ingestion has been initiated successfully", + }) + setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuthConnecting) + } else { + throw new Error("Failed to start ingestion") + } + } catch (error) { + toast({ + title: "Failed to Start Ingestion", + description: `Error: ${getErrorMessage(error)}`, + variant: "destructive", + }) + } finally { + setIsStartingIngestion(false) + } + } + return ( {isPending ? ( @@ -84,6 +135,24 @@ const OAuthTab = ({ /> + ) : oauthIntegrationStatus === + OAuthIntegrationStatus.OAuthReadyForIngestion ? ( + + + Google OAuth + + OAuth authentication completed. Ready to start data ingestion. + + + + + + ) : ( diff --git a/frontend/src/routes/_authenticated/admin/integrations/google.tsx b/frontend/src/routes/_authenticated/admin/integrations/google.tsx index 4d06003c3..7d76416a1 100644 --- a/frontend/src/routes/_authenticated/admin/integrations/google.tsx +++ b/frontend/src/routes/_authenticated/admin/integrations/google.tsx @@ -1108,6 +1108,8 @@ const AdminLayout = ({ user, workspace, agentWhiteList }: AdminPageProps) => { setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuthConnecting) } else if (connector?.status === ConnectorStatus.Connected) { setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuthConnected) + } else if (connector?.status === ConnectorStatus.Authenticated) { + setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuthReadyForIngestion) } else if (connector?.status === ConnectorStatus.NotConnected) { setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuth) } else { @@ -1127,15 +1129,17 @@ const AdminLayout = ({ user, workspace, agentWhiteList }: AdminPageProps) => { const serviceAccountConnector = data.find( (c) => c.authType === AuthType.ServiceAccount, ) - const oauthConnector = data.find((c) => c.authType === AuthType.OAuth) + const oauthConnector = data.find( + (c) => c.app === Apps.GoogleDrive && c.authType === AuthType.OAuth, + ) if (serviceAccountConnector) { serviceAccountSocket = wsClient.ws.$ws({ - query: { id: serviceAccountConnector.id }, + query: { id: serviceAccountConnector.externalId }, }) serviceAccountSocket?.addEventListener("open", () => { logger.info( - `Service Account WebSocket opened for ${serviceAccountConnector.id}`, + `Service Account WebSocket opened for ${serviceAccountConnector.externalId}`, ) }) serviceAccountSocket?.addEventListener("message", (e) => { @@ -1313,6 +1317,13 @@ const AdminLayout = ({ user, workspace, agentWhiteList }: AdminPageProps) => { updateStatus={updateStatus} handleDelete={handleDelete} userRole={user.role} + connectorId={ + data?.find( + (v) => + v.app === Apps.GoogleDrive && + v.authType === AuthType.OAuth, + )?.id + } /> diff --git a/frontend/src/routes/_authenticated/admin/integrations/slack.tsx b/frontend/src/routes/_authenticated/admin/integrations/slack.tsx index b1ec200e0..ecb7b3f94 100644 --- a/frontend/src/routes/_authenticated/admin/integrations/slack.tsx +++ b/frontend/src/routes/_authenticated/admin/integrations/slack.tsx @@ -696,17 +696,33 @@ const SlackOAuthTab = ({ )} + ) : oauthIntegrationStatus === OAuthIntegrationStatus.OAuthReadyForIngestion ? ( + // If OAuth completed and ready for ingestion, show the Start Ingestion button +
+

+ OAuth authentication completed. Ready to start data ingestion. +

+ +
) : oauthIntegrationStatus === - OAuthIntegrationStatus.OAuthConnected || - oauthIntegrationStatus === - OAuthIntegrationStatus.OAuthConnecting ? ( - // If connected or connecting, show the Start Ingestion button + OAuthIntegrationStatus.OAuthConnected ? ( + // If connected, show the Start Ingestion button + ) : oauthIntegrationStatus === + OAuthIntegrationStatus.OAuthConnecting ? ( + // If connecting, show connecting status (same as Google) + "Connecting" ) : null} @@ -775,6 +791,8 @@ export const Slack = ({ setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuthConnecting) } else if (connector?.status === ConnectorStatus.Connected) { setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuthConnected) + } else if (connector?.status === ConnectorStatus.Authenticated) { + setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuthReadyForIngestion) } else if (connector?.status === ConnectorStatus.NotConnected) { setOAuthIntegrationStatus(OAuthIntegrationStatus.OAuth) } else if (connector?.status === ConnectorStatus.Paused) { @@ -1005,7 +1023,9 @@ export const Slack = ({ {oauthIntegrationStatus === OAuthIntegrationStatus.OAuthConnecting || oauthIntegrationStatus === - OAuthIntegrationStatus.OAuthConnected ? ( + OAuthIntegrationStatus.OAuthConnected || + oauthIntegrationStatus === + OAuthIntegrationStatus.OAuthReadyForIngestion ? ( { useEffect(() => { let socket: WebSocket | null = null if (!isPending && data && data.length > 0) { - const oauthConnector = data.find((c) => c.authType === AuthType.OAuth) + const oauthConnector = data.find( + (c) => c.app === Apps.GoogleDrive && c.authType === AuthType.OAuth, + ) if (oauthConnector) { socket = wsClient.ws.$ws({ query: { id: oauthConnector.id }, diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 58ff1d11b..981d52c23 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -51,6 +51,7 @@ export enum OAuthIntegrationStatus { Provider = "Provider", // yet to create provider OAuth = "OAuth", // provider created but OAuth not yet connected OAuthConnecting = "OAuthConnecting", + OAuthReadyForIngestion = "OAuthReadyForIngestion", // OAuth completed, ready to start ingestion OAuthConnected = "OAuthConnected", OAuthPaused = "OAuthPaused", } diff --git a/server/api/admin.ts b/server/api/admin.ts index b3cb0151b..fbb3066d9 100644 --- a/server/api/admin.ts +++ b/server/api/admin.ts @@ -110,7 +110,10 @@ import { ConnectorNotCreated, NoUserFound, } from "@/errors" -import { handleGoogleServiceAccountIngestion } from "@/integrations/google" +import { + handleGoogleOAuthIngestion, + handleGoogleServiceAccountIngestion, +} from "@/integrations/google" import { scopes } from "@/integrations/google/config" import { ServiceAccountIngestMoreUsers } from "@/integrations/google" import { handleSlackChannelIngestion } from "@/integrations/slack/channelIngest" @@ -1446,7 +1449,7 @@ export const AddStdioMCPConnector = async (c: Context) => { export const StartSlackIngestionApi = async (c: Context) => { const { sub } = c.get(JwtPayloadKey) // @ts-ignore - Assuming payload is validated by zValidator - const payload = c.req.valid("json") as { connectorId: string } + const payload = c.req.valid("json") as { connectorId: number } try { const userRes = await getUserByEmail(db, sub) @@ -1459,7 +1462,7 @@ export const StartSlackIngestionApi = async (c: Context) => { } const [user] = userRes - const connector = await getConnector(db, parseInt(payload.connectorId)) + const connector = await getConnector(db, payload.connectorId) if (!connector) { throw new HTTPException(404, { message: "Connector not found" }) } @@ -1494,11 +1497,64 @@ export const StartSlackIngestionApi = async (c: Context) => { } } +export const StartGoogleIngestionApi = async (c: Context) => { + const { sub } = c.get(JwtPayloadKey) + // @ts-ignore - Assuming payload is validated by zValidator + const payload = c.req.valid("json") as { connectorId: string } + try { + const userRes = await getUserByEmail(db, sub) + if (!userRes || !userRes.length) { + loggerWithChild({ email: sub }).error( + { sub }, + "No user found for sub in StartGoogleIngestionApi", + ) + throw new NoUserFound({}) + } + const [user] = userRes + + const connector = await getConnectorByExternalId( + db, + payload.connectorId, + user.id, + ) + if (!connector) { + throw new HTTPException(404, { message: "Connector not found" }) + } + + // Call the main Google ingestion function + handleGoogleOAuthIngestion({ + connectorId: connector.id, + app: connector.app as Apps, + externalId: connector.externalId, + authType: connector.authType as AuthType, + email: sub, + }).catch((error) => { + loggerWithChild({ email: sub }).error( + error, + `Background Google ingestion failed for connector ${connector.id}: ${getErrorMessage(error)}`, + ) + }) + + return c.json({ + success: true, + message: "Regular Google ingestion started.", + }) + } catch (error: any) { + loggerWithChild({ email: sub }).error( + error, + `Error starting regular Google ingestion: ${getErrorMessage(error)}`, + ) + if (error instanceof HTTPException) throw error + throw new HTTPException(500, { + message: `Failed to start regular Google ingestion: ${getErrorMessage(error)}`, + }) + } +} export const IngestMoreChannelApi = async (c: Context) => { const { sub } = c.get(JwtPayloadKey) // @ts-ignore const payload = c.req.valid("json") as { - connectorId: string + connectorId: number channelsToIngest: string[] startDate: string endDate: string @@ -1507,7 +1563,7 @@ export const IngestMoreChannelApi = async (c: Context) => { try { const email = sub const resp = await handleSlackChannelIngestion( - parseInt(payload.connectorId), + payload.connectorId, payload.channelsToIngest, payload.startDate, payload.endDate, diff --git a/server/api/oauth.ts b/server/api/oauth.ts index 5f04e2927..e9a2539b0 100644 --- a/server/api/oauth.ts +++ b/server/api/oauth.ts @@ -135,7 +135,7 @@ export const OAuthCallback = async (c: Context) => { const connector: SelectConnector = await updateConnector(db, connectorId, { subject: email, oauthCredentials: JSON.stringify(tokens), - status: ConnectorStatus.Connecting, + status: ConnectorStatus.Authenticated, }) const SaasJobPayload: SaaSOAuthJob = { connectorId: connector.id, @@ -146,13 +146,8 @@ export const OAuthCallback = async (c: Context) => { } if (IsGoogleApp(app)) { - // Start ingestion in the background, but catch any errors it might throw later - handleGoogleOAuthIngestion(SaasJobPayload).catch((error) => { - loggerWithChild({ email: email }).error( - error, - `Background Google OAuth ingestion failed for connector ${connector.id}: ${getErrorMessage(error)}`, - ) - }) + // moved the ingestion logic to sync-server , once the user will click on start ingestion + // ingestion will start on sync-server } else if (IsMicrosoftApp(app)) { handleMicrosoftOAuthIngestion(SaasJobPayload).catch((error) => { loggerWithChild({ email: email }).error( diff --git a/server/config.ts b/server/config.ts index 768c01373..3c0cfd964 100644 --- a/server/config.ts +++ b/server/config.ts @@ -63,7 +63,7 @@ let CurrentAuthType: AuthType = (process.env.AUTH_TYPE as AuthType) || AuthType.OAuth const MAX_IMAGE_SIZE_BYTES = 4 * 1024 * 1024 const MAX_SERVICE_ACCOUNT_FILE_SIZE_BYTES = 3 * 1024 // 3KB - generous limit for service account JSON files - +const AccessTokenCookie = "access-token" // TODO: // instead of TOGETHER_MODEL, OLLAMA_MODEL we should just have MODEL if present means they are selecting the model // since even docs have to be updated we can make this change in one go including that, so will be done later @@ -230,6 +230,7 @@ export default { defaultRecencyDecayRate: 0.1, // Decay rate for recency scoring in Vespa searches CurrentAuthType, getDatabaseUrl, + AccessTokenCookie, fileProcessingWorkerThreads, fileProcessingTeamSize, pdfFileProcessingWorkerThreads, diff --git a/server/integrations/google/index.ts b/server/integrations/google/index.ts index 473494cff..791f8f006 100644 --- a/server/integrations/google/index.ts +++ b/server/integrations/google/index.ts @@ -862,6 +862,18 @@ export const handleGoogleOAuthIngestion = async (data: SaaSOAuthJob) => { // const data: SaaSOAuthJob = job.data as SaaSOAuthJob const logger = loggerWithChild({ email: data.email }) try { + // Update status to Connecting when ingestion starts + try { + await db + .update(connectors) + .set({ + status: ConnectorStatus.Connecting, + }) + .where(eq(connectors.id, data.connectorId)) + } catch (error) { + logger.error(error, `Failed to update connector status to Connecting`) + throw error + } // we will first fetch the change token // and poll the changes in a new Cron Job const connector: SelectConnector = await getOAuthConnectorWithCredentials( @@ -2943,21 +2955,20 @@ export async function* listFiles( } do { - const res: any = - await retryWithBackoff( - () => - drive.files.list({ - q: query, - pageSize: 100, - fields: - "nextPageToken, files(id, webViewLink, size, parents, createdTime, modifiedTime, name, owners, fileExtension, mimeType, permissions(id, type, emailAddress))", - pageToken: nextPageToken, - }), - `Fetching all files from Google Drive`, - Apps.GoogleDrive, - 0, - client, - ) + const res: any = await retryWithBackoff( + () => + drive.files.list({ + q: query, + pageSize: 100, + fields: + "nextPageToken, files(id, webViewLink, size, parents, createdTime, modifiedTime, name, owners, fileExtension, mimeType, permissions(id, type, emailAddress))", + pageToken: nextPageToken, + }), + `Fetching all files from Google Drive`, + Apps.GoogleDrive, + 0, + client, + ) if (res.data.files) { yield res.data.files @@ -2994,17 +3005,16 @@ export const googleDocsVespa = async ( email: userEmail, }) try { - const docResponse: any = - await retryWithBackoff( - () => - docs.documents.get({ - documentId: doc.id as string, - }), - `Fetching document with documentId ${doc.id}`, - Apps.GoogleDrive, - 0, - client, - ) + const docResponse: any = await retryWithBackoff( + () => + docs.documents.get({ + documentId: doc.id as string, + }), + `Fetching document with documentId ${doc.id}`, + Apps.GoogleDrive, + 0, + client, + ) if (!docResponse || !docResponse.data) { throw new DocsParsingError( `Could not get document content for file: ${doc.id}`, @@ -3336,20 +3346,19 @@ export async function countDriveFiles( } do { - const res: any = - await retryWithBackoff( - () => - drive.files.list({ - q: query, - pageSize: 1000, - fields: "nextPageToken, files(id)", - pageToken: nextPageToken, - }), - `Counting Drive files (pageToken: ${nextPageToken || "initial"})`, - Apps.GoogleDrive, - 0, - client, - ) + const res: any = await retryWithBackoff( + () => + drive.files.list({ + q: query, + pageSize: 1000, + fields: "nextPageToken, files(id)", + pageToken: nextPageToken, + }), + `Counting Drive files (pageToken: ${nextPageToken || "initial"})`, + Apps.GoogleDrive, + 0, + client, + ) fileCount += res.data.files?.length || 0 nextPageToken = res.data.nextPageToken as string | undefined } while (nextPageToken) diff --git a/server/integrations/metricStream.ts b/server/integrations/metricStream.ts index c87368cbf..09e84b104 100644 --- a/server/integrations/metricStream.ts +++ b/server/integrations/metricStream.ts @@ -1,4 +1,18 @@ import type { WSContext } from "hono/ws" +import config from "@/config" +import { getLogger } from "@/logger" +import { Subsystem } from "@/types" + +const Logger = getLogger(Subsystem.Server).child({ module: "metricStream" }) + +// Lazy-loaded sync-server function to avoid circular dependencies +let syncServerModule: any = null +const getSyncServerModule = async () => { + if (!syncServerModule) { + syncServerModule = await import("../sync-server") + } + return syncServerModule +} export const wsConnections = new Map() @@ -6,11 +20,35 @@ export const closeWs = (connectorId: string) => { wsConnections.get(connectorId)?.close(1000, "Job finished") } -// TODO: scope it per user email who is integration -// if multiple people are doing oauth it should just work -export const sendWebsocketMessage = (message: string, connectorId: string) => { +// Function to send WebSocket message directly (when running on main server) +const sendWebsocketMessageDirect = (message: string, connectorId: string) => { const ws: WSContext = wsConnections.get(connectorId) if (ws) { ws.send(JSON.stringify({ message })) } } + +// Function to send WebSocket message via sync-server WebSocket connection +const sendWebsocketMessageViaSyncServerWS = async ( + message: string, + connectorId: string, +) => { + try { + const { sendWebsocketMessageToMainServer } = await getSyncServerModule() + sendWebsocketMessageToMainServer(message, connectorId) + } catch (error) { + Logger.error( + error, + "Error sending WebSocket message via sync-server WebSocket - message will be lost", + ) + } +} + +// TODO: scope it per user email who is integration +// if multiple people are doing oauth it should just work +export const sendWebsocketMessage = (message: string, connectorId: string) => { + // Always forward to main server from sync-server + // The main server will then forward to the frontend client + // This ensures the correct flow: Sync-Server → Main Server → Frontend Client + sendWebsocketMessageViaSyncServerWS(message, connectorId) +} diff --git a/server/integrations/slack/index.ts b/server/integrations/slack/index.ts index 7467ae8f3..1798bc959 100644 --- a/server/integrations/slack/index.ts +++ b/server/integrations/slack/index.ts @@ -884,7 +884,22 @@ export const handleSlackIngestion = async (data: SaaSOAuthJob) => { db, data.connectorId, ) - + // change the status of connector to connecting + // before update the status is authenticated + try { + await db + .update(connectors) + .set({ + status: ConnectorStatus.Connecting, + }) + .where(eq(connectors.id, data.connectorId)) + } catch (error) { + loggerWithChild({ email: data.email }).error( + error, + `Failed to update connector status to Connecting`, + ) + throw error + } // Initialize ingestion state const initialState: SlackOAuthIngestionState = { app: Apps.Slack, diff --git a/server/server.ts b/server/server.ts index c1e6d9ef9..135b4084c 100644 --- a/server/server.ts +++ b/server/server.ts @@ -62,8 +62,6 @@ import { GetConnectorTools, // Added GetConnectorTools UpdateToolsStatusApi, // Added for tool status updates AdminDeleteUserData, - IngestMoreChannelApi, - StartSlackIngestionApi, GetProviders, GetAdminChats, GetAdminAgents, @@ -78,8 +76,6 @@ import { agentAnalysisQuerySchema, AddServiceConnectionMicrosoft, UpdateUser, - HandlePerUserSlackSync, - HandlePerUserGoogleWorkSpaceSync, ListAllLoggedInUsers, ListAllIngestedUsers, GetKbVespaContent, @@ -174,7 +170,11 @@ import { CreateApiKeySchema, getDocumentSchema, } from "@/shared/types" // Import Apps -import { wsConnections } from "@/integrations/metricStream" +import { + wsConnections, + sendWebsocketMessage, +} from "@/integrations/metricStream" + import { EvaluateHandler, ListDatasetsHandler, @@ -316,7 +316,7 @@ const postOauthRedirect = config.postOauthRedirect const accessTokenSecret = process.env.ACCESS_TOKEN_SECRET! const refreshTokenSecret = process.env.REFRESH_TOKEN_SECRET! -const AccessTokenCookieName = "access-token" +const AccessTokenCookieName = config.AccessTokenCookie const RefreshTokenCookieName = "refresh-token" const Logger = getLogger(Subsystem.Server) @@ -1150,15 +1150,14 @@ export const AppRoutes = app zValidator("form", createOAuthProvider), CreateOAuthProvider, ) - .post( - "/slack/ingest_more_channel", - zValidator("json", ingestMoreChannelSchema), - IngestMoreChannelApi, + .post("/slack/ingest_more_channel", (c) => + proxyToSyncServer(c, "/slack/ingest_more_channel"), ) - .post( - "/slack/start_ingestion", - zValidator("json", startSlackIngestionSchema), - StartSlackIngestionApi, + .post("/slack/start_ingestion", (c) => + proxyToSyncServer(c, "/slack/start_ingestion"), + ) + .post("/google/start_ingestion", (c) => + proxyToSyncServer(c, "/google/start_ingestion"), ) .delete( "/oauth/connector/delete", @@ -1182,8 +1181,10 @@ export const AppRoutes = app .get("/list_loggedIn_users", ListAllLoggedInUsers) .get("/list_ingested_users", ListAllIngestedUsers) .post("/change_role", zValidator("form", UserRoleChangeSchema), UpdateUser) - .post("/syncGoogleWorkSpaceByMail", HandlePerUserGoogleWorkSpaceSync) - .post("syncSlackByMail", HandlePerUserSlackSync) + .post("/syncGoogleWorkSpaceByMail", (c) => + proxyToSyncServer(c, "/syncGoogleWorkSpaceByMail"), + ) + .post("syncSlackByMail", (c) => proxyToSyncServer(c, "/syncSlackByMail")) // create the provider + connector .post( "/oauth/create", @@ -1195,15 +1196,14 @@ export const AppRoutes = app zValidator("form", microsoftServiceSchema), AddServiceConnectionMicrosoft, ) - .post( - "/slack/ingest_more_channel", - zValidator("json", ingestMoreChannelSchema), - IngestMoreChannelApi, + .post("/slack/ingest_more_channel", (c) => + proxyToSyncServer(c, "/slack/ingest_more_channel"), ) - .post( - "/slack/start_ingestion", - zValidator("json", startSlackIngestionSchema), - StartSlackIngestionApi, + .post("/slack/start_ingestion", (c) => + proxyToSyncServer(c, "/slack/start_ingestion"), + ) + .post("/google/start_ingestion", (c) => + proxyToSyncServer(c, "/google/start_ingestion"), ) .delete( "/oauth/connector/delete", @@ -1228,10 +1228,8 @@ export const AppRoutes = app zValidator("form", updateServiceConnectionSchema), UpdateServiceConnection, ) - .post( - "/google/service_account/ingest_more", - zValidator("json", serviceAccountIngestMoreSchema), - ServiceAccountIngestMoreUsersApi, + .post("/google/service_account/ingest_more", (c) => + proxyToSyncServer(c, "/google/service_account/ingest_more"), ) // create the provider + connector .post( @@ -1312,6 +1310,58 @@ export const AppRoutes = app GetAllUserFeedbackMessages, ) +// WebSocket endpoint for sync-server connections +export const SyncServerWsApp = app.get( + "/internal/sync-websocket", + upgradeWebSocket((c) => { + // Verify authentication + const authHeader = c.req.header("Authorization") + const expectedSecret = process.env.METRICS_SECRET + + if ( + !authHeader || + !authHeader.startsWith("Bearer ") || + authHeader.slice(7) !== expectedSecret + ) { + Logger.warn("Unauthorized sync-server WebSocket connection attempt") + return { + onOpen() { + // Close immediately if unauthorized + }, + } + } + + return { + onOpen(event, ws) { + Logger.info("Sync-server WebSocket connected") + }, + onMessage(event, ws) { + try { + const { message, connectorId } = JSON.parse(event.data.toString()) + + // Forward message to the appropriate frontend WebSocket connection + const frontendWs = wsConnections.get(connectorId) + if (frontendWs) { + frontendWs.send(JSON.stringify({ message })) + Logger.info( + `WebSocket message forwarded from sync-server to frontend for connector ${connectorId}`, + ) + } else { + Logger.warn( + `No frontend WebSocket connection found for connector ${connectorId}`, + ) + } + } catch (error) { + Logger.error(error, "Error processing sync-server WebSocket message") + } + }, + onClose: (event, ws) => { + Logger.info("Sync-server WebSocket connection closed") + }, + } + }), +) + app.get("/oauth/callback", AuthMiddleware, OAuthCallback) app.get( "/oauth/start", @@ -1353,6 +1403,48 @@ app .post("/cl/:clId/items/upload", UploadFilesApi) // Upload files to KB .delete("/cl/:clId/items/:itemId", DeleteItemApi) // Delete Item in KB +// Proxy function to forward ingestion API calls to sync server +const proxyToSyncServer = async (c: Context, endpoint: string) => { + try { + // Get JWT token from cookie + const token = getCookie(c, AccessTokenCookieName) + if (!token) { + throw new HTTPException(401, { message: "No authentication token" }) + } + + // Get request body + const body = await c.req.json() + + // Forward to sync server + const response = await fetch( + `http://localhost:${config.syncServerPort}${endpoint}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + Cookie: `${AccessTokenCookieName}=${token}`, + }, + body: JSON.stringify(body), + }, + ) + + if (!response.ok) { + const errorData = await response + .json() + .catch(() => ({ message: "Proxy request failed" })) + throw new HTTPException(response.status as any, { + message: errorData.message || "Proxy request failed", + }) + } + + return c.json(await response.json()) + } catch (error) { + if (error instanceof HTTPException) throw error + Logger.error(error, `Proxy request to ${endpoint} failed`) + throw new HTTPException(500, { message: "Proxy request failed" }) + } +} + const generateTokens = async ( email: string, role: string, @@ -1653,7 +1745,7 @@ app.get("/*", AuthRedirect, serveStatic({ path: "./dist/index.html" })) export const init = async () => { // Initialize API server queue (only FileProcessingQueue, no workers) await initApiServerQueue() - + if (isSlackEnabled()) { Logger.info("Slack Web API client initialized and ready.") try { diff --git a/server/shared/types.ts b/server/shared/types.ts index ffa346be8..c5f47fa9c 100644 --- a/server/shared/types.ts +++ b/server/shared/types.ts @@ -138,10 +138,12 @@ export enum ConnectorStatus { Connected = "connected", // Pending = 'pending', Connecting = "connecting", + Paused = "paused", Failed = "failed", // for oauth we will default to this NotConnected = "not-connected", + Authenticated = "authenticated", } export enum SyncJobStatus { diff --git a/server/sync-server.ts b/server/sync-server.ts index 4bf17de8d..d2976dd07 100644 --- a/server/sync-server.ts +++ b/server/sync-server.ts @@ -2,20 +2,111 @@ import { Hono } from "hono" import { init as initQueue } from "@/queue" import config from "@/config" import { getLogger, LogMiddleware } from "@/logger" -import { Subsystem } from "@/types" +import { startGoogleIngestionSchema, Subsystem } from "@/types" import { InitialisationError } from "@/errors" import metricRegister from "@/metrics/sharedRegistry" import { isSlackEnabled, startSocketMode } from "@/integrations/slack/client" +import { jwt } from "hono/jwt" +import { zValidator } from "@hono/zod-validator" +import { + IngestMoreChannelApi, + StartSlackIngestionApi, + ServiceAccountIngestMoreUsersApi, + HandlePerUserSlackSync, + HandlePerUserGoogleWorkSpaceSync, + StartGoogleIngestionApi, +} from "@/api/admin" +import { + ingestMoreChannelSchema, + startSlackIngestionSchema, + serviceAccountIngestMoreSchema, +} from "@/types" +import { db } from "@/db/client" +import { getUserByEmail } from "@/db/user" +import type { JwtVariables } from "hono/jwt" +import type { Context, Next } from "hono" +import WebSocket from "ws" import { Worker } from "worker_threads" import path from "path" const Logger = getLogger(Subsystem.SyncServer) -const app = new Hono() +const app = new Hono<{ Variables: JwtVariables }>() const honoMiddlewareLogger = LogMiddleware(Subsystem.SyncServer) -// Add logging middleware +// WebSocket connection to main server for forwarding stats +let mainServerWebSocket: WebSocket | null = null + +const connectToMainServer = () => { + const mainServerUrl = `ws://localhost:${config.port}/internal/sync-websocket` + const authSecret = process.env.METRICS_SECRET + mainServerWebSocket = new WebSocket(mainServerUrl, { + headers: { + Authorization: `Bearer ${authSecret}`, + }, + }) + + mainServerWebSocket.on("open", () => {}) + + mainServerWebSocket.on("error", (error) => { + Logger.error(error, "WebSocket connection to main server failed") + mainServerWebSocket = null + // Retry connection after 5 seconds + setTimeout(connectToMainServer, 5000) + }) + + mainServerWebSocket.on("close", () => { + mainServerWebSocket = null + // Retry connection after 5 seconds + setTimeout(connectToMainServer, 5000) + }) +} + +// Function to send WebSocket message to main server +export const sendWebsocketMessageToMainServer = ( + message: string, + connectorId: string, +) => { + if ( + mainServerWebSocket && + mainServerWebSocket.readyState === WebSocket.OPEN + ) { + try { + mainServerWebSocket.send(JSON.stringify({ message, connectorId })) + } catch (error) { + Logger.error( + error, + `Failed to send WebSocket message for connector ${connectorId} - message lost`, + ) + } + } else { + Logger.warn( + `Cannot send WebSocket message - connection not available for connector ${connectorId}. Connection state: ${mainServerWebSocket?.readyState || "null"}. Message lost.`, + ) + + // Try to reconnect if connection is not available + if ( + !mainServerWebSocket || + mainServerWebSocket.readyState === WebSocket.CLOSED + ) { + Logger.info("Attempting to reconnect to main server...") + connectToMainServer() + } + } +} + +// JWT Authentication middleware +const accessTokenSecret = process.env.ACCESS_TOKEN_SECRET! +const AccessTokenCookieName = config.AccessTokenCookie +const { JwtPayloadKey } = config + +const AuthMiddleware = jwt({ + secret: accessTokenSecret, + cookie: AccessTokenCookieName, +}) + +// Add logging middleware to all routes app.use("*", honoMiddlewareLogger) // Health check endpoint @@ -46,58 +137,109 @@ app.get("/status", (c) => { }) }) +// // Protected ingestion API routes - require JWT authentication +app.use("*", AuthMiddleware) + +// Slack ingestion APIs +app.post( + "/slack/ingest_more_channel", + zValidator("json", ingestMoreChannelSchema), + IngestMoreChannelApi, +) + +app.post( + "/slack/start_ingestion", + zValidator("json", startSlackIngestionSchema), + StartSlackIngestionApi, +) + +// Google Workspace APIs +app.post( + "/google/service_account/ingest_more", + zValidator("json", serviceAccountIngestMoreSchema), + ServiceAccountIngestMoreUsersApi, +) + +app.post( + "/google/start_ingestion", + zValidator("json", startGoogleIngestionSchema), + StartGoogleIngestionApi, +) +// Sync APIs +app.post("/syncSlackByMail", HandlePerUserSlackSync) +app.post("/syncGoogleWorkSpaceByMail", HandlePerUserGoogleWorkSpaceSync) + const startAndMonitorWorkers = ( workerScript: string, workerType: string, count: number, workerThreads: Worker[], - arrayIndexOffset: number + arrayIndexOffset: number, ) => { Logger.info(`Starting ${count} ${workerType} processing worker threads...`) for (let i = 0; i < count; i++) { const workerIndexForLogging = i + 1 const workerArrayIndex = arrayIndexOffset + i - const worker = new Worker(path.join(__dirname, workerScript)) workerThreads.push(worker) worker.on("message", (message) => { if (message.status === "initialized") { - Logger.info(`${workerType} processing worker thread ${workerIndexForLogging} initialized successfully`) + Logger.info( + `${workerType} processing worker thread ${workerIndexForLogging} initialized successfully`, + ) } else if (message.status === "error") { - Logger.error(`${workerType} processing worker thread ${workerIndexForLogging} failed: ${message.error}`) + Logger.error( + `${workerType} processing worker thread ${workerIndexForLogging} failed: ${message.error}`, + ) } }) worker.on("error", (error) => { - Logger.error(error, `${workerType} processing worker thread ${workerIndexForLogging} error`) + Logger.error( + error, + `${workerType} processing worker thread ${workerIndexForLogging} error`, + ) }) worker.on("exit", (code) => { if (code !== 0) { - Logger.error(`${workerType} processing worker thread ${workerIndexForLogging} exited with code ${code}`) - - Logger.info(`Restarting ${workerType} processing worker thread ${workerIndexForLogging}...`) + Logger.error( + `${workerType} processing worker thread ${workerIndexForLogging} exited with code ${code}`, + ) + + Logger.info( + `Restarting ${workerType} processing worker thread ${workerIndexForLogging}...`, + ) const newWorker = new Worker(path.join(__dirname, workerScript)) workerThreads[workerArrayIndex] = newWorker - + // Re-attach event listeners for the new worker newWorker.on("message", (message) => { if (message.status === "initialized") { - Logger.info(`${workerType} processing worker thread ${workerIndexForLogging} restarted and initialized successfully`) + Logger.info( + `${workerType} processing worker thread ${workerIndexForLogging} restarted and initialized successfully`, + ) } else if (message.status === "error") { - Logger.error(`${workerType} processing worker thread ${workerIndexForLogging} failed: ${message.error}`) + Logger.error( + `${workerType} processing worker thread ${workerIndexForLogging} failed: ${message.error}`, + ) } }) - + newWorker.on("error", (error) => { - Logger.error(error, `${workerType} processing worker thread ${workerIndexForLogging} error`) + Logger.error( + error, + `${workerType} processing worker thread ${workerIndexForLogging} error`, + ) }) - + newWorker.on("exit", (code) => { if (code !== 0) { - Logger.error(`${workerType} processing worker thread ${workerIndexForLogging} exited with code ${code}`) + Logger.error( + `${workerType} processing worker thread ${workerIndexForLogging} exited with code ${code}`, + ) } }) } @@ -114,8 +256,20 @@ export const initSyncServer = async () => { const pdfWorkerCount = config.pdfFileProcessingWorkerThreads // Start workers using the helper function - startAndMonitorWorkers("fileProcessingWorker.ts", "File", fileWorkerCount, workerThreads, 0) - startAndMonitorWorkers("pdfFileProcessingWorker.ts", "PDF file", pdfWorkerCount, workerThreads, fileWorkerCount) + startAndMonitorWorkers( + "fileProcessingWorker.ts", + "File", + fileWorkerCount, + workerThreads, + 0, + ) + startAndMonitorWorkers( + "pdfFileProcessingWorker.ts", + "PDF file", + pdfWorkerCount, + workerThreads, + fileWorkerCount, + ) // Initialize the queue system in background - don't await (excluding file processing) initQueue() @@ -126,6 +280,8 @@ export const initSyncServer = async () => { Logger.error(error, "Failed to initialize queue system") }) + // Connect to main server WebSocket + connectToMainServer() Logger.info("Sync Server initialization completed") } diff --git a/server/types.ts b/server/types.ts index a7bf2b193..0438ab7fc 100644 --- a/server/types.ts +++ b/server/types.ts @@ -563,6 +563,9 @@ export const ingestMoreChannelSchema = z.object({ export const startSlackIngestionSchema = z.object({ connectorId: z.number(), }) +export const startGoogleIngestionSchema = z.object({ + connectorId: z.string(), +}) export type EntityType = | DriveEntity