Skip to content

Commit 175bf51

Browse files
authored
Add Role enum to openai module (#287)
* updated role from string to enum for chat messages Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com> * corrected role in chat completion delta to optional Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com> --------- Signed-off-by: Shonda-Adena-Witherspoon <shonda.adena.witherspoon@ibm.com>
1 parent 408badf commit 175bf51

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

src/clients/openai.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,21 @@ pub struct JsonSchemaObject {
386386
pub required: Option<Vec<String>>,
387387
}
388388

389+
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
390+
#[serde(rename_all = "lowercase")]
391+
pub enum Role {
392+
#[default]
393+
User,
394+
Developer,
395+
Assistant,
396+
System,
397+
Tool,
398+
}
399+
389400
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
390401
pub struct Message {
391-
/// The role of the messages author.
392-
pub role: String,
402+
/// The role of the author of this message.
403+
pub role: Role,
393404
/// The contents of the message.
394405
#[serde(skip_serializing_if = "Option::is_none")]
395406
pub content: Option<Content>,
@@ -552,7 +563,7 @@ pub struct ChatCompletionChoice {
552563
#[derive(Debug, Clone, Serialize, Deserialize)]
553564
pub struct ChatCompletionMessage {
554565
/// The role of the author of this message.
555-
pub role: String,
566+
pub role: Role,
556567
/// The contents of the message.
557568
pub content: Option<String>,
558569
/// The tool calls generated by the model, such as function calls.
@@ -635,7 +646,7 @@ pub struct ChatCompletionChunkChoice {
635646
pub struct ChatCompletionDelta {
636647
/// The role of the author of this message.
637648
#[serde(skip_serializing_if = "Option::is_none")]
638-
pub role: Option<String>,
649+
pub role: Option<Role>,
639650
/// The contents of the message.
640651
#[serde(skip_serializing_if = "Option::is_none")]
641652
pub content: Option<String>,

src/orchestrator/chat_completions_detection.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use crate::{
3232
detector::{ChatDetectionRequest, ContentAnalysisRequest},
3333
openai::{
3434
ChatCompletion, ChatCompletionChoice, ChatCompletionsRequest, ChatCompletionsResponse,
35-
ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning,
35+
ChatDetections, Content, InputDetectionResult, OpenAiClient, OrchestratorWarning, Role,
3636
},
3737
},
3838
config::DetectorType,
@@ -51,7 +51,7 @@ pub struct ChatMessageInternal {
5151
/// Index of the message
5252
pub message_index: usize,
5353
/// The role of the messages author.
54-
pub role: String,
54+
pub role: Role,
5555
/// The contents of the message.
5656
#[serde(skip_serializing_if = "Option::is_none")]
5757
pub content: Option<Content>,
@@ -425,7 +425,7 @@ mod tests {
425425
let messages = vec![ChatMessageInternal {
426426
message_index: 0,
427427
content: Some(Content::Text("hello".to_string())),
428-
role: "assistant".to_string(),
428+
role: Role::Assistant,
429429
..Default::default()
430430
}];
431431
let processed_messages = preprocess_chat_messages(&ctx, &detectors, messages).unwrap();
@@ -458,7 +458,7 @@ mod tests {
458458
message_index: 0,
459459
content: Some(Content::Text("hello".to_string())),
460460
// Invalid role will return error used for testing
461-
role: "foo".to_string(),
461+
role: Role::Tool,
462462
..Default::default()
463463
}];
464464

src/orchestrator/detector_processing/content.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
1616
*/
1717
use crate::{
18-
clients::openai::Content, models::ValidationError,
18+
clients::openai::{Content, Role},
19+
models::ValidationError,
1920
orchestrator::chat_completions_detection::ChatMessageInternal,
2021
};
2122

@@ -38,7 +39,7 @@ pub fn filter_chat_messages(
3839
));
3940
}
4041
// 2. Role is user | assistant | system
41-
if !matches!(message.role.as_str(), "user" | "assistant" | "system") {
42+
if !matches!(message.role, Role::User | Role::Assistant | Role::System) {
4243
return Err(ValidationError::Invalid(
4344
"Last message role must be user, assistant, or system".into(),
4445
));
@@ -62,7 +63,7 @@ mod tests {
6263
let message = vec![ChatMessageInternal {
6364
message_index: 0,
6465
content: Some(Content::Text("hello".to_string())),
65-
role: "assistant".to_string(),
66+
role: Role::Assistant,
6667
..Default::default()
6768
}];
6869

@@ -79,13 +80,13 @@ mod tests {
7980
ChatMessageInternal {
8081
message_index: 0,
8182
content: Some(Content::Text("hello".to_string())),
82-
role: "assistant".to_string(),
83+
role: Role::Assistant,
8384
..Default::default()
8485
},
8586
ChatMessageInternal {
8687
message_index: 1,
8788
content: Some(Content::Text("bot".to_string())),
88-
role: "assistant".to_string(),
89+
role: Role::Assistant,
8990
..Default::default()
9091
},
9192
];
@@ -102,7 +103,7 @@ mod tests {
102103
let message = vec![ChatMessageInternal {
103104
message_index: 0,
104105
content: Some(Content::Text("hello".to_string())),
105-
role: "invalid_role".to_string(),
106+
role: Role::Tool,
106107
..Default::default()
107108
}];
108109

0 commit comments

Comments
 (0)