Skip to content

feat: Add an option for specifying model name #685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ Options:
Optionally control the number of tokenizer workers used for payload tokenization, validation and truncation. Default to the number of CPU cores on the machine

[env: TOKENIZATION_WORKERS=]

--served-model-name <SERVED_MODEL_NAME>
The name of the model that is returned when serving OpenAI requests. If not specified, defaults to value in model-id.

[env: SERVED_MODEL_NAME=]

--dtype <DTYPE>
The dtype to be forced upon the model
Expand Down
6 changes: 6 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@
"required": [
"model_id",
"model_dtype",
"served_model_name",
"model_type",
"max_concurrent_requests",
"max_input_length",
Expand Down Expand Up @@ -1107,6 +1108,11 @@
"description": "Model info",
"example": "thenlper/gte-base"
},
"served_model_name": {
"type": "string",
"description": "Model name specified by user",
"example": "thenlper/gte-base"
},
"model_sha": {
"type": "string",
"example": "fca14538aa9956a46526bd1d0d11d69e19b5a101",
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/cli_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ Options:

[env: DTYPE=]
[possible values: float16, float32]

--served-model-name <SERVED_MODEL_NAME>
The name of the model that is returned when serving OpenAI requests. If not specified, defaults to value in model-id.

[env: SERVED_MODEL_NAME=]

--pooling <POOLING>
Optionally control the pooling method for embedding models.
Expand Down
1 change: 1 addition & 0 deletions proto/tei.proto
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ message InfoResponse {
optional uint32 max_batch_requests = 11;
uint32 max_client_batch_size = 12;
uint32 tokenization_workers = 13;
optional string served_model_name = 14;
}

message Metadata {
Expand Down
1 change: 1 addition & 0 deletions router/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ impl grpc::info_server::Info for TextEmbeddingsService {
model_id: self.info.model_id.clone(),
model_sha: self.info.model_sha.clone(),
model_dtype: self.info.model_dtype.clone(),
served_model_name: self.info.served_model_name.clone(),
model_type: model_type.into(),
max_concurrent_requests: self.info.max_concurrent_requests as u32,
max_input_length: self.info.max_input_length as u32,
Expand Down
2 changes: 1 addition & 1 deletion router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ async fn openai_embed(
let response = OpenAICompatResponse {
object: "list",
data: embeddings,
model: info.model_id.clone(),
model: info.served_model_name.clone().unwrap_or_else(|| info.model_id.clone()),
usage: OpenAICompatUsage {
prompt_tokens: compute_tokens,
total_tokens: compute_tokens,
Expand Down
4 changes: 4 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub async fn run(
revision: Option<String>,
tokenization_workers: Option<usize>,
dtype: Option<DType>,
served_model_name: Option<String>,
pooling: Option<text_embeddings_backend::Pool>,
max_concurrent_requests: usize,
max_batch_tokens: usize,
Expand Down Expand Up @@ -279,6 +280,7 @@ pub async fn run(
model_id,
model_sha: revision,
model_dtype: dtype.to_string(),
served_model_name,
model_type,
max_concurrent_requests,
max_input_length,
Expand Down Expand Up @@ -493,6 +495,8 @@ pub struct Info {
pub model_sha: Option<String>,
#[cfg_attr(feature = "http", schema(example = "float16"))]
pub model_dtype: String,
#[cfg_attr(feature = "http", schema(example = "thenlper/gte-base"))]
pub served_model_name: Option<String>,
pub model_type: ModelType,
/// Router Parameters
#[cfg_attr(feature = "http", schema(example = "128"))]
Expand Down
6 changes: 6 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ struct Args {
#[clap(long, env, value_enum)]
dtype: Option<DType>,

/// The name of the model that is being served. If not specified, defaults to
/// model-id.
#[clap(long, env)]
served_model_name: Option<String>,

/// Optionally control the pooling method for embedding models.
///
/// If `pooling` is not set, the pooling configuration will be parsed from the
Expand Down Expand Up @@ -214,6 +219,7 @@ async fn main() -> Result<()> {
args.revision,
args.tokenization_workers,
args.dtype,
args.served_model_name,
args.pooling,
args.max_concurrent_requests,
args.max_batch_tokens,
Expand Down
1 change: 1 addition & 0 deletions router/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub async fn start_server(model_id: String, revision: Option<String>, dtype: DTy
Some(1),
Some(dtype),
None,
None,
4,
1024,
None,
Expand Down