Skip to content

Commit 7c3e4af

Browse files
authored
[kafka] Support writing messages with different schemas (#413)
1 parent 0e5f74d commit 7c3e4af

File tree

2 files changed

+138
-15
lines changed

2 files changed

+138
-15
lines changed

langstream-kafka-runtime/src/main/java/ai/langstream/kafka/runner/KafkaProducerWrapper.java

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,14 @@ byte[].class, new ByteArraySerializer(),
7474
private final AtomicInteger totalIn = new AtomicInteger();
7575
KafkaProducer<Object, Object> producer;
7676
Serializer keySerializer;
77+
Class cacheKeyForKeySerializer;
7778
Serializer valueSerializer;
79+
80+
Class cacheKeyForValueSerializer;
7881
Serializer headerSerializer;
7982

83+
Class cacheKeyForHeaderSerializer;
84+
8085
boolean forcedKeySerializer;
8186
boolean forcedValueSerializer;
8287

@@ -151,7 +156,7 @@ public Map<String, Object> getInfo() {
151156
}
152157

153158
@Override
154-
public CompletableFuture<?> write(Record r) {
159+
public synchronized CompletableFuture<?> write(Record r) {
155160
CompletableFuture<?> handle = new CompletableFuture<>();
156161
try {
157162
List<org.apache.kafka.common.header.Header> headers = new ArrayList<>();
@@ -160,8 +165,11 @@ public CompletableFuture<?> write(Record r) {
160165
if (forcedKeySerializer) {
161166
key = r.key();
162167
} else {
163-
if (keySerializer == null) {
164-
keySerializer = getSerializer(r.key().getClass(), keySerializers, true);
168+
Class<?> keyClass = r.key().getClass();
169+
if (keySerializer == null
170+
|| !(Objects.equals(keyClass, cacheKeyForKeySerializer))) {
171+
keySerializer = getSerializer(keyClass, keySerializers, true);
172+
cacheKeyForKeySerializer = keyClass;
165173
}
166174
key = keySerializer.serialize(topicName, r.key());
167175
}
@@ -171,9 +179,11 @@ public CompletableFuture<?> write(Record r) {
171179
if (forcedValueSerializer) {
172180
value = r.value();
173181
} else {
174-
if (valueSerializer == null) {
175-
valueSerializer =
176-
getSerializer(r.value().getClass(), valueSerializers, false);
182+
Class<?> valueClass = r.value().getClass();
183+
if (valueSerializer == null
184+
|| !(Objects.equals(valueClass, cacheKeyForValueSerializer))) {
185+
valueSerializer = getSerializer(valueClass, valueSerializers, false);
186+
cacheKeyForValueSerializer = valueClass;
177187
}
178188
value = valueSerializer.serialize(topicName, r.value());
179189
}
@@ -182,10 +192,13 @@ public CompletableFuture<?> write(Record r) {
182192
for (Header header : r.headers()) {
183193
Object headerValue = header.value();
184194
byte[] serializedHeader = null;
195+
185196
if (headerValue != null) {
186-
if (headerSerializer == null) {
187-
headerSerializer =
188-
getSerializer(headerValue.getClass(), headerSerializers, null);
197+
Class<?> headerClass = headerValue.getClass();
198+
if (headerSerializer == null
199+
|| !(Objects.equals(headerClass, cacheKeyForHeaderSerializer))) {
200+
headerSerializer = getSerializer(headerClass, headerSerializers, null);
201+
cacheKeyForHeaderSerializer = headerClass;
189202
}
190203
serializedHeader = headerSerializer.serialize(topicName, headerValue);
191204
}
@@ -194,7 +207,10 @@ public CompletableFuture<?> write(Record r) {
194207
}
195208
ProducerRecord<Object, Object> record =
196209
new ProducerRecord<>(topicName, null, null, key, value, headers);
197-
log.info("Sending record {}", record);
210+
211+
if (log.isDebugEnabled()) {
212+
log.debug("Sending record {}", record);
213+
}
198214

199215
producer.send(
200216
record,

langstream-kafka-runtime/src/test/java/ai/langstream/kafka/KafkaConsumerTest.java

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import ai.langstream.api.model.Module;
2525
import ai.langstream.api.model.StreamingCluster;
2626
import ai.langstream.api.model.TopicDefinition;
27+
import ai.langstream.api.runner.code.Header;
2728
import ai.langstream.api.runner.code.Record;
2829
import ai.langstream.api.runner.code.SimpleRecord;
2930
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
@@ -39,9 +40,11 @@
3940
import ai.langstream.kafka.runner.KafkaTopicConnectionsRuntime;
4041
import ai.langstream.kafka.runtime.KafkaTopic;
4142
import java.util.ArrayList;
43+
import java.util.Arrays;
4244
import java.util.List;
4345
import java.util.Map;
4446
import java.util.Set;
47+
import java.util.UUID;
4548
import java.util.concurrent.ExecutorService;
4649
import java.util.concurrent.Executors;
4750
import java.util.concurrent.TimeUnit;
@@ -66,7 +69,7 @@ class KafkaConsumerTest {
6669
@ValueSource(ints = {1, 4})
6770
public void testKafkaConsumerCommitOffsets(int numPartitions) throws Exception {
6871
final AdminClient admin = kafkaContainer.getAdmin();
69-
String topicName = "input-topic-" + numPartitions + "parts";
72+
String topicName = "input-topic-" + numPartitions + "parts-" + UUID.randomUUID();
7073
Application applicationInstance =
7174
ModelBuilder.buildApplicationInstance(
7275
Map.of(
@@ -190,7 +193,7 @@ public void testKafkaConsumerCommitOffsetsMultiThread() throws Exception {
190193
int numPartitions = 4;
191194
int numThreads = 8;
192195
final AdminClient admin = kafkaContainer.getAdmin();
193-
String topicName = "input-topic-" + numPartitions + "-parts-mt";
196+
String topicName = "input-topic-" + numPartitions + "-parts-mt-" + UUID.randomUUID();
194197
Application applicationInstance =
195198
ModelBuilder.buildApplicationInstance(
196199
Map.of(
@@ -285,7 +288,7 @@ public void testKafkaConsumerCommitOffsetsMultiThread() throws Exception {
285288
public void testRestartConsumer() throws Exception {
286289
int numPartitions = 1;
287290
final AdminClient admin = kafkaContainer.getAdmin();
288-
String topicName = "input-topic-restart";
291+
String topicName = "input-topic-restart-" + UUID.randomUUID();
289292
Application applicationInstance =
290293
ModelBuilder.buildApplicationInstance(
291294
Map.of(
@@ -370,6 +373,105 @@ public void testRestartConsumer() throws Exception {
370373
}
371374
}
372375

376+
@Test
377+
public void testMultipleSchemas() throws Exception {
378+
int numPartitions = 1;
379+
final AdminClient admin = kafkaContainer.getAdmin();
380+
String topicName = "input-topic-multi-schemas-" + UUID.randomUUID();
381+
Application applicationInstance =
382+
ModelBuilder.buildApplicationInstance(
383+
Map.of(
384+
"module.yaml",
385+
"""
386+
module: "module-1"
387+
id: "pipeline-1"
388+
topics:
389+
- name: %s
390+
creation-mode: create-if-not-exists
391+
partitions: %d
392+
"""
393+
.formatted(topicName, numPartitions)),
394+
buildInstanceYaml(),
395+
null)
396+
.getApplication();
397+
398+
@Cleanup
399+
ApplicationDeployer deployer =
400+
ApplicationDeployer.builder()
401+
.registry(new ClusterRuntimeRegistry())
402+
.pluginsRegistry(new PluginsRegistry())
403+
.topicConnectionsRuntimeRegistry(new TopicConnectionsRuntimeRegistry())
404+
.build();
405+
406+
Module module = applicationInstance.getModule("module-1");
407+
408+
ExecutionPlan implementation = deployer.createImplementation("app", applicationInstance);
409+
assertTrue(
410+
implementation.getConnectionImplementation(
411+
module, Connection.fromTopic(TopicDefinition.fromName(topicName)))
412+
instanceof KafkaTopic);
413+
414+
deployer.deploy("tenant", implementation, null);
415+
416+
Set<String> topics = admin.listTopics().names().get();
417+
log.info("Topics {}", topics);
418+
assertTrue(topics.contains(topicName));
419+
420+
Map<String, TopicDescription> stats = admin.describeTopics(Set.of(topicName)).all().get();
421+
assertEquals(numPartitions, stats.get(topicName).partitions().size());
422+
423+
deployer.delete("tenant", implementation, null);
424+
topics = admin.listTopics().names().get();
425+
log.info("Topics {}", topics);
426+
assertFalse(topics.contains(topicName));
427+
428+
StreamingCluster streamingCluster =
429+
implementation.getApplication().getInstance().streamingCluster();
430+
KafkaTopicConnectionsRuntime runtime = new KafkaTopicConnectionsRuntime();
431+
runtime.init(streamingCluster);
432+
String agentId = "agent-1";
433+
try (TopicProducer producer =
434+
runtime.createProducer(agentId, streamingCluster, Map.of("topic", topicName)); ) {
435+
producer.start();
436+
437+
int numIterations = 5;
438+
for (int i = 0; i < numIterations; i++) {
439+
440+
producer.write(generateRecord(1, "string")).join();
441+
producer.write(generateRecord("two", 2)).join();
442+
443+
producer.write(
444+
generateRecord(
445+
"two",
446+
2,
447+
new SimpleRecord.SimpleHeader("h1", 7),
448+
new SimpleRecord.SimpleHeader("h2", "bar")))
449+
.join();
450+
451+
producer.write(generateRecord(1, "string")).join();
452+
producer.write(generateRecord("two", 2)).join();
453+
454+
producer.write(
455+
generateRecord(
456+
"two",
457+
2,
458+
new SimpleRecord.SimpleHeader("h1", 7),
459+
new SimpleRecord.SimpleHeader("h2", "bar")))
460+
.join();
461+
462+
try (KafkaConsumerWrapper consumer =
463+
(KafkaConsumerWrapper)
464+
runtime.createConsumer(
465+
agentId, streamingCluster, Map.of("topic", topicName))) {
466+
467+
consumer.start();
468+
List<Record> readFromConsumer = consumeRecords(consumer, 6);
469+
consumer.commit(readFromConsumer);
470+
}
471+
}
472+
}
473+
}
474+
373475
@NotNull
374476
private static List<Record> consumeRecords(TopicConsumer consumer, int atLeast) {
375477
List<Record> readFromConsumer = new ArrayList<>();
@@ -387,12 +489,17 @@ private static List<Record> consumeRecords(TopicConsumer consumer, int atLeast)
387489
return readFromConsumer;
388490
}
389491

390-
private static Record generateRecord(String value) {
492+
private static Record generateRecord(Object value) {
493+
return generateRecord(value, value);
494+
}
495+
496+
private static Record generateRecord(Object key, Object value, Header... headers) {
391497
return SimpleRecord.builder()
392-
.key(value)
498+
.key(key)
393499
.value(value)
394500
.origin("origin")
395501
.timestamp(System.currentTimeMillis())
502+
.headers(headers != null ? Arrays.asList(headers) : List.of())
396503
.build();
397504
}
398505

0 commit comments

Comments
 (0)