Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 121 additions & 14 deletions enterprise/workers/socket/src/durable-objects/websocket-room.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ export class WebSocketRoom extends DurableObject<IEnv> {
return new Response('Missing JWT token', { status: 400 });
}

const contextKeys = this.extractContextKeysFromHeader(request);

const [client, server] = Object.values(new WebSocketPair());

/*
Expand All @@ -64,6 +66,7 @@ export class WebSocketRoom extends DurableObject<IEnv> {
server.serializeAttachment({
jwtToken,
connectedAt: Date.now(),
contextKeys,
});

// Use waitUntil to allow hibernation without waiting for API call
Expand Down Expand Up @@ -122,7 +125,7 @@ export class WebSocketRoom extends DurableObject<IEnv> {
/**
* Send message to a specific user
*/
async sendToUser(userId: string, event: string, data: unknown): Promise<void> {
async sendToUser(userId: string, event: string, data: unknown, contextKeys?: string[]): Promise<void> {
const userConnections = this.ctx.getWebSockets(`user:${userId}`);

if (userConnections.length === 0) {
Expand All @@ -136,20 +139,30 @@ export class WebSocketRoom extends DurableObject<IEnv> {
timestamp: Date.now(),
});

// Send to all user connections in parallel using Promise.allSettled
// This prevents one failed connection from blocking others
const sendPromises = userConnections.map(async (ws) => {
try {
ws.send(message);
} catch (error) {
console.error(`Failed to send message to user ${userId}:`, error);
// Connection will be cleaned up automatically by Cloudflare
throw error; // Re-throw to be caught by Promise.allSettled
}
});
if (this.isFeatureFlagOff(contextKeys)) {
await this.broadcastToAllSockets(userId, message, userConnections);
} else {
await this.sendToMatchingContexts(userId, message, contextKeys, userConnections);
}
}

// Wait for all sends to complete, but don't fail if some connections error
await Promise.allSettled(sendPromises);
/**
* Context matching logic (same as ws.gateway.ts)
*/
private isExactMatch(messageContextKeys: string[], inboxContextKeys?: string[]): boolean {
if (inboxContextKeys === undefined) {
return true;
}

if (messageContextKeys.length === 0) {
return inboxContextKeys.length === 0;
}

if (messageContextKeys.length !== inboxContextKeys.length) {
return false;
}

return messageContextKeys.every((key) => inboxContextKeys.includes(key));
}

/**
Expand Down Expand Up @@ -255,6 +268,7 @@ export class WebSocketRoom extends DurableObject<IEnv> {
environmentId,
connectedAt: attachment.connectedAt || Date.now(),
jwtToken: attachment.jwtToken,
contextKeys: attachment.contextKeys,
};
}

Expand All @@ -276,4 +290,97 @@ export class WebSocketRoom extends DurableObject<IEnv> {
);
}
}

private extractContextKeysFromHeader(request: Request): string[] | undefined {
const contextKeysHeader = request.headers.get('X-Context-Keys');

if (!contextKeysHeader) {
return undefined;
}

try {
return JSON.parse(contextKeysHeader);
} catch (e) {
console.error('Failed to parse contextKeys:', e);

return undefined;
}
}

/**
* Check if feature flag is OFF (contextKeys is undefined)
*/
private isFeatureFlagOff(contextKeys?: string[]): boolean {
return contextKeys === undefined;
}

/**
* Broadcast message to all user connections (FF OFF behavior)
*/
private async broadcastToAllSockets(userId: string, message: string, sockets: WebSocket[]): Promise<void> {
const sendPromises = sockets.map(async (ws) => {
try {
ws.send(message);
} catch (error) {
console.error(`Failed to send message to user ${userId}:`, error);
throw error;
}
});

await Promise.allSettled(sendPromises);
}

/**
* Send message only to sockets with matching contexts (FF ON behavior)
*/
private async sendToMatchingContexts(
userId: string,
message: string,
messageContextKeys: string[] | undefined,
sockets: WebSocket[]
): Promise<void> {
if (!messageContextKeys) {
return;
}

const sendPromises = sockets.map(async (ws) => {
const metadata = this.getConnectionMetadata(ws);

if (!metadata) {
return;
}

const inboxContextKeys = metadata.contextKeys;

if (this.shouldDeliverMessage(messageContextKeys, inboxContextKeys)) {
await this.deliverMessageToSocket(ws, message, userId, inboxContextKeys);
}
});

await Promise.allSettled(sendPromises);
}

/**
* Determine if message should be delivered based on context match
*/
private shouldDeliverMessage(messageContextKeys: string[], inboxContextKeys?: string[]): boolean {
return this.isExactMatch(messageContextKeys, inboxContextKeys);
}

/**
* Deliver message to a specific socket
*/
private async deliverMessageToSocket(
ws: WebSocket,
message: string,
userId: string,
_inboxContextKeys?: string[]
): Promise<void> {
try {
ws.send(message);
} catch (error) {
console.error(`Failed to send message to user ${userId}:`, error);
throw error;
}
}
}
10 changes: 7 additions & 3 deletions enterprise/workers/socket/src/handlers/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export async function handleWebSocketUpgrade(context: Context) {
const subscriberId = context.get('subscriberId');
const organizationId = context.get('organizationId');
const environmentId = context.get('environmentId');
const contextKeys = context.get('contextKeys');

// Extract JWT token from query parameter
const jwtToken = context.req.query('token');
Expand All @@ -28,6 +29,7 @@ export async function handleWebSocketUpgrade(context: Context) {
'X-Organization-Id': organizationId,
'X-Environment-Id': environmentId,
'X-JWT-Token': jwtToken || '',
'X-Context-Keys': contextKeys !== undefined ? JSON.stringify(contextKeys) : '',
},
body: context.req.raw.body,
});
Expand All @@ -38,7 +40,7 @@ export async function handleWebSocketUpgrade(context: Context) {
// Send message handler - Protected by internal API key authentication
export async function handleSendMessage(context: Context) {
try {
const { userId, event, data, environmentId } = await context.req.json();
const { userId, event, data, environmentId, contextKeys } = await context.req.json();

// Validate required fields
if (!userId || !event) {
Expand All @@ -57,7 +59,9 @@ export async function handleSendMessage(context: Context) {
// Create room ID based on environment and user
const roomId = `${environmentId}:${userId}`;

console.log(`[Internal API] Routing message to room: ${roomId} for user: ${userId}, event: ${event}`);
console.log(
`[Internal API] Routing message to room: ${roomId} for user: ${userId}, event: ${event}, contextKeys: ${contextKeys}`
);

/*
* Get the Durable Object instance for the appropriate room
Expand All @@ -69,7 +73,7 @@ export async function handleSendMessage(context: Context) {
const id = namespace.idFromName(roomId);
const stub = namespace.get(id);

await stub.sendToUser(userId, event, data);
await stub.sendToUser(userId, event, data, contextKeys);

return context.json({ success: true, roomId, timestamp: new Date().toISOString() });
} catch (error) {
Expand Down
2 changes: 2 additions & 0 deletions enterprise/workers/socket/src/middleware/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export async function authenticateJWT(context: Context, next: Next) {
const userId = userPayload._id;
const subscriberId = userPayload.subscriberId || userId;
const { organizationId, environmentId } = userPayload;
const contextKeys = userPayload.contextKeys;

if (!userId || !subscriberId || !organizationId || !environmentId) {
return context.json({ error: 'Unauthorized: Missing required user information in JWT' }, 401);
Expand All @@ -53,6 +54,7 @@ export async function authenticateJWT(context: Context, next: Next) {
context.set('subscriberId', subscriberId);
context.set('organizationId', organizationId);
context.set('environmentId', environmentId);
context.set('contextKeys', contextKeys);

await next();
} catch (error) {
Expand Down
3 changes: 2 additions & 1 deletion enterprise/workers/socket/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ export interface IConnectionMetadata {
environmentId: string;
connectedAt: number;
jwtToken: string;
contextKeys?: string[];
}

export interface IWebSocketRoom {
sendToUser(userId: string, event: string, data: unknown): Promise<void>;
sendToUser(userId: string, event: string, data: unknown, contextKeys?: string[]): Promise<void>;
getActiveConnectionsForUser(userId: string): number;
getTotalActiveConnections(): number;
getConnectionCapacity(): { current: number; max: number; available: number };
Expand Down
Loading