Skip to content

Commit 68101a8

Browse files
authored
Merge pull request #233 from mdevino/chat-standalone-endpoint
Chat standalone endpoint
2 parents a9e62c2 + 47a8dbd commit 68101a8

File tree

5 files changed

+297
-20
lines changed

5 files changed

+297
-20
lines changed

src/clients/detector/text_chat.rs

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616
*/
1717

1818
use async_trait::async_trait;
19+
use hyper::{HeaderMap, StatusCode};
20+
use serde::Serialize;
1921

20-
use super::DEFAULT_PORT;
22+
use super::{DetectorError, DEFAULT_PORT, DETECTOR_ID_HEADER_NAME};
2123
use crate::{
22-
clients::{create_http_client, Client, HttpClient},
24+
clients::{create_http_client, openai::Message, Client, Error, HttpClient},
2325
config::ServiceConfig,
2426
health::HealthCheckResult,
27+
models::{DetectionResult, DetectorParams},
2528
};
2629

30+
const CHAT_DETECTOR_ENDPOINT: &str = "/api/v1/text/chat";
31+
2732
#[cfg_attr(test, faux::create)]
2833
#[derive(Clone)]
2934
pub struct TextChatDetectorClient {
@@ -46,9 +51,37 @@ impl TextChatDetectorClient {
4651
}
4752
}
4853

49-
pub async fn text_chat(&self) {
50-
let _url = self.client.base_url().join("/api/v1/text/chat").unwrap();
51-
todo!()
54+
pub async fn text_chat(
55+
&self,
56+
model_id: &str,
57+
request: ChatDetectionRequest,
58+
headers: HeaderMap,
59+
) -> Result<Vec<DetectionResult>, Error> {
60+
let url = self.client.base_url().join(CHAT_DETECTOR_ENDPOINT).unwrap();
61+
let request = self
62+
.client
63+
.post(url)
64+
.headers(headers)
65+
.header(DETECTOR_ID_HEADER_NAME, model_id)
66+
.json(&request);
67+
68+
tracing::debug!("Request being sent to chat detector: {:?}", request);
69+
let response = request.send().await?;
70+
tracing::debug!("Response received from chat detector: {:?}", response);
71+
72+
if response.status() == StatusCode::OK {
73+
Ok(response.json().await?)
74+
} else {
75+
let code = response.status().as_u16();
76+
let error = response
77+
.json::<DetectorError>()
78+
.await
79+
.unwrap_or(DetectorError {
80+
code,
81+
message: "".into(),
82+
});
83+
Err(error.into())
84+
}
5285
}
5386
}
5487

@@ -67,3 +100,23 @@ impl Client for TextChatDetectorClient {
67100
}
68101
}
69102
}
103+
104+
/// A struct representing a request to a detector compatible with the
105+
/// /api/v1/text/chat endpoint.
106+
// #[cfg_attr(test, derive(PartialEq))]
107+
#[derive(Debug, Serialize)]
108+
pub struct ChatDetectionRequest {
109+
/// Chat messages to run detection on
110+
pub messages: Vec<Message>,
111+
112+
pub detector_params: DetectorParams,
113+
}
114+
115+
impl ChatDetectionRequest {
116+
pub fn new(messages: Vec<Message>, detector_params: DetectorParams) -> Self {
117+
Self {
118+
messages,
119+
detector_params,
120+
}
121+
}
122+
}

src/models.rs

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ use std::collections::HashMap;
2222
use serde::{Deserialize, Serialize};
2323

2424
use crate::{
25-
clients::detector::{ContentAnalysisResponse, ContextType},
25+
clients::{
26+
self,
27+
detector::{ContentAnalysisResponse, ContextType},
28+
openai::Content,
29+
},
2630
health::HealthCheckCache,
2731
pb,
2832
};
@@ -939,6 +943,79 @@ pub struct ContextDocsResult {
939943
pub detections: Vec<DetectionResult>,
940944
}
941945

946+
/// The request format expected in the /api/v2/text/detect/generated endpoint.
947+
#[derive(Clone, Debug, Serialize, Deserialize)]
948+
pub struct ChatDetectionHttpRequest {
949+
/// The map of detectors to be used, along with their respective parameters, e.g. thresholds.
950+
pub detectors: HashMap<String, DetectorParams>,
951+
952+
// The list of messages to run detections on.
953+
pub messages: Vec<clients::openai::Message>,
954+
}
955+
956+
impl ChatDetectionHttpRequest {
957+
/// Upfront validation of user request
958+
pub fn validate(&self) -> Result<(), ValidationError> {
959+
// Validate required parameters
960+
if self.detectors.is_empty() {
961+
return Err(ValidationError::Required("detectors".into()));
962+
}
963+
if self.messages.is_empty() {
964+
return Err(ValidationError::Required("messages".into()));
965+
}
966+
967+
Ok(())
968+
}
969+
970+
/// Validates for the "/api/v1/text/chat" endpoint.
971+
pub fn validate_for_text(&self) -> Result<(), ValidationError> {
972+
self.validate()?;
973+
self.validate_messages()?;
974+
validate_detector_params(&self.detectors)?;
975+
976+
Ok(())
977+
}
978+
979+
/// Validates if message contents are either a string or a content type of type "text"
980+
fn validate_messages(&self) -> Result<(), ValidationError> {
981+
for message in &self.messages {
982+
match &message.content {
983+
Some(content) => self.validate_content_type(content)?,
984+
None => {
985+
return Err(ValidationError::Invalid(
986+
"Message content cannot be empty".into(),
987+
))
988+
}
989+
}
990+
}
991+
Ok(())
992+
}
993+
994+
/// Validates if content type array contains only text messages
995+
fn validate_content_type(&self, content: &Content) -> Result<(), ValidationError> {
996+
match content {
997+
Content::Array(content) => {
998+
for content_part in content {
999+
if content_part.r#type != "text" {
1000+
return Err(ValidationError::Invalid(
1001+
"Only content of type text is allowed".into(),
1002+
));
1003+
}
1004+
}
1005+
Ok(())
1006+
}
1007+
Content::String(_) => Ok(()), // if message.content is a string, it is a valid message
1008+
}
1009+
}
1010+
}
1011+
1012+
/// The response format of the /api/v2/text/detection/chat endpoint
1013+
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
1014+
pub struct ChatDetectionResult {
1015+
/// Detection results
1016+
pub detections: Vec<DetectionResult>,
1017+
}
1018+
9421019
/// The request format expected in the /api/v2/text/detect/generated endpoint.
9431020
#[derive(Clone, Debug, Serialize, Deserialize)]
9441021
pub struct DetectionOnGeneratedHttpRequest {

src/orchestrator.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use uuid::Uuid;
2929

3030
use crate::{
3131
clients::{
32+
self,
3233
chunker::ChunkerClient,
3334
detector::{
3435
text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient,
@@ -40,9 +41,9 @@ use crate::{
4041
config::{DetectorType, GenerationProvider, OrchestratorConfig},
4142
health::HealthCheckCache,
4243
models::{
43-
ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams,
44-
GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest,
45-
GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest,
44+
ChatDetectionHttpRequest, ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest,
45+
DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig,
46+
GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest,
4647
},
4748
};
4849

@@ -382,6 +383,33 @@ impl ContextDocsDetectionTask {
382383
}
383384
}
384385

386+
/// Task for the /api/v2/text/detection/chat endpoint
387+
#[derive(Debug)]
388+
pub struct ChatDetectionTask {
389+
/// Request unique identifier
390+
pub request_id: Uuid,
391+
392+
/// Detectors configuration
393+
pub detectors: HashMap<String, DetectorParams>,
394+
395+
// Messages to run detection on
396+
pub messages: Vec<clients::openai::Message>,
397+
398+
// Headermap
399+
pub headers: HeaderMap,
400+
}
401+
402+
impl ChatDetectionTask {
403+
pub fn new(request_id: Uuid, request: ChatDetectionHttpRequest, headers: HeaderMap) -> Self {
404+
Self {
405+
request_id,
406+
detectors: request.detectors,
407+
messages: request.messages,
408+
headers,
409+
}
410+
}
411+
}
412+
385413
/// Task for the /api/v2/text/detection/generated endpoint
386414
#[derive(Debug)]
387415
pub struct DetectionOnGenerationTask {

src/orchestrator/unary.rs

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,28 @@ use futures::{
2525
use tracing::{debug, error, info, instrument};
2626

2727
use super::{
28-
apply_masks, get_chunker_ids, Chunk, ClassificationWithGenTask, Context,
28+
apply_masks, get_chunker_ids, ChatDetectionTask, Chunk, ClassificationWithGenTask, Context,
2929
ContextDocsDetectionTask, DetectionOnGenerationTask, Error, GenerationWithDetectionTask,
3030
Orchestrator, TextContentDetectionTask,
3131
};
3232
use crate::{
3333
clients::{
3434
chunker::{tokenize_whole_doc, ChunkerClient, DEFAULT_CHUNKER_ID},
3535
detector::{
36-
ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest,
37-
ContextType, GenerationDetectionRequest, TextContentsDetectorClient,
38-
TextContextDocDetectorClient, TextGenerationDetectorClient,
36+
ChatDetectionRequest, ContentAnalysisRequest, ContentAnalysisResponse,
37+
ContextDocsDetectionRequest, ContextType, GenerationDetectionRequest,
38+
TextChatDetectorClient, TextContentsDetectorClient, TextContextDocDetectorClient,
39+
TextGenerationDetectorClient,
3940
},
41+
openai::Message,
4042
GenerationClient,
4143
},
4244
models::{
43-
ClassifiedGeneratedTextResult, ContextDocsResult, DetectionOnGenerationResult,
44-
DetectionResult, DetectorParams, GenerationWithDetectionResult,
45-
GuardrailsTextGenerationParameters, InputWarning, InputWarningReason,
46-
TextContentDetectionResult, TextGenTokenClassificationResults, TokenClassificationResult,
45+
ChatDetectionResult, ClassifiedGeneratedTextResult, ContextDocsResult,
46+
DetectionOnGenerationResult, DetectionResult, DetectorParams,
47+
GenerationWithDetectionResult, GuardrailsTextGenerationParameters, InputWarning,
48+
InputWarningReason, TextContentDetectionResult, TextGenTokenClassificationResults,
49+
TokenClassificationResult,
4750
},
4851
orchestrator::UNSUITABLE_INPUT_MESSAGE,
4952
pb::caikit::runtime::chunkers,
@@ -447,6 +450,61 @@ impl Orchestrator {
447450
}
448451
}
449452
}
453+
454+
/// Handles detections on chat messages (without performing generation)
455+
pub async fn handle_chat_detection(
456+
&self,
457+
task: ChatDetectionTask,
458+
) -> Result<ChatDetectionResult, Error> {
459+
info!(
460+
request_id = ?task.request_id,
461+
detectors = ?task.detectors,
462+
"handling detection on chat content task"
463+
);
464+
let ctx = self.ctx.clone();
465+
let headers = task.headers;
466+
467+
let task_handle = tokio::spawn(async move {
468+
// call detection
469+
let detections = try_join_all(
470+
task.detectors
471+
.iter()
472+
.map(|(detector_id, detector_params)| {
473+
let ctx = ctx.clone();
474+
let detector_id = detector_id.clone();
475+
let detector_params = detector_params.clone();
476+
let messages = task.messages.clone();
477+
let headers = headers.clone();
478+
async {
479+
detect_for_chat(ctx, detector_id, detector_params, messages, headers)
480+
.await
481+
}
482+
})
483+
.collect::<Vec<_>>(),
484+
)
485+
.await?
486+
.into_iter()
487+
.flatten()
488+
.collect::<Vec<_>>();
489+
490+
Ok(ChatDetectionResult { detections })
491+
});
492+
match task_handle.await {
493+
// Task completed successfully
494+
Ok(Ok(result)) => Ok(result),
495+
// Task failed, return error propagated from child task that failed
496+
Ok(Err(error)) => {
497+
error!(request_id = ?task.request_id, %error, "detection task on chat failed");
498+
Err(error)
499+
}
500+
// Task cancelled or panicked
501+
Err(error) => {
502+
let error = error.into();
503+
error!(request_id = ?task.request_id, %error, "detection task on chat failed");
504+
Err(error)
505+
}
506+
}
507+
}
450508
}
451509

452510
/// Handles input detection task.
@@ -711,6 +769,47 @@ pub async fn detect_for_generation(
711769
Ok::<Vec<DetectionResult>, Error>(response)
712770
}
713771

772+
/// Calls a detector that implements the /api/v1/text/chat endpoint
773+
pub async fn detect_for_chat(
774+
ctx: Arc<Context>,
775+
detector_id: String,
776+
detector_params: DetectorParams,
777+
messages: Vec<Message>,
778+
headers: HeaderMap,
779+
) -> Result<Vec<DetectionResult>, Error> {
780+
let detector_id = detector_id.clone();
781+
let threshold = detector_params.threshold().unwrap_or(
782+
detector_params.threshold().unwrap_or(
783+
ctx.config
784+
.detectors
785+
.get(&detector_id)
786+
.ok_or_else(|| Error::DetectorNotFound(detector_id.clone()))?
787+
.default_threshold,
788+
),
789+
);
790+
let request = ChatDetectionRequest::new(messages.clone(), detector_params.clone());
791+
debug!(%detector_id, ?request, "sending chat detector request");
792+
let client = ctx
793+
.clients
794+
.get_as::<TextChatDetectorClient>(&detector_id)
795+
.unwrap();
796+
let response = client
797+
.text_chat(&detector_id, request, headers)
798+
.await
799+
.map(|results| {
800+
results
801+
.into_iter()
802+
.filter(|detection| detection.score > threshold)
803+
.collect()
804+
})
805+
.map_err(|error| Error::DetectorRequestFailed {
806+
id: detector_id.clone(),
807+
error,
808+
})?;
809+
debug!(%detector_id, ?response, "received chat detector response");
810+
Ok::<Vec<DetectionResult>, Error>(response)
811+
}
812+
714813
/// Calls a detector that implements the /api/v1/text/doc endpoint
715814
pub async fn detect_for_context(
716815
ctx: Arc<Context>,

0 commit comments

Comments
 (0)