Skip to content

Commit e3ae3f0

Browse files
committed
Don't convert schema-defined strings to other types during validation
Signed-off-by: Daniel Widdis <widdis@gmail.com>
1 parent c3b1fd6 commit e3ae3f0

File tree

3 files changed

+58
-18
lines changed

3 files changed

+58
-18
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ public void validateInputSchema(String modelId, MLInput mlInput) {
262262
try {
263263
String InputString = mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString();
264264
// Process the parameters field in the input dataset to convert it back to its original datatype, instead of a string
265-
String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString);
265+
String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString, inputSchemaString);
266266
MLNodeUtils.validateSchema(inputSchemaString, processedInputString);
267267
} catch (Exception e) {
268268
throw new OpenSearchStatusException(

plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,18 @@ public static void validateSchema(String schemaString, String instanceString) th
9191
}
9292

9393
/**
94-
* This method processes the input JSON string and replaces the string values of the parameters with JSON objects if the string is a valid JSON.
94+
* This method processes the input JSON string and replaces the string values of the parameters with JSON objects if the string is a valid JSON, unless the schema defines the value as a string.
9595
* @param inputJson The input JSON string
96+
* @param schemaJson The schema matching the input JSON string
9697
* @return The processed JSON string
9798
*/
98-
public static String processRemoteInferenceInputDataSetParametersValue(String inputJson) throws IOException {
99+
public static String processRemoteInferenceInputDataSetParametersValue(String inputJson, String schemaJson) throws IOException {
99100
ObjectMapper mapper = new ObjectMapper();
100101
JsonNode rootNode = mapper.readTree(inputJson);
102+
JsonNode schemaNode = mapper.readTree(schemaJson);
103+
104+
// Get the schema properties for parameters if they exist
105+
JsonNode parametersSchema = schemaNode.path("properties").path("parameters").path("properties");
101106

102107
if (rootNode.has("parameters") && rootNode.get("parameters").isObject()) {
103108
ObjectNode parametersNode = (ObjectNode) rootNode.get("parameters");
@@ -106,15 +111,12 @@ public static String processRemoteInferenceInputDataSetParametersValue(String in
106111
String key = entry.getKey();
107112
JsonNode value = entry.getValue();
108113

109-
if (value.isTextual()) {
110-
String textValue = value.asText();
114+
if (value.isTextual() && !isStringTypeInSchema(parametersSchema, key)) {
111115
try {
112-
// Try to parse the string as JSON
113-
JsonNode parsedValue = mapper.readTree(textValue);
114-
// If successful, replace the string with the parsed JSON
116+
JsonNode parsedValue = mapper.readTree(value.asText());
115117
parametersNode.set(key, parsedValue);
116118
} catch (IOException e) {
117-
// If parsing fails, it's not a valid JSON string, so keep it as is
119+
// If parsing fails, keep it as is
118120
parametersNode.set(key, value);
119121
}
120122
}
@@ -123,6 +125,11 @@ public static String processRemoteInferenceInputDataSetParametersValue(String in
123125
return mapper.writeValueAsString(rootNode);
124126
}
125127

128+
private static boolean isStringTypeInSchema(JsonNode schema, String fieldName) {
129+
JsonNode typeNode = schema.path(fieldName).path("type");
130+
return typeNode.isTextual() && typeNode.asText().equals("string");
131+
}
132+
126133
public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats) {
127134
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
128135
if (openCircuitBreaker != null) {

plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,61 +118,94 @@ public void testValidateRemoteInputWithTitanMultiModalRemoteSchema() throws IOEx
118118

119119
@Test
120120
public void testProcessRemoteInferenceInputDataSetParametersValueNoParameters() throws IOException {
121+
String schema = "{\"type\": \"object\",\"properties\": {}}";
121122
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true}";
122-
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
123+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
123124
assertEquals(json, processedJson);
124125
}
125126

126127
@Test
127128
public void testProcessRemoteInferenceInputDataSetInvalidJson() {
129+
String schema = "{\"type\": \"object\",\"properties\": {}}";
128130
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"a\"}}";
129-
assertThrows(JsonParseException.class, () -> MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json));
131+
assertThrows(JsonParseException.class, () -> MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema));
130132
}
131133

132134
@Test
133135
public void testProcessRemoteInferenceInputDataSetEmptyParameters() throws IOException {
136+
String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\"}}}";
134137
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{}}";
135-
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
138+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
136139
assertEquals(json, processedJson);
137140
}
138141

139142
@Test
140143
public void testProcessRemoteInferenceInputDataSetParametersValueParametersWrongType() throws IOException {
144+
String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"array\"}}}";
141145
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":[\"Hello\",\"world\"]}";
142-
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
146+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
143147
assertEquals(json, processedJson);
144148
}
145149

146150
@Test
147151
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessArray() throws IOException {
152+
String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {"
153+
+ "\"texts\": {\"type\": \"array\",\"items\": {\"type\": \"string\"}}"
154+
+ "}}}}";
148155
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":\"[\\\"Hello\\\",\\\"world\\\"]\"}}";
149156
String expectedJson = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":[\"Hello\",\"world\"]}}";
150-
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
157+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
151158
assertEquals(expectedJson, processedJson);
152159
}
153160

154161
@Test
155162
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessObject() throws IOException {
163+
String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {"
164+
+ "\"messages\": {\"type\": \"object\"}"
165+
+ "}}}}";
156166
String json =
157-
"{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":\"{\\\"role\\\":\\\"system\\\",\\\"foo\\\":\\\"{\\\\\\\"a\\\\\\\": \\\\\\\"b\\\\\\\"}\\\",\\\"content\\\":{\\\"a\\\":\\\"b\\\"}}\"}}}";
167+
"{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":\"{\\\"role\\\":\\\"system\\\",\\\"foo\\\":\\\"{\\\\\\\"a\\\\\\\": \\\\\\\"b\\\\\\\"}\\\",\\\"content\\\":{\\\"a\\\":\\\"b\\\"}}\"}}";
158168
String expectedJson =
159169
"{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":{\"role\":\"system\",\"foo\":\"{\\\"a\\\": \\\"b\\\"}\",\"content\":{\"a\":\"b\"}}}}";
160-
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
170+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
161171
assertEquals(expectedJson, processedJson);
162172
}
163173

174+
@Test
175+
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersQuotedNumber() throws IOException {
176+
String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {"
177+
+ "\"key1\": {\"type\": \"string\"},"
178+
+ "\"key2\": {\"type\": \"integer\"},"
179+
+ "\"key3\": {\"type\": \"boolean\"}"
180+
+ "}}}}";
181+
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"123\",\"key2\":123,\"key3\":true}}";
182+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
183+
assertEquals(json, processedJson);
184+
}
185+
164186
@Test
165187
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersNoProcess() throws IOException {
188+
String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {"
189+
+ "\"key1\": {\"type\": \"string\"},"
190+
+ "\"key2\": {\"type\": \"integer\"},"
191+
+ "\"key3\": {\"type\": \"boolean\"}"
192+
+ "}}}}";
166193
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true}}";
167-
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
194+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
168195
assertEquals(json, processedJson);
169196
}
170197

171198
@Test
172199
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersInvalidJson() throws IOException {
200+
String schema = "{\"type\": \"object\",\"properties\": {\"parameters\": {\"type\": \"object\",\"properties\": {"
201+
+ "\"key1\": {\"type\": \"string\"},"
202+
+ "\"key2\": {\"type\": \"integer\"},"
203+
+ "\"key3\": {\"type\": \"boolean\"},"
204+
+ "\"texts\": {\"type\": \"array\"}"
205+
+ "}}}}";
173206
String json =
174207
"{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"texts\":\"[\\\"Hello\\\",\\\"world\\\"\"}}";
175-
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
208+
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json, schema);
176209
assertEquals(json, processedJson);
177210
}
178211
}

0 commit comments

Comments
 (0)