Skip to content

Commit cab7649

Browse files
committed
feat: allow for specifying model on CreateMessage reqs
1 parent 956bfa8 commit cab7649

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

server/src/data/models.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4619,6 +4619,7 @@ impl ApiKeyRequestParams {
46194619
use_agentic_search: payload.use_agentic_search,
46204620
only_include_docs_used: payload.only_include_docs_used,
46214621
number_of_messages_to_include: payload.number_of_messages_to_include,
4622+
model: payload.model,
46224623
}
46234624
}
46244625

@@ -9506,6 +9507,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
95069507
typo_options: Option<TypoOptions>,
95079508
pub only_include_docs_used: Option<bool>,
95089509
pub number_of_messages_to_include: Option<u64>,
9510+
pub model: Option<String>,
95099511
}
95109512

95119513
let mut helper = Helper::deserialize(deserializer)?;
@@ -9548,6 +9550,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
95489550
use_agentic_search: helper.use_agentic_search,
95499551
only_include_docs_used: helper.only_include_docs_used,
95509552
number_of_messages_to_include: helper.number_of_messages_to_include,
9553+
model: helper.model,
95519554
})
95529555
}
95539556
}
@@ -9584,6 +9587,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
95849587
pub use_agentic_search: Option<bool>,
95859588
pub only_include_docs_used: Option<bool>,
95869589
pub number_of_messages_to_include: Option<u64>,
9590+
pub model: Option<String>,
95879591
}
95889592

95899593
let mut helper = Helper::deserialize(deserializer)?;
@@ -9623,6 +9627,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
96239627
use_agentic_search: helper.use_agentic_search,
96249628
only_include_docs_used: helper.only_include_docs_used,
96259629
number_of_messages_to_include: helper.number_of_messages_to_include,
9630+
model: helper.model,
96269631
})
96279632
}
96289633
}
@@ -9663,6 +9668,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
96639668
pub use_agentic_search: Option<bool>,
96649669
pub only_include_docs_used: Option<bool>,
96659670
pub number_of_messages_to_include: Option<u64>,
9671+
pub model: Option<String>,
96669672
}
96679673

96689674
let mut helper = Helper::deserialize(deserializer)?;
@@ -9706,6 +9712,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
97069712
use_agentic_search: helper.use_agentic_search,
97079713
only_include_docs_used: helper.only_include_docs_used,
97089714
number_of_messages_to_include: helper.number_of_messages_to_include,
9715+
model: helper.model,
97099716
})
97109717
}
97119718
}

server/src/handlers/message_handler.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ pub struct CreateMessageReqPayload {
138138
pub use_agentic_search: Option<bool>,
139139
/// Number of messages to include in the context window. If not specified, this defaults to 10.
140140
pub number_of_messages_to_include: Option<u64>,
141+
/// Model name to use for the completion. If not specified, this defaults to the dataset's model.
142+
pub model: Option<String>,
141143
}
142144

143145
/// Create message
@@ -451,6 +453,8 @@ pub struct RegenerateMessageReqPayload {
451453
pub use_agentic_search: Option<bool>,
452454
/// Number of messages to include in the context window. If not specified, this defaults to 10.
453455
pub number_of_messages_to_include: Option<u64>,
456+
/// Model name to use for the completion. If not specified, this defaults to the dataset's model.
457+
pub model: Option<String>,
454458
}
455459

456460
#[derive(Serialize, Debug, ToSchema)]
@@ -509,6 +513,8 @@ pub struct EditMessageReqPayload {
509513
pub use_agentic_search: Option<bool>,
510514
/// Number of messages to include in the context window. If not specified, this defaults to 10.
511515
pub number_of_messages_to_include: Option<u64>,
516+
/// Model name to use for the completion. If not specified, this defaults to the dataset's model.
517+
pub model: Option<String>,
512518
}
513519

514520
impl From<EditMessageReqPayload> for CreateMessageReqPayload {
@@ -540,6 +546,7 @@ impl From<EditMessageReqPayload> for CreateMessageReqPayload {
540546
use_agentic_search: data.use_agentic_search,
541547
only_include_docs_used: data.only_include_docs_used,
542548
number_of_messages_to_include: data.number_of_messages_to_include,
549+
model: data.model,
543550
}
544551
}
545552
}
@@ -573,6 +580,7 @@ impl From<RegenerateMessageReqPayload> for CreateMessageReqPayload {
573580
use_agentic_search: data.use_agentic_search,
574581
only_include_docs_used: data.only_include_docs_used,
575582
number_of_messages_to_include: data.number_of_messages_to_include,
583+
model: data.model,
576584
}
577585
}
578586
}

server/src/operators/message_operator.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,10 @@ pub async fn stream_response(
854854
};
855855

856856
let rag_prompt = dataset_config.RAG_PROMPT.clone();
857-
let chosen_model = dataset_config.LLM_DEFAULT_MODEL.clone();
857+
let chosen_model = create_message_req_payload
858+
.model
859+
.clone()
860+
.unwrap_or(dataset_config.LLM_DEFAULT_MODEL.clone());
858861

859862
let (search_event, score_chunks) = get_rag_chunks_query(
860863
create_message_req_payload.clone(),
@@ -1894,8 +1897,10 @@ pub async fn stream_response_with_agentic_search(
18941897
messages_len
18951898
};
18961899

1897-
let chosen_model = dataset_config.LLM_DEFAULT_MODEL.clone();
1898-
1900+
let chosen_model = create_message_req_payload
1901+
.model
1902+
.clone()
1903+
.unwrap_or(dataset_config.LLM_DEFAULT_MODEL.clone());
18991904
let tools = vec![ChatCompletionTool {
19001905
r#type: ChatCompletionToolType::Function,
19011906
function: ChatCompletionFunction {

0 commit comments

Comments
 (0)