Skip to content

Commit 088c1a5

Browse files
authored
Add BERT Tokenizer as OpenSearch built-in analyzer (#3719)
* bert analyzer Signed-off-by: zhichao-aws <zhichaog@amazon.com> * add license header Signed-off-by: zhichao-aws <zhichaog@amazon.com> * add rest test case Signed-off-by: zhichao-aws <zhichaog@amazon.com> * load from zip Signed-off-by: zhichao-aws <zhichaog@amazon.com> * address comments Signed-off-by: zhichao-aws <zhichaog@amazon.com> * retry for init Signed-off-by: zhichao-aws <zhichaog@amazon.com> --------- Signed-off-by: zhichao-aws <zhichaog@amazon.com>
1 parent 37d79e6 commit 088c1a5

File tree

19 files changed

+923
-8
lines changed

19 files changed

+923
-8
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public class MLEngine {
3838

3939
public static final String REGISTER_MODEL_FOLDER = "register";
4040
public static final String DEPLOY_MODEL_FOLDER = "deploy";
41+
public static final String ANALYSIS_FOLDER = "analysis";
4142
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";
4243

4344
@Getter
@@ -114,6 +115,10 @@ public Path getModelCacheRootPath() {
114115
return mlModelsCachePath.resolve("models");
115116
}
116117

118+
public Path getAnalysisRootPath() {
119+
return mlModelsCachePath.resolve(ANALYSIS_FOLDER);
120+
}
121+
117122
public MLModel train(Input input) {
118123
validateMLInput(input);
119124
MLInput mlInput = (MLInput) input;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
package org.opensearch.ml.engine.algorithms.tokenize;
77

88
import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
9-
import static org.opensearch.ml.common.utils.StringUtils.gson;
109

1110
import java.io.IOException;
12-
import java.io.InputStreamReader;
13-
import java.lang.reflect.Type;
1411
import java.nio.file.Files;
1512
import java.nio.file.Path;
1613
import java.util.ArrayList;
@@ -31,10 +28,9 @@
3128
import org.opensearch.ml.common.output.model.ModelTensorOutput;
3229
import org.opensearch.ml.common.output.model.ModelTensors;
3330
import org.opensearch.ml.engine.algorithms.DLModel;
31+
import org.opensearch.ml.engine.analysis.DJLUtils;
3432
import org.opensearch.ml.engine.annotation.Function;
3533

36-
import com.google.gson.reflect.TypeToken;
37-
3834
import ai.djl.MalformedModelException;
3935
import ai.djl.huggingface.tokenizers.Encoding;
4036
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
@@ -110,9 +106,7 @@ protected void doLoadModel(
110106
tokenizer = HuggingFaceTokenizer.builder().optPadding(true).optTokenizerPath(modelPath.resolve("tokenizer.json")).build();
111107
idf = new HashMap<>();
112108
if (Files.exists(modelPath.resolve(IDF_FILE_NAME))) {
113-
Type mapType = new TypeToken<Map<String, Float>>() {
114-
}.getType();
115-
idf = gson.fromJson(new InputStreamReader(Files.newInputStream(modelPath.resolve(IDF_FILE_NAME))), mapType);
109+
idf = DJLUtils.fetchTokenWeights(modelPath.resolve(IDF_FILE_NAME));
116110
}
117111
log.info("sparse tokenize Model {} is successfully deployed", modelId);
118112
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.ml.engine.analysis;
6+
7+
import static org.opensearch.ml.common.utils.StringUtils.gson;
8+
9+
import java.io.IOException;
10+
import java.io.InputStreamReader;
11+
import java.lang.reflect.Type;
12+
import java.nio.file.Files;
13+
import java.nio.file.Path;
14+
import java.security.AccessController;
15+
import java.security.PrivilegedActionException;
16+
import java.security.PrivilegedExceptionAction;
17+
import java.util.Map;
18+
import java.util.concurrent.Callable;
19+
20+
import org.opensearch.ml.engine.MLEngine;
21+
22+
import com.google.gson.reflect.TypeToken;
23+
24+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
25+
import lombok.Getter;
26+
import lombok.Setter;
27+
28+
/**
29+
* Utility class for DJL (Deep Java Library) operations related to tokenization and model handling.
30+
*/
31+
public class DJLUtils {
32+
@Getter
33+
@Setter
34+
private static MLEngine mlEngine;
35+
36+
private static <T> T withDJLContext(Callable<T> action) throws PrivilegedActionException {
37+
return AccessController.doPrivileged((PrivilegedExceptionAction<T>) () -> {
38+
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
39+
try {
40+
System.setProperty("java.library.path", mlEngine.getMlCachePath().toAbsolutePath().toString());
41+
System.setProperty("DJL_CACHE_DIR", mlEngine.getMlCachePath().toAbsolutePath().toString());
42+
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader());
43+
44+
return action.call();
45+
} finally {
46+
Thread.currentThread().setContextClassLoader(contextClassLoader);
47+
}
48+
});
49+
}
50+
51+
/**
52+
* Creates a new HuggingFaceTokenizer instance for the given resource path.
53+
* @param resourcePath The resource path of the tokenizer to create
54+
* @return A new HuggingFaceTokenizer instance
55+
* @throws RuntimeException if tokenizer initialization fails
56+
*/
57+
public static HuggingFaceTokenizer buildHuggingFaceTokenizer(Path resourcePath) {
58+
try {
59+
return withDJLContext(() -> { return HuggingFaceTokenizer.newInstance(resourcePath); });
60+
} catch (PrivilegedActionException e) {
61+
throw new RuntimeException("Failed to initialize Hugging Face tokenizer. " + e);
62+
}
63+
}
64+
65+
/**
66+
* Fetches token weights from a specified file for a given tokenizer.
67+
* @param resourcePath The resource path of the tokenizer to create
68+
* @return A map of token to weight mappings
69+
* @throws RuntimeException if file fetching or parsing fails
70+
*/
71+
public static Map<String, Float> fetchTokenWeights(Path resourcePath) {
72+
try {
73+
Type mapType = new TypeToken<Map<String, Float>>() {
74+
}.getType();
75+
return gson.fromJson(new InputStreamReader(Files.newInputStream(resourcePath)), mapType);
76+
} catch (IOException e) {
77+
throw new RuntimeException("Failed to parse token weights file. " + e);
78+
}
79+
}
80+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.ml.engine.analysis;
6+
7+
import java.util.function.Supplier;
8+
9+
import org.apache.lucene.analysis.Analyzer;
10+
import org.apache.lucene.analysis.Tokenizer;
11+
12+
/**
13+
* Custom Lucene Analyzer that uses the HFModelTokenizer for text analysis.
14+
* Provides a way to process text using Hugging Face models within OpenSearch.
15+
*/
16+
public class HFModelAnalyzer extends Analyzer {
17+
Supplier<Tokenizer> tokenizerSupplier;
18+
19+
public HFModelAnalyzer(Supplier<Tokenizer> tokenizerSupplier) {
20+
this.tokenizerSupplier = tokenizerSupplier;
21+
}
22+
23+
@Override
24+
protected TokenStreamComponents createComponents(String fieldName) {
25+
final Tokenizer src = tokenizerSupplier.get();
26+
return new TokenStreamComponents(src, src);
27+
}
28+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.ml.engine.analysis;
6+
7+
import org.opensearch.common.settings.Settings;
8+
import org.opensearch.env.Environment;
9+
import org.opensearch.index.IndexSettings;
10+
import org.opensearch.index.analysis.AbstractIndexAnalyzerProvider;
11+
12+
/**
13+
* Provider class for HFModelAnalyzer instances.
14+
* Handles the creation and configuration of HFModelAnalyzer instances within OpenSearch.
15+
*/
16+
public class HFModelAnalyzerProvider extends AbstractIndexAnalyzerProvider<HFModelAnalyzer> {
17+
private final HFModelAnalyzer analyzer;
18+
19+
public HFModelAnalyzerProvider(IndexSettings indexSettings, Environment environment, String name, Settings settings) {
20+
super(indexSettings, name, settings);
21+
HFModelTokenizerFactory tokenizerFactory = new HFModelTokenizerFactory(indexSettings, environment, name, settings);
22+
analyzer = new HFModelAnalyzer(tokenizerFactory::create);
23+
}
24+
25+
@Override
26+
public HFModelAnalyzer get() {
27+
return analyzer;
28+
}
29+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.ml.engine.analysis;
6+
7+
import java.io.IOException;
8+
import java.nio.ByteBuffer;
9+
import java.util.Map;
10+
import java.util.Objects;
11+
import java.util.function.Supplier;
12+
13+
import org.apache.commons.lang3.StringUtils;
14+
import org.apache.lucene.analysis.Tokenizer;
15+
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
16+
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
17+
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
18+
import org.apache.lucene.util.BytesRef;
19+
20+
import com.google.common.io.CharStreams;
21+
22+
import ai.djl.huggingface.tokenizers.Encoding;
23+
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
24+
import lombok.extern.log4j.Log4j2;
25+
26+
/**
27+
* A Lucene Tokenizer implementation that uses Hugging Face tokenizer for tokenization.
28+
* Supports token weighting and handles overflow scenarios.
29+
*/
30+
@Log4j2
31+
public class HFModelTokenizer extends Tokenizer {
32+
public static final String NAME = "hf_model_tokenizer";
33+
private static final Float DEFAULT_TOKEN_WEIGHT = 1.0f;
34+
35+
private final CharTermAttribute termAtt;
36+
private final PayloadAttribute payloadAtt;
37+
private final OffsetAttribute offsetAtt;
38+
private final Supplier<HuggingFaceTokenizer> tokenizerSupplier;
39+
private final Supplier<Map<String, Float>> tokenWeightsSupplier;
40+
41+
private Encoding encoding;
42+
private int tokenIdx = 0;
43+
private int overflowingIdx = 0;
44+
45+
public HFModelTokenizer(Supplier<HuggingFaceTokenizer> huggingFaceTokenizerSupplier) {
46+
this(huggingFaceTokenizerSupplier, null);
47+
}
48+
49+
public HFModelTokenizer(Supplier<HuggingFaceTokenizer> huggingFaceTokenizerSupplier, Supplier<Map<String, Float>> weightsSupplier) {
50+
termAtt = addAttribute(CharTermAttribute.class);
51+
offsetAtt = addAttribute(OffsetAttribute.class);
52+
if (Objects.nonNull(weightsSupplier)) {
53+
payloadAtt = addAttribute(PayloadAttribute.class);
54+
} else {
55+
payloadAtt = null;
56+
}
57+
tokenizerSupplier = huggingFaceTokenizerSupplier;
58+
tokenWeightsSupplier = weightsSupplier;
59+
}
60+
61+
@Override
62+
public void reset() throws IOException {
63+
super.reset();
64+
tokenIdx = 0;
65+
overflowingIdx = -1;
66+
String inputStr = CharStreams.toString(input);
67+
// For pre-built analyzer, when create new index service, reset() will be called with empty input in checkVersions
68+
// And we want to lazy-load the tokenizer only really needed. So we use supplier, and skip empty input.
69+
encoding = StringUtils.isEmpty(inputStr) ? null : tokenizerSupplier.get().encode(inputStr, false, true);
70+
}
71+
72+
private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) {
73+
return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0;
74+
}
75+
76+
public static byte[] floatToBytes(float value) {
77+
return ByteBuffer.allocate(4).putFloat(value).array();
78+
}
79+
80+
public static float bytesToFloat(byte[] bytes) {
81+
return ByteBuffer.wrap(bytes).getFloat();
82+
}
83+
84+
@Override
85+
final public boolean incrementToken() throws IOException {
86+
clearAttributes();
87+
if (Objects.isNull(encoding))
88+
return false;
89+
Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx];
90+
91+
while (!isLastTokenInEncodingSegment(tokenIdx, curEncoding) || overflowingIdx < encoding.getOverflowing().length) {
92+
if (isLastTokenInEncodingSegment(tokenIdx, curEncoding)) {
93+
// reset cur segment, go to the next segment
94+
// until overflowingIdx = encoding.getOverflowing().length
95+
tokenIdx = 0;
96+
overflowingIdx++;
97+
if (overflowingIdx >= encoding.getOverflowing().length) {
98+
return false;
99+
}
100+
curEncoding = encoding.getOverflowing()[overflowingIdx];
101+
} else {
102+
termAtt.append(curEncoding.getTokens()[tokenIdx]);
103+
offsetAtt
104+
.setOffset(curEncoding.getCharTokenSpans()[tokenIdx].getStart(), curEncoding.getCharTokenSpans()[tokenIdx].getEnd());
105+
if (Objects.nonNull(tokenWeightsSupplier)) {
106+
// for neural sparse query, write the token weight to payload field
107+
payloadAtt
108+
.setPayload(
109+
new BytesRef(
110+
floatToBytes(
111+
tokenWeightsSupplier.get().getOrDefault(curEncoding.getTokens()[tokenIdx], DEFAULT_TOKEN_WEIGHT)
112+
)
113+
)
114+
);
115+
}
116+
tokenIdx++;
117+
return true;
118+
}
119+
}
120+
121+
return false;
122+
}
123+
}

0 commit comments

Comments
 (0)