Skip to content

Commit 4c8a6ee

Browse files
sobychackomarkpollack
authored andcommitted
Document BatchingStrategy and enhance TokenCountBatching
This commit adds comprehensive documentation for the BatchingStrategy in vector stores and enhances the TokenCountBatchingStrategy class. Key changes: - Explain batching necessity due to embedding model thresholds - Describe BatchingStrategy interface and its purpose - Detail TokenCountBatchingStrategy default implementation - Provide guidance on using and customizing batching strategies - Note pre-configured vector stores with default strategy - Add new constructor for custom TokenCountEstimator in TokenCountBatchingStrategy - Implement null checks with Spring's Assert utility - Update docs with new customization options and code examples
1 parent e29d38d commit 4c8a6ee

File tree

2 files changed

+131
-1
lines changed

2 files changed

+131
-1
lines changed

spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.tokenizer.TokenCountEstimator;
2828

2929
import com.knuddels.jtokkit.api.EncodingType;
30+
import org.springframework.util.Assert;
3031

3132
/**
3233
* Token count based strategy implementation for {@link BatchingStrategy}. Using openai
@@ -96,12 +97,34 @@ public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCo
9697
*/
9798
public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage,
9899
ContentFormatter contentFormatter, MetadataMode metadataMode) {
100+
Assert.notNull(encodingType, "EncodingType must not be null");
101+
Assert.notNull(contentFormatter, "ContentFormatter must not be null");
102+
Assert.notNull(metadataMode, "MetadataMode must not be null");
99103
this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
100104
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage));
101105
this.contentFormater = contentFormatter;
102106
this.metadataMode = metadataMode;
103107
}
104108

109+
/**
110+
* Constructs a TokenCountBatchingStrategy with the specified parameters.
111+
* @param tokenCountEstimator the TokenCountEstimator to be used for estimating token
112+
* counts.
113+
* @param maxInputTokenCount the initial upper limit for input tokens.
114+
* @param reservePercentage the percentage of tokens to reserve from the max input
115+
* token count to create a buffer.
116+
* @param contentFormatter the ContentFormatter to be used for formatting content.
117+
* @param metadataMode the MetadataMode to be used for handling metadata.
118+
*/
119+
public TokenCountBatchingStrategy(TokenCountEstimator tokenCountEstimator, int maxInputTokenCount,
120+
double reservePercentage, ContentFormatter contentFormatter, MetadataMode metadataMode) {
121+
Assert.notNull(tokenCountEstimator, "TokenCountEstimator must not be null");
122+
this.tokenCountEstimator = tokenCountEstimator;
123+
this.maxInputTokenCount = (int) Math.round(maxInputTokenCount * (1 - reservePercentage));
124+
this.contentFormater = contentFormatter;
125+
this.metadataMode = metadataMode;
126+
}
127+
105128
@Override
106129
public List<List<Document>> batch(List<Document> documents) {
107130
List<List<Document>> batches = new ArrayList<>();

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,114 @@ It will not be initialized for you by default.
9191
You must opt-in, by passing a `boolean` for the appropriate constructor argument or, if using Spring Boot, setting the appropriate `initialize-schema` property to `true` in `application.properties` or `application.yml`.
9292
Check the documentation for the vector store you are using for the specific property name.
9393

94-
== Available Implementations
94+
== Batching Strategy
95+
96+
When working with vector stores, it's often necessary to embed large numbers of documents.
97+
While it might seem straightforward to make a single call to embed all documents at once, this approach can lead to issues.
98+
Embedding models process text as tokens and have a maximum token limit, often referred to as the context window size.
99+
This limit restricts the amount of text that can be processed in a single embedding request.
100+
Attempting to embed too many tokens in one call can result in errors or truncated embeddings.
101+
102+
To address this token limit, Spring AI implements a batching strategy.
103+
This approach breaks down large sets of documents into smaller batches that fit within the embedding model's maximum context window.
104+
Batching not only solves the token limit issue but can also lead to improved performance and more efficient use of API rate limits.
105+
106+
Spring AI provides this functionality through the `BatchingStrategy` interface, which allows for processing documents in sub-batches based on their token counts.
107+
108+
The core `BatchingStrategy` interface is defined as follows:
109+
110+
[source,java]
111+
----
112+
public interface BatchingStrategy {
113+
List<List<Document>> batch(List<Document> documents);
114+
}
115+
----
116+
117+
This interface defines a single method, `batch`, which takes a list of documents and returns a list of document batches.
118+
119+
=== Default Implementation
120+
121+
Spring AI provides a default implementation called `TokenCountBatchingStrategy`.
122+
This strategy batches documents based on their token counts, ensuring that each batch does not exceed a calculated maximum input token count.
123+
124+
Key features of `TokenCountBatchingStrategy`:
125+
126+
1. Uses https://platform.openai.com/docs/guides/embeddings/embedding-models[OpenAI's max input token count] (8191) as the default upper limit.
127+
2. Incorporates a reserve percentage (default 10%) to provide a buffer for potential overhead.
128+
3. Calculates the actual max input token count as: `actualMaxInputTokenCount = originalMaxInputTokenCount * (1 - RESERVE_PERCENTAGE)`
129+
130+
The strategy estimates the token count for each document, groups them into batches without exceeding the max input token count, and throws an exception if a single document exceeds this limit.
131+
132+
You can also customize the `TokenCountBatchingStrategy` to better suit your specific requirements. This can be done by creating a new instance with custom parameters in a Spring Boot `@Configuration` class.
133+
134+
Here's an example of how to create a custom `TokenCountBatchingStrategy` bean:
135+
136+
[source,java]
137+
----
138+
@Configuration
139+
public class EmbeddingConfig {
140+
@Bean
141+
public BatchingStrategy customTokenCountBatchingStrategy() {
142+
return new TokenCountBatchingStrategy(
143+
EncodingType.CL100K_BASE, // Specify the encoding type
144+
8000, // Set the maximum input token count
145+
0.9 // Set the threshold factor
146+
);
147+
}
148+
}
149+
----
150+
151+
In this configuration:
152+
153+
1. `EncodingType.CL100K_BASE`: Specifies the encoding type used for tokenization. This encoding type is used by the `JTokkitTokenCountEstimator` to accurately estimate token counts.
154+
2. `8000`: Sets the maximum input token count. This value should be less than or equal to the maximum context window size of your embedding model.
155+
3. `0.9`: Sets the threshold factor. This factor determines how full a batch can be before starting a new one. A value of 0.9 means each batch will be filled up to 90% of the maximum input token count.
156+
157+
By default, this constructor uses `Document.DEFAULT_CONTENT_FORMATTER` for content formatting and `MetadataMode.NONE` for metadata handling. If you need to customize these parameters, you can use the full constructor with additional parameters.
158+
159+
Once defined, this custom `TokenCountBatchingStrategy` bean will be automatically used by the `EmbeddingModel` implementations in your application, replacing the default strategy.
160+
161+
The `TokenCountBatchingStrategy` internally uses a `TokenCountEstimator` (specifically, `JTokkitTokenCountEstimator`) to calculate token counts for efficient batching. This ensures accurate token estimation based on the specified encoding type.
162+
163+
164+
Additionally, `TokenCountBatchingStrategy` provides flexibility by allowing you to pass in your own implementation of the `TokenCountEstimator` interface. This feature enables you to use custom token counting strategies tailored to your specific needs. For example:
165+
166+
[source,java]
167+
----
168+
TokenCountEstimator customEstimator = new YourCustomTokenCountEstimator();
169+
TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy(
170+
customEstimator,
171+
8000, // maxInputTokenCount
172+
0.1, // reservePercentage
173+
Document.DEFAULT_CONTENT_FORMATTER,
174+
MetadataMode.NONE
175+
);
176+
----
177+
178+
=== Custom Implementation
179+
180+
While `TokenCountBatchingStrategy` provides a robust default implementation, you can customize the batching strategy to fit your specific needs.
181+
This can be done through Spring Boot's auto-configuration.
182+
183+
To customize the batching strategy, define a `BatchingStrategy` bean in your Spring Boot application:
184+
185+
[source,java]
186+
----
187+
@Configuration
188+
public class EmbeddingConfig {
189+
@Bean
190+
public BatchingStrategy customBatchingStrategy() {
191+
return new CustomBatchingStrategy();
192+
}
193+
}
194+
----
195+
196+
This custom `BatchingStrategy` will then be automatically used by the `EmbeddingModel` implementations in your application.
197+
198+
NOTE: Vector stores supported by Spring AI are configured to use the default `TokenCountBatchingStrategy`.
199+
SAP Hana vector store is not currently configured for batching.
200+
201+
== VectorStore Implementations
95202

96203
These are the available implementations of the `VectorStore` interface:
97204

0 commit comments

Comments
 (0)