Skip to content

Commit fd38f9e

Browse files
feat: add support for roberta
1 parent b718452 commit fd38f9e

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

backends/candle/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ impl CandleBackend {
3939
};
4040

4141
// Check model type
42-
if config.model_type != Some("bert".to_string()) {
42+
if config.model_type != Some("bert".to_string())
43+
&& config.model_type != Some("xlm-roberta".to_string())
44+
{
4345
return Err(BackendError::Start(format!(
4446
"Model {:?} is not supported",
4547
config.model_type

core/src/tokenization.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ pub struct Tokenization {
1313
}
1414

1515
impl Tokenization {
16-
pub fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
16+
pub fn new(
17+
workers: usize,
18+
tokenizer: Tokenizer,
19+
max_input_length: usize,
20+
position_offset: usize,
21+
) -> Self {
1722
// Create channel
1823
let (sender, receiver) = flume::unbounded();
1924

@@ -24,7 +29,12 @@ impl Tokenization {
2429

2530
// Spawn worker
2631
tokio::task::spawn_blocking(move || {
27-
tokenizer_worker(tokenizer_clone, max_input_length, receiver_clone)
32+
tokenizer_worker(
33+
tokenizer_clone,
34+
max_input_length,
35+
position_offset,
36+
receiver_clone,
37+
)
2838
});
2939
}
3040

@@ -66,6 +76,7 @@ impl Tokenization {
6676
fn tokenizer_worker(
6777
tokenizer: Tokenizer,
6878
max_input_length: usize,
79+
position_offset: usize,
6980
receiver: flume::Receiver<TokenizerRequest>,
7081
) {
7182
// Loop over requests
@@ -74,8 +85,13 @@ fn tokenizer_worker(
7485
if !response_tx.is_closed() {
7586
// It's possible that the user dropped its request resulting in a send error.
7687
// We just discard the error
77-
let _ =
78-
response_tx.send(encode_input(inputs, truncate, max_input_length, &tokenizer));
88+
let _ = response_tx.send(encode_input(
89+
inputs,
90+
truncate,
91+
max_input_length,
92+
position_offset,
93+
&tokenizer,
94+
));
7995
}
8096
})
8197
}
@@ -86,6 +102,7 @@ fn encode_input(
86102
inputs: String,
87103
truncate: bool,
88104
max_input_length: usize,
105+
position_offset: usize,
89106
tokenizer: &Tokenizer,
90107
) -> Result<Encoding, TextEmbeddingsError> {
91108
// Get the number of tokens in the input
@@ -109,7 +126,8 @@ fn encode_input(
109126
Ok(Encoding {
110127
input_ids: encoding.get_ids().to_vec(),
111128
token_type_ids: encoding.get_type_ids().to_vec(),
112-
position_ids: (0..seq_len as u32).collect::<Vec<_>>(),
129+
position_ids: (position_offset as u32..(seq_len + position_offset) as u32)
130+
.collect::<Vec<_>>(),
113131
})
114132
}
115133

router/src/main.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ pub struct ModelConfig {
116116
pub model_type: String,
117117
#[serde(alias = "n_positions")]
118118
pub max_position_embeddings: usize,
119+
pub pad_token_id: usize,
119120
}
120121

121122
#[tokio::main]
@@ -167,11 +168,19 @@ async fn main() -> Result<()> {
167168
);
168169
tokenizer.with_padding(None);
169170

171+
// Position IDs offset. Used for Roberta.
172+
let position_offset = if config.pad_token_id == 0 {
173+
0
174+
} else {
175+
config.pad_token_id + 1
176+
};
177+
170178
// Tokenization logic
171179
let tokenization = Tokenization::new(
172180
args.tokenization_workers,
173181
tokenizer,
174182
config.max_position_embeddings,
183+
position_offset,
175184
);
176185

177186
// Create backend

0 commit comments

Comments
 (0)