Skip to content

Commit 90ea664

Browse files
feat: add /decode route (#212)
1 parent a57cf61 commit 90ea664

File tree

7 files changed

+397
-18
lines changed

7 files changed

+397
-18
lines changed

core/src/infer.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ impl Infer {
7474
})
7575
}
7676

77+
#[instrument(skip(self))]
78+
pub async fn decode(
79+
&self,
80+
ids: Vec<u32>,
81+
skip_special_tokens: bool,
82+
) -> Result<String, TextEmbeddingsError> {
83+
self.tokenization
84+
.decode(ids, skip_special_tokens)
85+
.await
86+
.map_err(|err| {
87+
metrics::increment_counter!("te_request_failure", "err" => "tokenization");
88+
tracing::error!("{err}");
89+
err
90+
})
91+
}
92+
7793
#[instrument(skip(self))]
7894
pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> {
7995
// Limit concurrent requests by acquiring a permit from the semaphore

core/src/tokenization.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,37 @@ impl Tokenization {
120120
// Unwrap is safe here
121121
response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")
122122
}
123+
124+
#[instrument(skip_all)]
125+
pub async fn decode(
126+
&self,
127+
ids: Vec<u32>,
128+
skip_special_tokens: bool,
129+
) -> Result<String, TextEmbeddingsError> {
130+
// Check if inputs is empty
131+
if ids.is_empty() {
132+
return Err(TextEmbeddingsError::Validation(
133+
"`input_ids` cannot be empty".to_string(),
134+
));
135+
}
136+
137+
// Create response channel
138+
let (response_sender, response_receiver) = oneshot::channel();
139+
// Send request to the background validation task
140+
// Unwrap is safe here
141+
self.sender
142+
.send(TokenizerRequest::Decode(
143+
ids,
144+
skip_special_tokens,
145+
response_sender,
146+
Span::current(),
147+
))
148+
.expect("Tokenization background task dropped the receiver. This is a bug.");
149+
150+
// Await on response channel
151+
// Unwrap is safe here
152+
response_receiver.await.expect("Tokenization background task dropped the sender without sending a response. This is a bug.")
153+
}
123154
}
124155

125156
/// Start tokenization workers
@@ -161,10 +192,30 @@ fn tokenizer_worker(
161192
}
162193
})
163194
}
195+
TokenizerRequest::Decode(ids, skip_special_tokens, response_tx, parent_span) => {
196+
parent_span.in_scope(|| {
197+
if !response_tx.is_closed() {
198+
// It's possible that the user dropped its request resulting in a send error.
199+
// We just discard the error
200+
let _ =
201+
response_tx.send(decode_ids(ids, skip_special_tokens, &mut tokenizer));
202+
}
203+
})
204+
}
164205
}
165206
}
166207
}
167208

209+
fn decode_ids(
210+
ids: Vec<u32>,
211+
skip_special_tokens: bool,
212+
tokenizer: &mut Tokenizer,
213+
) -> Result<String, TextEmbeddingsError> {
214+
Ok(tokenizer
215+
.with_truncation(None)?
216+
.decode(&ids, skip_special_tokens)?)
217+
}
218+
168219
fn tokenize_input(
169220
inputs: EncodingInput,
170221
add_special_tokens: bool,
@@ -263,4 +314,10 @@ enum TokenizerRequest {
263314
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
264315
Span,
265316
),
317+
Decode(
318+
Vec<u32>,
319+
bool,
320+
oneshot::Sender<Result<String, TextEmbeddingsError>>,
321+
Span,
322+
),
266323
}

docs/openapi.json

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,52 @@
1212
"version": "1.1.0"
1313
},
1414
"paths": {
15+
"/decode": {
16+
"post": {
17+
"tags": [
18+
"Text Embeddings Inference"
19+
],
20+
"summary": "Decode input ids",
21+
"description": "Decode input ids",
22+
"operationId": "decode",
23+
"requestBody": {
24+
"content": {
25+
"application/json": {
26+
"schema": {
27+
"$ref": "#/components/schemas/DecodeRequest"
28+
}
29+
}
30+
},
31+
"required": true
32+
},
33+
"responses": {
34+
"200": {
35+
"description": "Decoded ids",
36+
"content": {
37+
"application/json": {
38+
"schema": {
39+
"$ref": "#/components/schemas/DecodeResponse"
40+
}
41+
}
42+
}
43+
},
44+
"422": {
45+
"description": "Tokenization error",
46+
"content": {
47+
"application/json": {
48+
"schema": {
49+
"$ref": "#/components/schemas/ErrorResponse"
50+
},
51+
"example": {
52+
"message": "Tokenization error",
53+
"type": "tokenizer"
54+
}
55+
}
56+
}
57+
}
58+
}
59+
}
60+
},
1561
"/embed": {
1662
"post": {
1763
"tags": [
@@ -647,7 +693,7 @@
647693
"content": {
648694
"application/json": {
649695
"schema": {
650-
"$ref": "#/components/schemas/OpenAICompatErrorResponse"
696+
"$ref": "#/components/schemas/ErrorResponse"
651697
},
652698
"example": {
653699
"message": "Tokenization error",
@@ -771,6 +817,31 @@
771817
}
772818
}
773819
},
820+
"DecodeRequest": {
821+
"type": "object",
822+
"required": [
823+
"ids"
824+
],
825+
"properties": {
826+
"ids": {
827+
"$ref": "#/components/schemas/InputIds"
828+
},
829+
"skip_special_tokens": {
830+
"type": "boolean",
831+
"default": "true",
832+
"example": "true"
833+
}
834+
}
835+
},
836+
"DecodeResponse": {
837+
"type": "array",
838+
"items": {
839+
"type": "string"
840+
},
841+
"example": [
842+
"test"
843+
]
844+
},
774845
"EmbedAllRequest": {
775846
"type": "object",
776847
"required": [
@@ -1003,6 +1074,29 @@
10031074
}
10041075
]
10051076
},
1077+
"InputIds": {
1078+
"oneOf": [
1079+
{
1080+
"type": "array",
1081+
"items": {
1082+
"type": "integer",
1083+
"format": "int32",
1084+
"minimum": 0
1085+
}
1086+
},
1087+
{
1088+
"type": "array",
1089+
"items": {
1090+
"type": "array",
1091+
"items": {
1092+
"type": "integer",
1093+
"format": "int32",
1094+
"minimum": 0
1095+
}
1096+
}
1097+
}
1098+
]
1099+
},
10061100
"ModelType": {
10071101
"oneOf": [
10081102
{

proto/tei.proto

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ service Rerank {
3030
service Tokenize {
3131
rpc Tokenize (EncodeRequest) returns (EncodeResponse);
3232
rpc TokenizeStream (stream EncodeRequest) returns (stream EncodeResponse);
33+
rpc Decode (DecodeRequest) returns (DecodeResponse);
34+
rpc DecodeStream (stream DecodeRequest) returns (stream DecodeResponse);
3335
}
3436

3537
message InfoRequest {}
@@ -166,3 +168,12 @@ message SimpleToken {
166168
message EncodeResponse {
167169
repeated SimpleToken tokens = 1;
168170
}
171+
172+
message DecodeRequest {
173+
repeated uint32 ids = 1;
174+
bool skip_special_tokens = 2;
175+
}
176+
177+
message DecodeResponse {
178+
string text = 1;
179+
}

router/src/grpc/server.rs

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use crate::grpc::pb::tei::v1::{
33
EncodeResponse, RerankStreamRequest, SimpleToken, SparseValue, TokenEmbedding,
44
};
55
use crate::grpc::{
6-
EmbedRequest, EmbedResponse, InfoRequest, InfoResponse, PredictRequest, PredictResponse,
7-
Prediction, Rank, RerankRequest, RerankResponse,
6+
DecodeRequest, DecodeResponse, EmbedRequest, EmbedResponse, InfoRequest, InfoResponse,
7+
PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
88
};
99
use crate::ResponseMetadata;
1010
use crate::{grpc, shutdown, ErrorResponse, ErrorType, Info, ModelType};
@@ -331,6 +331,17 @@ impl TextEmbeddingsService {
331331
.collect();
332332
Ok(EncodeResponse { tokens })
333333
}
334+
335+
#[instrument(skip_all)]
336+
async fn decode_inner(&self, request: DecodeRequest) -> Result<DecodeResponse, Status> {
337+
let ids = request.ids;
338+
let text = self
339+
.infer
340+
.decode(ids, request.skip_special_tokens)
341+
.await
342+
.map_err(ErrorResponse::from)?;
343+
Ok(DecodeResponse { text })
344+
}
334345
}
335346

336347
#[tonic::async_trait]
@@ -1327,6 +1338,99 @@ impl grpc::tokenize_server::Tokenize for TextEmbeddingsService {
13271338
response_receiver,
13281339
)))
13291340
}
1341+
1342+
async fn decode(
1343+
&self,
1344+
request: Request<DecodeRequest>,
1345+
) -> Result<Response<DecodeResponse>, Status> {
1346+
let request = request.into_inner();
1347+
let tokens = self.decode_inner(request).await?;
1348+
Ok(Response::new(tokens))
1349+
}
1350+
1351+
type DecodeStreamStream = UnboundedReceiverStream<Result<DecodeResponse, Status>>;
1352+
1353+
async fn decode_stream(
1354+
&self,
1355+
request: Request<Streaming<DecodeRequest>>,
1356+
) -> Result<Response<Self::DecodeStreamStream>, Status> {
1357+
let mut request_stream = request.into_inner();
1358+
1359+
// Create bounded channel to have an upper bound of spawned tasks
1360+
// We will have at most `max_parallel_stream_requests` messages from this stream in the queue
1361+
let (encode_sender, mut encode_receiver) = mpsc::channel::<(
1362+
DecodeRequest,
1363+
oneshot::Sender<Result<DecodeResponse, Status>>,
1364+
)>(self.max_parallel_stream_requests);
1365+
1366+
// Required for the async move below
1367+
let local = self.clone();
1368+
1369+
// Background task that uses the bounded channel
1370+
tokio::spawn(async move {
1371+
while let Some((request, mut sender)) = encode_receiver.recv().await {
1372+
// Required for the async move below
1373+
let task_local = local.clone();
1374+
1375+
// Create async task for this specific input
1376+
tokio::spawn(async move {
1377+
// Select on closed to cancel work if the stream was closed
1378+
tokio::select! {
1379+
response = task_local.decode_inner(request) => {
1380+
let _ = sender.send(response);
1381+
}
1382+
_ = sender.closed() => {}
1383+
}
1384+
});
1385+
}
1386+
});
1387+
1388+
// Intermediate channels
1389+
// Required to keep the order of the requests
1390+
let (intermediate_sender, mut intermediate_receiver) = mpsc::unbounded_channel();
1391+
1392+
tokio::spawn(async move {
1393+
// Iterate on input
1394+
while let Some(request) = request_stream.next().await {
1395+
// Create return channel
1396+
let (result_sender, result_receiver) = oneshot::channel();
1397+
// Push to intermediate channel and preserve ordering
1398+
intermediate_sender
1399+
.send(result_receiver)
1400+
.expect("`intermediate_receiver` was dropped. This is a bug.");
1401+
1402+
match request {
1403+
Ok(request) => encode_sender
1404+
.send((request, result_sender))
1405+
.await
1406+
.expect("`encode_receiver` was dropped. This is a bug."),
1407+
Err(status) => {
1408+
// Request is malformed
1409+
let _ = result_sender.send(Err(status));
1410+
}
1411+
};
1412+
}
1413+
});
1414+
1415+
// Final channel for the outputs
1416+
let (response_sender, response_receiver) = mpsc::unbounded_channel();
1417+
1418+
tokio::spawn(async move {
1419+
while let Some(result_receiver) = intermediate_receiver.recv().await {
1420+
// Select on closed to cancel work if the stream was closed
1421+
tokio::select! {
1422+
response = result_receiver => {
1423+
let _ = response_sender.send(response.expect("`result_sender` was dropped. This is a bug."));
1424+
}
1425+
_ = response_sender.closed() => {}
1426+
}
1427+
}
1428+
});
1429+
1430+
Ok(Response::new(UnboundedReceiverStream::new(
1431+
response_receiver,
1432+
)))
1433+
}
13301434
}
13311435

13321436
pub async fn run(

0 commit comments

Comments
 (0)