Skip to content

Commit c202774

Browse files
authored
Add error handling for unknown fields for chat completion detection request (#296)
* added serde deny_unknown_fields attribute to relevant chat completion request Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com> * added validation unit tests Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com> * modidified json_data to remove actual model and detector ids in unit tests Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com> * removed async from chat completions unit test Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com> --------- Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com> Co-authored-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>
1 parent d82b00d commit c202774

File tree

2 files changed

+90
-4
lines changed

2 files changed

+90
-4
lines changed

src/clients/openai.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ impl From<ChatCompletion> for ChatCompletionsResponse {
172172
}
173173

174174
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
175+
#[serde(deny_unknown_fields)]
175176
pub struct ChatCompletionsRequest {
176177
/// A list of messages comprising the conversation so far.
177178
pub messages: Vec<Message>,
@@ -290,6 +291,7 @@ pub struct ChatCompletionsRequest {
290291

291292
/// Structure to contain parameters for detectors.
292293
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
294+
#[serde(deny_unknown_fields)]
293295
pub struct DetectorConfig {
294296
#[serde(skip_serializing_if = "Option::is_none")]
295297
pub input: Option<HashMap<String, DetectorParams>>,
@@ -398,6 +400,7 @@ pub enum Role {
398400
}
399401

400402
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
403+
#[serde(deny_unknown_fields)]
401404
pub struct Message {
402405
/// The role of the author of this message.
403406
pub role: Role,

src/orchestrator/chat_completions_detection.rs

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,8 @@ mod tests {
396396

397397
// Test to verify preprocess_chat_messages works correctly for multiple content type detectors
398398
// with single message in chat request
399-
#[tokio::test]
400-
async fn pretest_process_chat_messages_multiple_content_detector() {
399+
#[test]
400+
fn pretest_process_chat_messages_multiple_content_detector() {
401401
// Test setup
402402
let clients = ClientMap::new();
403403
let detector_1_id = "detector1";
@@ -436,8 +436,8 @@ mod tests {
436436

437437
// Test preprocess_chat_messages returns error correctly for multiple content type detectors
438438
// with incorrect message requirements
439-
#[tokio::test]
440-
async fn pretest_process_chat_messages_error_handling() {
439+
#[test]
440+
fn pretest_process_chat_messages_error_handling() {
441441
// Test setup
442442
let clients = ClientMap::new();
443443
let detector_1_id = "detector1";
@@ -473,4 +473,87 @@ mod tests {
473473
"validation error: Last message role must be user, assistant, or system"
474474
);
475475
}
476+
// validate chat completions request with invalid fields
477+
// (nonexistant fields or typos)
478+
#[test]
479+
fn test_validate() {
480+
// Additional unknown field (additional_field)
481+
let json_data = r#"
482+
{
483+
"messages": [
484+
{
485+
"content": "this is a nice sentence",
486+
"role": "user",
487+
"name": "string"
488+
}
489+
],
490+
"model": "my_model",
491+
"additional_field": "test",
492+
"n": 1,
493+
"temperature": 1,
494+
"top_p": 1,
495+
"user": "user-1234",
496+
"detectors": {
497+
"input": {}
498+
}
499+
}
500+
"#;
501+
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
502+
assert!(result.is_err());
503+
let error = result.unwrap_err().to_string();
504+
assert!(error
505+
.to_string()
506+
.contains("unknown field `additional_field"));
507+
508+
// Additional unknown field (additional_message")
509+
let json_data = r#"
510+
{
511+
"messages": [
512+
{
513+
"content": "this is a nice sentence",
514+
"role": "user",
515+
"name": "string",
516+
"additional_msg: "test"
517+
}
518+
],
519+
"model": "my_model",
520+
"n": 1,
521+
"temperature": 1,
522+
"top_p": 1,
523+
"user": "user-1234",
524+
"detectors": {
525+
"input": {}
526+
}
527+
}
528+
"#;
529+
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
530+
assert!(result.is_err());
531+
let error = result.unwrap_err().to_string();
532+
assert!(error.to_string().contains("unknown field `additional_msg"));
533+
534+
// Additional unknown field (typo for input field in detectors)
535+
let json_data = r#"
536+
{
537+
"messages": [
538+
{
539+
"content": "this is a nice sentence",
540+
"role": "user",
541+
"name": "string"
542+
}
543+
],
544+
"model": "my_model",
545+
"n": 1,
546+
"temperature": 1,
547+
"top_p": 1,
548+
"user": "user-1234",
549+
"detectors": {
550+
"inputs": {}
551+
}
552+
}
553+
"#;
554+
let result: Result<ChatCompletionsRequest, _> = serde_json::from_str(json_data);
555+
assert!(result.is_err());
556+
let error = result.unwrap_err().to_string();
557+
assert!(error.to_string().contains("unknown field `inputs"));
558+
}
476559
}

0 commit comments

Comments
 (0)