Skip to content

Commit d3555c8

Browse files
authored
Introduce new gateway type: 'chat' (#398)
1 parent 3bc13b9 commit d3555c8

File tree

14 files changed

+1149
-554
lines changed

14 files changed

+1149
-554
lines changed

examples/applications/gateway-authentication/gateways.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ gateways:
3535
headers:
3636
- key: langstream-client-session-id
3737
value-from-parameters: sessionId
38+
- id: chat-no-auth
39+
type: chat
40+
chat-options:
41+
headers:
42+
- value-from-parameters: session-id
43+
questions-topic: input-topic
44+
answers-topic: output-topic
3845

3946
- id: produce-input-auth-google
4047
type: produce

examples/instances/kafka-docker.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ instance:
2020
type: "kafka"
2121
configuration:
2222
admin:
23-
bootstrap.servers: localhost:9092
23+
bootstrap.servers: localhost:39092

langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import ai.langstream.api.storage.ApplicationStore;
2020
import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties;
2121
import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean;
22+
import ai.langstream.apigateway.websocket.handlers.ChatHandler;
2223
import ai.langstream.apigateway.websocket.handlers.ConsumeHandler;
2324
import ai.langstream.apigateway.websocket.handlers.ProduceHandler;
2425
import jakarta.annotation.PreDestroy;
@@ -42,6 +43,7 @@ public class WebSocketConfig implements WebSocketConfigurer {
4243

4344
public static final String CONSUME_PATH = "/v1/consume/{tenant}/{application}/{gateway}";
4445
public static final String PRODUCE_PATH = "/v1/produce/{tenant}/{application}/{gateway}";
46+
public static final String CHAT_PATH = "/v1/chat/{tenant}/{application}/{gateway}";
4547

4648
private final ApplicationStore applicationStore;
4749
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
@@ -61,6 +63,12 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
6163
.addHandler(
6264
new ProduceHandler(applicationStore, topicConnectionsRuntimeRegistry),
6365
PRODUCE_PATH)
66+
.addHandler(
67+
new ChatHandler(
68+
applicationStore,
69+
consumeThreadPool,
70+
topicConnectionsRuntimeRegistry),
71+
CHAT_PATH)
6472
.setAllowedOrigins("*")
6573
.addInterceptors(
6674
new HttpSessionHandshakeInterceptor(),

langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java

Lines changed: 341 additions & 2 deletions
Large diffs are not rendered by default.
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package ai.langstream.apigateway.websocket.handlers;
17+
18+
import static ai.langstream.apigateway.websocket.WebSocketConfig.CHAT_PATH;
19+
20+
import ai.langstream.api.model.Gateway;
21+
import ai.langstream.api.runner.code.Header;
22+
import ai.langstream.api.runner.code.Record;
23+
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
24+
import ai.langstream.api.storage.ApplicationStore;
25+
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
26+
import java.util.ArrayList;
27+
import java.util.List;
28+
import java.util.Map;
29+
import java.util.concurrent.ExecutorService;
30+
import java.util.function.Function;
31+
import lombok.extern.slf4j.Slf4j;
32+
import org.springframework.util.StringUtils;
33+
import org.springframework.web.socket.CloseStatus;
34+
import org.springframework.web.socket.TextMessage;
35+
import org.springframework.web.socket.WebSocketSession;
36+
37+
@Slf4j
38+
public class ChatHandler extends AbstractHandler {
39+
40+
private final ExecutorService executor;
41+
42+
public ChatHandler(
43+
ApplicationStore applicationStore,
44+
ExecutorService executor,
45+
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) {
46+
super(applicationStore, topicConnectionsRuntimeRegistry);
47+
this.executor = executor;
48+
}
49+
50+
@Override
51+
public String path() {
52+
return CHAT_PATH;
53+
}
54+
55+
@Override
56+
Gateway.GatewayType gatewayType() {
57+
return Gateway.GatewayType.chat;
58+
}
59+
60+
@Override
61+
String tenantFromPath(Map<String, String> parsedPath, Map<String, String> queryString) {
62+
return parsedPath.get("tenant");
63+
}
64+
65+
@Override
66+
String applicationIdFromPath(Map<String, String> parsedPath, Map<String, String> queryString) {
67+
return parsedPath.get("application");
68+
}
69+
70+
@Override
71+
String gatewayFromPath(Map<String, String> parsedPath, Map<String, String> queryString) {
72+
return parsedPath.get("gateway");
73+
}
74+
75+
@Override
76+
protected List<String> getAllRequiredParameters(Gateway gateway) {
77+
List<String> parameters = gateway.getParameters();
78+
if (parameters == null) {
79+
parameters = new ArrayList<>();
80+
}
81+
if (gateway.getChatOptions() != null && gateway.getChatOptions().getHeaders() != null) {
82+
for (Gateway.KeyValueComparison header : gateway.getChatOptions().getHeaders()) {
83+
parameters.add(header.key());
84+
}
85+
}
86+
return parameters;
87+
}
88+
89+
@Override
90+
public void onBeforeHandshakeCompleted(
91+
AuthenticatedGatewayRequestContext context, Map<String, Object> attributes)
92+
throws Exception {
93+
94+
setupReader(context);
95+
setupProducer(context);
96+
97+
sendClientConnectedEvent(context);
98+
}
99+
100+
private void setupProducer(AuthenticatedGatewayRequestContext context) {
101+
final Gateway.ChatOptions chatOptions = context.gateway().getChatOptions();
102+
103+
List<Gateway.KeyValueComparison> headerConfig = new ArrayList<>();
104+
final List<Gateway.KeyValueComparison> gwHeaders = chatOptions.getHeaders();
105+
if (gwHeaders != null) {
106+
for (Gateway.KeyValueComparison gwHeader : gwHeaders) {
107+
headerConfig.add(gwHeader);
108+
}
109+
}
110+
final List<Header> commonHeaders =
111+
getProducerCommonHeaders(
112+
headerConfig, context.userParameters(), context.principalValues());
113+
setupProducer(
114+
context.attributes(),
115+
chatOptions.getQuestionsTopic(),
116+
context.application().getInstance().streamingCluster(),
117+
commonHeaders,
118+
context.tenant(),
119+
context.applicationId(),
120+
context.gateway().getId());
121+
}
122+
123+
private void setupReader(AuthenticatedGatewayRequestContext context) throws Exception {
124+
final Gateway.ChatOptions chatOptions = context.gateway().getChatOptions();
125+
126+
List<Gateway.KeyValueComparison> headerFilters = new ArrayList<>();
127+
final List<Gateway.KeyValueComparison> gwHeaders = chatOptions.getHeaders();
128+
if (gwHeaders != null) {
129+
for (Gateway.KeyValueComparison gwHeader : gwHeaders) {
130+
headerFilters.add(gwHeader);
131+
}
132+
}
133+
final List<Function<Record, Boolean>> messageFilters =
134+
createMessageFilters(
135+
headerFilters, context.userParameters(), context.principalValues());
136+
137+
setupReader(
138+
context.attributes(),
139+
chatOptions.getAnswersTopic(),
140+
context.application().getInstance().streamingCluster(),
141+
messageFilters,
142+
context.options());
143+
}
144+
145+
@Override
146+
public void onOpen(
147+
WebSocketSession webSocketSession, AuthenticatedGatewayRequestContext context) {
148+
startReadingMessages(webSocketSession, context, executor);
149+
}
150+
151+
@Override
152+
public void onMessage(
153+
WebSocketSession webSocketSession,
154+
AuthenticatedGatewayRequestContext context,
155+
TextMessage message)
156+
throws Exception {
157+
produceMessage(webSocketSession, message);
158+
}
159+
160+
@Override
161+
public void onClose(
162+
WebSocketSession webSocketSession,
163+
AuthenticatedGatewayRequestContext context,
164+
CloseStatus status) {}
165+
166+
@Override
167+
void validateOptions(Map<String, String> options) {
168+
for (Map.Entry<String, String> option : options.entrySet()) {
169+
switch (option.getKey()) {
170+
case "position":
171+
if (!StringUtils.hasText(option.getValue())) {
172+
throw new IllegalArgumentException("'position' cannot be blank");
173+
}
174+
break;
175+
default:
176+
throw new IllegalArgumentException("Unknown option " + option.getKey());
177+
}
178+
}
179+
}
180+
}

0 commit comments

Comments
 (0)