Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 172 additions & 14 deletions enterprise/workers/socket/src/durable-objects/websocket-room.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ export class WebSocketRoom extends DurableObject<IEnv> {
return new Response('Missing JWT token', { status: 400 });
}

const contextKeys = this.extractContextKeysFromHeader(request);

const [client, server] = Object.values(new WebSocketPair());

/*
Expand All @@ -64,8 +66,11 @@ export class WebSocketRoom extends DurableObject<IEnv> {
server.serializeAttachment({
jwtToken,
connectedAt: Date.now(),
contextKeys,
});

this.logConnectionAccepted(userId, contextKeys);

// Use waitUntil to allow hibernation without waiting for API call
this.ctx.waitUntil(
this.notifySubscriberOnlineState(userId, environmentId, true, undefined, jwtToken).catch((error) =>
Expand Down Expand Up @@ -122,7 +127,7 @@ export class WebSocketRoom extends DurableObject<IEnv> {
/**
* Send message to a specific user
*/
async sendToUser(userId: string, event: string, data: unknown): Promise<void> {
async sendToUser(userId: string, event: string, data: unknown, contextKeys?: string[]): Promise<void> {
const userConnections = this.ctx.getWebSockets(`user:${userId}`);

if (userConnections.length === 0) {
Expand All @@ -136,20 +141,30 @@ export class WebSocketRoom extends DurableObject<IEnv> {
timestamp: Date.now(),
});

// Send to all user connections in parallel using Promise.allSettled
// This prevents one failed connection from blocking others
const sendPromises = userConnections.map(async (ws) => {
try {
ws.send(message);
} catch (error) {
console.error(`Failed to send message to user ${userId}:`, error);
// Connection will be cleaned up automatically by Cloudflare
throw error; // Re-throw to be caught by Promise.allSettled
}
});
if (this.isFeatureFlagOff(contextKeys)) {
await this.broadcastToAllSockets(userId, event, message, userConnections);
} else {
await this.sendToMatchingContexts(userId, event, message, contextKeys, userConnections);
}
}

// Wait for all sends to complete, but don't fail if some connections error
await Promise.allSettled(sendPromises);
/**
* Context matching logic (same as ws.gateway.ts)
*/
private isExactMatch(messageContextKeys: string[], inboxContextKeys?: string[]): boolean {
if (inboxContextKeys === undefined) {
return true;
}

if (messageContextKeys.length === 0) {
return inboxContextKeys.length === 0;
}

if (messageContextKeys.length !== inboxContextKeys.length) {
return false;
}

return messageContextKeys.every((key) => inboxContextKeys.includes(key));
}

/**
Expand Down Expand Up @@ -255,6 +270,7 @@ export class WebSocketRoom extends DurableObject<IEnv> {
environmentId,
connectedAt: attachment.connectedAt || Date.now(),
jwtToken: attachment.jwtToken,
contextKeys: attachment.contextKeys,
};
}

Expand All @@ -276,4 +292,146 @@ export class WebSocketRoom extends DurableObject<IEnv> {
);
}
}

private extractContextKeysFromHeader(request: Request): string[] | undefined {
const contextKeysHeader = request.headers.get('X-Context-Keys');

if (!contextKeysHeader) {
return undefined;
}

try {
return JSON.parse(contextKeysHeader);
} catch (e) {
console.error('Failed to parse contextKeys:', e);

return undefined;
}
}

/**
* Log connection acceptance with context information
*/
private logConnectionAccepted(userId: string, contextKeys?: string[]): void {
const contextDisplay = this.formatContextDisplay(contextKeys);
console.log(`Connection accepted for ${userId} with contexts: ${contextDisplay}`);
}

/**
* Format context keys for display in logs
*/
private formatContextDisplay(contextKeys?: string[]): string {
if (contextKeys === undefined) {
return 'FF disabled';
}

if (contextKeys.length === 0) {
return 'no context';
}

return contextKeys.join(', ');
}

/**
* Check if feature flag is OFF (contextKeys is undefined)
*/
private isFeatureFlagOff(contextKeys?: string[]): boolean {
return contextKeys === undefined;
}

/**
* Broadcast message to all user connections (FF OFF behavior)
*/
private async broadcastToAllSockets(
userId: string,
event: string,
message: string,
sockets: WebSocket[]
): Promise<void> {
console.log(`Sending event ${event} to all ${sockets.length} socket(s) (FF disabled)`);

const sendPromises = sockets.map(async (ws) => {
try {
ws.send(message);
} catch (error) {
console.error(`Failed to send message to user ${userId}:`, error);
throw error;
}
});

await Promise.allSettled(sendPromises);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part can be deleted when we remove the feature flag


/**
* Send message only to sockets with matching contexts (FF ON behavior)
*/
private async sendToMatchingContexts(
userId: string,
event: string,
message: string,
messageContextKeys: string[] | undefined,
sockets: WebSocket[]
): Promise<void> {
if (!messageContextKeys) {
return;
}
console.log(
`Sending event ${event} to ${userId} with message contexts: ${this.formatContextDisplay(messageContextKeys)} (${sockets.length} socket(s))`
);

const sendPromises = sockets.map(async (ws) => {
const metadata = this.getConnectionMetadata(ws);

if (!metadata) {
console.warn(`No metadata found for socket, skipping`);

return;
}

const inboxContextKeys = metadata.contextKeys;

if (this.shouldDeliverMessage(messageContextKeys, inboxContextKeys)) {
await this.deliverMessageToSocket(ws, message, userId, inboxContextKeys);
} else {
this.logContextMismatch(messageContextKeys, inboxContextKeys);
}
});

await Promise.allSettled(sendPromises);
}

/**
* Determine if message should be delivered based on context match
*/
private shouldDeliverMessage(messageContextKeys: string[], inboxContextKeys?: string[]): boolean {
return this.isExactMatch(messageContextKeys, inboxContextKeys);
}

/**
* Deliver message to a specific socket
*/
private async deliverMessageToSocket(
ws: WebSocket,
message: string,
userId: string,
inboxContextKeys?: string[]
): Promise<void> {
try {
ws.send(message);
console.log(`Delivered to socket with inbox contexts: ${this.formatContextDisplay(inboxContextKeys)}`);
} catch (error) {
console.error(`Failed to send message to user ${userId}:`, error);
throw error;
}
}

/**
* Log when a socket is skipped due to context mismatch
*/
private logContextMismatch(messageContextKeys: string[], inboxContextKeys?: string[]): void {
const messageDisplay = messageContextKeys.length === 0 ? 'none' : messageContextKeys.join(', ');
const inboxDisplay = inboxContextKeys?.length === 0 ? 'none' : inboxContextKeys?.join(', ') || 'none';

console.log(`Skipped socket - contexts mismatch. Message: [${messageDisplay}], Inbox: [${inboxDisplay}]`);
}
}
6 changes: 4 additions & 2 deletions enterprise/workers/socket/src/handlers/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export async function handleWebSocketUpgrade(context: Context) {
const subscriberId = context.get('subscriberId');
const organizationId = context.get('organizationId');
const environmentId = context.get('environmentId');
const contextKeys = context.get('contextKeys');

// Extract JWT token from query parameter
const jwtToken = context.req.query('token');
Expand All @@ -28,6 +29,7 @@ export async function handleWebSocketUpgrade(context: Context) {
'X-Organization-Id': organizationId,
'X-Environment-Id': environmentId,
'X-JWT-Token': jwtToken || '',
'X-Context-Keys': contextKeys !== undefined ? JSON.stringify(contextKeys) : '',
},
body: context.req.raw.body,
});
Expand All @@ -38,7 +40,7 @@ export async function handleWebSocketUpgrade(context: Context) {
// Send message handler - Protected by internal API key authentication
export async function handleSendMessage(context: Context) {
try {
const { userId, event, data, environmentId } = await context.req.json();
const { userId, event, data, environmentId, contextKeys } = await context.req.json();

// Validate required fields
if (!userId || !event) {
Expand Down Expand Up @@ -69,7 +71,7 @@ export async function handleSendMessage(context: Context) {
const id = namespace.idFromName(roomId);
const stub = namespace.get(id);

await stub.sendToUser(userId, event, data);
await stub.sendToUser(userId, event, data, contextKeys);

return context.json({ success: true, roomId, timestamp: new Date().toISOString() });
} catch (error) {
Expand Down
2 changes: 2 additions & 0 deletions enterprise/workers/socket/src/middleware/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export async function authenticateJWT(context: Context, next: Next) {
const userId = userPayload._id;
const subscriberId = userPayload.subscriberId || userId;
const { organizationId, environmentId } = userPayload;
const contextKeys = userPayload.contextKeys;

if (!userId || !subscriberId || !organizationId || !environmentId) {
return context.json({ error: 'Unauthorized: Missing required user information in JWT' }, 401);
Expand All @@ -53,6 +54,7 @@ export async function authenticateJWT(context: Context, next: Next) {
context.set('subscriberId', subscriberId);
context.set('organizationId', organizationId);
context.set('environmentId', environmentId);
context.set('contextKeys', contextKeys);

await next();
} catch (error) {
Expand Down
3 changes: 2 additions & 1 deletion enterprise/workers/socket/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ export interface IConnectionMetadata {
environmentId: string;
connectedAt: number;
jwtToken: string;
contextKeys?: string[];
}

export interface IWebSocketRoom {
sendToUser(userId: string, event: string, data: unknown): Promise<void>;
sendToUser(userId: string, event: string, data: unknown, contextKeys?: string[]): Promise<void>;
getActiveConnectionsForUser(userId: string): number;
getTotalActiveConnections(): number;
getConnectionCapacity(): { current: number; max: number; available: number };
Expand Down
Loading