Skip to content

Commit 463cc0e

Browse files
authored
Merge branch 'main' into mainline
2 parents 9ae8ce5 + dad243f commit 463cc0e

30 files changed

+1178
-93
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ public class CommonValue {
7272
public static final Version VERSION_2_18_0 = Version.fromString("2.18.0");
7373
public static final Version VERSION_2_19_0 = Version.fromString("2.19.0");
7474
public static final Version VERSION_3_0_0 = Version.fromString("3.0.0");
75+
public static final Version VERSION_3_1_0 = Version.fromString("3.1.0");
7576

7677
// Connector Constants
7778
public static final String NAME_FIELD = "name";

common/src/main/java/org/opensearch/ml/common/MLModel.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
import org.opensearch.core.xcontent.XContentParser;
3333
import org.opensearch.ml.common.connector.Connector;
3434
import org.opensearch.ml.common.controller.MLRateLimiter;
35+
import org.opensearch.ml.common.model.BaseModelConfig;
3536
import org.opensearch.ml.common.model.Guardrails;
3637
import org.opensearch.ml.common.model.MLDeploySetting;
3738
import org.opensearch.ml.common.model.MLModelConfig;
3839
import org.opensearch.ml.common.model.MLModelFormat;
3940
import org.opensearch.ml.common.model.MLModelState;
4041
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
4142
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
43+
import org.opensearch.ml.common.model.RemoteModelConfig;
4244
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
4345

4446
import lombok.Builder;
@@ -278,8 +280,12 @@ public MLModel(StreamInput input) throws IOException {
278280
modelConfig = new MetricsCorrelationModelConfig(input);
279281
} else if (algorithm.equals(FunctionName.QUESTION_ANSWERING)) {
280282
modelConfig = new QuestionAnsweringModelConfig(input);
281-
} else {
283+
} else if (algorithm.equals(FunctionName.TEXT_EMBEDDING)) {
282284
modelConfig = new TextEmbeddingModelConfig(input);
285+
} else if (algorithm.equals(FunctionName.REMOTE)) {
286+
modelConfig = new RemoteModelConfig(input);
287+
} else {
288+
modelConfig = new BaseModelConfig(input);
283289
}
284290
}
285291
if (input.readBoolean()) {
@@ -623,8 +629,12 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
623629
modelConfig = MetricsCorrelationModelConfig.parse(parser);
624630
} else if (FunctionName.QUESTION_ANSWERING.name().equals(algorithmName)) {
625631
modelConfig = QuestionAnsweringModelConfig.parse(parser);
626-
} else {
632+
} else if (FunctionName.TEXT_EMBEDDING.name().equals(algorithmName)) {
627633
modelConfig = TextEmbeddingModelConfig.parse(parser);
634+
} else if (FunctionName.REMOTE.name().equals(algorithmName)) {
635+
modelConfig = RemoteModelConfig.parse(parser);
636+
} else {
637+
modelConfig = BaseModelConfig.parse(parser);
628638
}
629639
break;
630640
case DEPLOY_SETTING_FIELD:
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.model;
7+
8+
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9+
import static org.opensearch.ml.common.CommonValue.VERSION_3_1_0;
10+
11+
import java.io.IOException;
12+
import java.util.Map;
13+
import java.util.Set;
14+
import java.util.stream.Collectors;
15+
16+
import org.opensearch.common.xcontent.XContentHelper;
17+
import org.opensearch.common.xcontent.XContentType;
18+
import org.opensearch.core.ParseField;
19+
import org.opensearch.core.common.io.stream.StreamInput;
20+
import org.opensearch.core.common.io.stream.StreamOutput;
21+
import org.opensearch.core.xcontent.NamedXContentRegistry;
22+
import org.opensearch.core.xcontent.ToXContent;
23+
import org.opensearch.core.xcontent.XContentBuilder;
24+
import org.opensearch.core.xcontent.XContentParser;
25+
26+
import lombok.Builder;
27+
import lombok.Getter;
28+
import lombok.Setter;
29+
30+
/**
31+
* Base configuration class for ML local models. This class handles
32+
* the basic configuration parameters that every local model can support.
33+
*/
34+
@Setter
35+
@Getter
36+
public class BaseModelConfig extends MLModelConfig {
37+
public static final String PARSE_FIELD_NAME = "base";
38+
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
39+
BaseModelConfig.class,
40+
new ParseField(PARSE_FIELD_NAME),
41+
it -> parse(it)
42+
);
43+
44+
public static final String ADDITIONAL_CONFIG_FIELD = "additional_config";
45+
protected Map<String, Object> additionalConfig;
46+
47+
@Builder(builderMethodName = "baseModelConfigBuilder")
48+
public BaseModelConfig(String modelType, String allConfig, Map<String, Object> additionalConfig) {
49+
super(modelType, allConfig);
50+
this.additionalConfig = additionalConfig;
51+
validateNoDuplicateKeys(allConfig, additionalConfig);
52+
}
53+
54+
public static BaseModelConfig parse(XContentParser parser) throws IOException {
55+
String modelType = null;
56+
String allConfig = null;
57+
Map<String, Object> additionalConfig = null;
58+
59+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
60+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
61+
String fieldName = parser.currentName();
62+
parser.nextToken();
63+
64+
switch (fieldName) {
65+
case MODEL_TYPE_FIELD:
66+
modelType = parser.text();
67+
break;
68+
case ALL_CONFIG_FIELD:
69+
allConfig = parser.text();
70+
break;
71+
case ADDITIONAL_CONFIG_FIELD:
72+
additionalConfig = parser.map();
73+
break;
74+
default:
75+
parser.skipChildren();
76+
break;
77+
}
78+
}
79+
return new BaseModelConfig(modelType, allConfig, additionalConfig);
80+
}
81+
82+
@Override
83+
public String getWriteableName() {
84+
return PARSE_FIELD_NAME;
85+
}
86+
87+
public BaseModelConfig(StreamInput in) throws IOException {
88+
super(in);
89+
if (in.getVersion().onOrAfter(VERSION_3_1_0)) {
90+
this.additionalConfig = in.readMap();
91+
}
92+
}
93+
94+
@Override
95+
public void writeTo(StreamOutput out) throws IOException {
96+
super.writeTo(out);
97+
if (out.getVersion().onOrAfter(VERSION_3_1_0)) {
98+
out.writeMap(additionalConfig);
99+
}
100+
}
101+
102+
@Override
103+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
104+
builder.startObject();
105+
if (modelType != null) {
106+
builder.field(MODEL_TYPE_FIELD, modelType);
107+
}
108+
if (allConfig != null) {
109+
builder.field(ALL_CONFIG_FIELD, allConfig);
110+
}
111+
if (additionalConfig != null) {
112+
builder.field(ADDITIONAL_CONFIG_FIELD, additionalConfig);
113+
}
114+
builder.endObject();
115+
return builder;
116+
}
117+
118+
protected void validateNoDuplicateKeys(String allConfig, Map<String, Object> additionalConfig) {
119+
if (allConfig == null || additionalConfig == null || additionalConfig.isEmpty()) {
120+
return;
121+
}
122+
123+
Map<String, Object> allConfigMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), allConfig, false);
124+
Set<String> duplicateKeys = allConfigMap.keySet().stream().filter(additionalConfig::containsKey).collect(Collectors.toSet());
125+
if (!duplicateKeys.isEmpty()) {
126+
throw new IllegalArgumentException(
127+
"Duplicate keys found in both all_config and additional_config: " + String.join(", ", duplicateKeys)
128+
);
129+
}
130+
}
131+
132+
public Map<String, Object> getAdditionalConfig() {
133+
return this.additionalConfig;
134+
}
135+
}

0 commit comments

Comments
 (0)