Skip to content

Commit 6800532

Browse files
committed
POC: Support sparse model return token_id
Signed-off-by: yuye-aws <yuyezhu@amazon.com>
1 parent dad243f commit 6800532

File tree

5 files changed

+127
-12
lines changed

5 files changed

+127
-12
lines changed

common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@
2727
* ML input class which supports a list fo text docs.
2828
* This class can be used for TEXT_EMBEDDING model.
2929
*/
30-
@org.opensearch.ml.common.annotation.MLInput(functionNames = {
31-
FunctionName.TEXT_EMBEDDING,
32-
FunctionName.SPARSE_ENCODING,
33-
FunctionName.SPARSE_TOKENIZE })
30+
@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.TEXT_EMBEDDING })
3431
public class TextDocsMLInput extends MLInput {
3532
public static final String TEXT_DOCS_FIELD = "text_docs";
3633
public static final String RESULT_FILTER_FIELD = "result_filter";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.input.parameter.textembedding;
7+
8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
10+
import java.io.IOException;
11+
import java.util.Locale;
12+
13+
import org.opensearch.core.ParseField;
14+
import org.opensearch.core.common.io.stream.StreamOutput;
15+
import org.opensearch.core.xcontent.NamedXContentRegistry;
16+
import org.opensearch.core.xcontent.XContentBuilder;
17+
import org.opensearch.core.xcontent.XContentParser;
18+
import org.opensearch.ml.common.FunctionName;
19+
import org.opensearch.ml.common.annotation.MLAlgoParameter;
20+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
21+
22+
import lombok.Builder;
23+
24+
@MLAlgoParameter(algorithms = { FunctionName.SPARSE_ENCODING })
25+
public class SparseEncodingParameters implements MLAlgoParams {
26+
27+
public static final String PARSE_FIELD_NAME = FunctionName.SPARSE_ENCODING.name();
28+
public static final String SPARSE_ENCODING_FORMAT_FIELD = "sparse_encoding_format";
29+
30+
@Override
31+
public int getVersion() {
32+
return 1;
33+
}
34+
35+
@Override
36+
public String getWriteableName() {
37+
return PARSE_FIELD_NAME;
38+
}
39+
40+
@Override
41+
public void writeTo(StreamOutput out) throws IOException {
42+
out.writeOptionalString(sparseEncodingType.name());
43+
}
44+
45+
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
46+
MLAlgoParams.class,
47+
new ParseField(PARSE_FIELD_NAME),
48+
SparseEncodingParameters::parse
49+
);
50+
51+
@Override
52+
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
53+
xContentBuilder.startObject();
54+
if (sparseEncodingType != null) {
55+
xContentBuilder.field(SPARSE_ENCODING_FORMAT_FIELD, sparseEncodingType.name());
56+
}
57+
xContentBuilder.endObject();
58+
return xContentBuilder;
59+
}
60+
61+
public enum SparseEncodingFormat {
62+
WORD,
63+
INT
64+
}
65+
66+
// The type of the content to be embedded
67+
private final SparseEncodingFormat sparseEncodingType;
68+
69+
@Builder(toBuilder = true)
70+
public SparseEncodingParameters(SparseEncodingFormat sparseEncodingType) {
71+
this.sparseEncodingType = sparseEncodingType;
72+
}
73+
74+
public SparseEncodingFormat getSparseEncodingType() {
75+
return sparseEncodingType;
76+
}
77+
78+
public static MLAlgoParams parse(XContentParser parser) throws IOException {
79+
SparseEncodingFormat sparseEncodingType = null;
80+
81+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
82+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
83+
String fieldName = parser.currentName();
84+
parser.nextToken();
85+
86+
if (fieldName.equals(SPARSE_ENCODING_FORMAT_FIELD)) {
87+
String contentType = parser.text();
88+
sparseEncodingType = SparseEncodingFormat.valueOf(contentType.toUpperCase(Locale.ROOT));
89+
} else {
90+
parser.skipChildren();
91+
}
92+
}
93+
return new SparseEncodingParameters(sparseEncodingType);
94+
}
95+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
1313
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
1414
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
15+
import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters;
1516
import org.opensearch.ml.common.model.MLModelConfig;
1617
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
1718
import org.opensearch.ml.common.output.model.ModelResultFilter;
@@ -40,6 +41,10 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla
4041
for (String doc : textDocsInput.getDocs()) {
4142
Input input = new Input();
4243
input.add(doc);
44+
if (mlParams instanceof SparseEncodingParameters) {
45+
input.add("sparse_encoding_format", ((SparseEncodingParameters) mlParams).getSparseEncodingType().name());
46+
}
47+
4348
output = getPredictor().predict(input);
4449
tensorOutputs.add(parseModelTensorOutput(output, resultFilter));
4550
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,48 @@
66
package org.opensearch.ml.engine.algorithms.sparse_encoding;
77

88
import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
9+
import static org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters.SPARSE_ENCODING_FORMAT_FIELD;
910

1011
import java.util.ArrayList;
1112
import java.util.Collections;
1213
import java.util.HashMap;
13-
import java.util.Iterator;
1414
import java.util.List;
1515
import java.util.Map;
1616

17+
import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters;
1718
import org.opensearch.ml.common.output.model.ModelTensor;
1819
import org.opensearch.ml.common.output.model.ModelTensors;
1920
import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator;
2021

22+
import ai.djl.modality.Input;
2123
import ai.djl.modality.Output;
2224
import ai.djl.ndarray.NDArray;
2325
import ai.djl.ndarray.NDList;
2426
import ai.djl.translate.TranslatorContext;
2527

2628
public class SparseEncodingTranslator extends SentenceTransformerTranslator {
29+
30+
@Override
31+
public NDList processInput(TranslatorContext ctx, Input input) {
32+
String sparse_encoding_format = input.getAsString(SPARSE_ENCODING_FORMAT_FIELD);
33+
if (sparse_encoding_format != null) {
34+
ctx.setAttachment(SPARSE_ENCODING_FORMAT_FIELD, sparse_encoding_format);
35+
}
36+
return super.processInput(ctx, input);
37+
}
38+
2739
@Override
2840
public Output processOutput(TranslatorContext ctx, NDList list) {
2941
Output output = new Output(200, "OK");
42+
Object sparseEncodingFormatObject = ctx.getAttachment(SPARSE_ENCODING_FORMAT_FIELD);
43+
String sparseEncodingFormatString = sparseEncodingFormatObject != null
44+
? sparseEncodingFormatObject.toString()
45+
: SparseEncodingParameters.SparseEncodingFormat.WORD.name();
3046

3147
List<ModelTensor> outputs = new ArrayList<>();
32-
Iterator<NDArray> iterator = list.iterator();
33-
while (iterator.hasNext()) {
34-
NDArray ndArray = iterator.next();
48+
for (NDArray ndArray : list) {
3549
String name = ndArray.getName();
36-
Map<String, Float> tokenWeightsMap = convertOutput(ndArray);
50+
Map<String, Float> tokenWeightsMap = convertOutput(ndArray, sparseEncodingFormatString);
3751
Map<String, ?> wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeightsMap));
3852
ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build();
3953
outputs.add(tensor);
@@ -44,12 +58,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
4458
return output;
4559
}
4660

47-
private Map<String, Float> convertOutput(NDArray array) {
61+
private Map<String, Float> convertOutput(NDArray array, String sparseEncodingFormat) {
4862
Map<String, Float> map = new HashMap<>();
4963
NDArray nonZeroIndices = array.nonzero().squeeze();
5064

5165
for (long index : nonZeroIndices.toLongArray()) {
52-
String s = this.tokenizer.decode(new long[] { index }, true);
66+
String s = sparseEncodingFormat.equals(SparseEncodingParameters.SparseEncodingFormat.INT.name())
67+
? Long.toString(index)
68+
: this.tokenizer.decode(new long[] { index }, true);
5369
if (!s.isEmpty()) {
5470
map.put(s, array.getFloat(index));
5571
}

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams;
136136
import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams;
137137
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
138+
import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters;
138139
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
139140
import org.opensearch.ml.common.settings.MLCommonsSettings;
140141
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
@@ -1038,7 +1039,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
10381039
RCFSummarizeParams.XCONTENT_REGISTRY,
10391040
LogisticRegressionParams.XCONTENT_REGISTRY,
10401041
TextEmbeddingModelConfig.XCONTENT_REGISTRY,
1041-
AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY
1042+
AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY,
1043+
SparseEncodingParameters.XCONTENT_REGISTRY
10421044
);
10431045
}
10441046

0 commit comments

Comments
 (0)