9
9
from transformers import RobertaConfig
10
10
11
11
from vllm .config import VllmConfig
12
+ from vllm .forward_context import get_forward_context
12
13
from vllm .model_executor .layers .pooler import ClassifierPooler , CLSPool
13
14
from vllm .model_executor .layers .vocab_parallel_embedding import (
14
15
VocabParallelEmbedding )
@@ -50,39 +51,12 @@ def __init__(self, config: RobertaConfig):
50
51
def forward (
51
52
self ,
52
53
input_ids : torch .Tensor ,
53
- seq_lens : torch .Tensor ,
54
54
position_ids : torch .Tensor ,
55
55
token_type_ids : Optional [torch .Tensor ] = None ,
56
56
) -> torch .Tensor :
57
57
input_shape = input_ids .size ()
58
58
inputs_embeds = self .word_embeddings (input_ids )
59
59
60
- # Replace position ids because in RoBERTa models
61
- # they have to start at padding_idx + 1 and ignore
62
- # existing padding tokens
63
- # References:
64
- # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
65
- # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
66
- pos_list = []
67
- token_list = []
68
- offset = 0
69
- for seq_len in seq_lens :
70
- pos_list .append (position_ids [offset :offset + seq_len ])
71
- token_list .append (input_ids [offset :offset + seq_len ])
72
- offset += seq_len
73
-
74
- new_pos_list = []
75
- for positions , tokens in zip (pos_list , token_list ):
76
- # Verify assumption that incoming position are
77
- # always a sequence from 0 to N.
78
- expected_pos = torch .arange (positions .size ()[0 ],
79
- dtype = torch .long ,
80
- device = inputs_embeds .device )
81
- assert torch .equal (positions , expected_pos )
82
- new_pos_list .append (
83
- create_position_ids_from_input_ids (tokens , self .padding_idx ))
84
- position_ids = torch .cat (new_pos_list )
85
-
86
60
# Position embeddings.
87
61
position_embeddings = self .position_embeddings (position_ids )
88
62
if token_type_ids is None :
@@ -124,6 +98,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
124
98
_pooler: An instance of Pooler used for pooling operations.
125
99
"""
126
100
101
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
102
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
103
+ self .padding_idx = vllm_config .model_config .hf_config .pad_token_id
104
+
105
+ def forward (
106
+ self ,
107
+ input_ids : Optional [torch .Tensor ],
108
+ positions : torch .Tensor ,
109
+ token_type_ids : Optional [torch .Tensor ] = None ,
110
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
111
+ inputs_embeds : Optional [torch .Tensor ] = None ,
112
+ ) -> torch .Tensor :
113
+
114
+ # Fix Roberta positions here outside of the CUDA graph.
115
+ # Because we need the to extract the sequences from
116
+ # input_ids the control flow is data dependent.
117
+ replace_roberta_positions (input_ids = input_ids ,
118
+ position_ids = positions ,
119
+ padding_idx = self .padding_idx )
120
+
121
+ return self .model (input_ids = input_ids ,
122
+ position_ids = positions ,
123
+ token_type_ids = token_type_ids ,
124
+ inputs_embeds = inputs_embeds ,
125
+ intermediate_tensors = intermediate_tensors )
126
+
127
127
def _build_model (self ,
128
128
vllm_config : VllmConfig ,
129
129
prefix : str = "" ) -> Union [BertModel , BertWithRope ]:
@@ -180,6 +180,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
180
180
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
181
181
super ().__init__ ()
182
182
config = vllm_config .model_config .hf_config
183
+ self .padding_idx = vllm_config .model_config .hf_config .pad_token_id
183
184
184
185
self .num_labels = config .num_labels
185
186
self .roberta = BertModel (vllm_config = vllm_config ,
@@ -206,6 +207,9 @@ def forward(
206
207
inputs_embeds : Optional [torch .Tensor ] = None ,
207
208
token_type_ids : Optional [torch .Tensor ] = None ,
208
209
) -> torch .Tensor :
210
+ replace_roberta_positions (input_ids = input_ids ,
211
+ position_ids = positions ,
212
+ padding_idx = self .padding_idx )
209
213
return self .roberta (input_ids = input_ids ,
210
214
position_ids = positions ,
211
215
inputs_embeds = inputs_embeds ,
@@ -235,3 +239,36 @@ def create_position_ids_from_input_ids(input_ids,
235
239
past_key_values_length ) * mask
236
240
237
241
return incremental_indices .long () + padding_idx
242
+
243
+
244
+ def replace_roberta_positions (input_ids : torch .Tensor ,
245
+ position_ids : torch .Tensor ,
246
+ padding_idx : int ) -> None :
247
+
248
+ seq_lens : Optional [torch .Tensor ] = None
249
+ attn_metadata = get_forward_context ().attn_metadata
250
+ if attn_metadata is not None : # can be None during warmup
251
+ if isinstance (attn_metadata , dict ):
252
+ attn_metadata = next (iter (attn_metadata .values ()))
253
+ # TODO: remove "seq_lens_tensor" after V0 is removed
254
+ seq_lens = getattr (attn_metadata , "seq_lens_tensor" ,
255
+ getattr (attn_metadata , "seq_lens" , None ))
256
+
257
+ if seq_lens is not None :
258
+ assert isinstance (seq_lens , torch .Tensor )
259
+
260
+ # Replace position ids because in RoBERTa models
261
+ # they have to start at padding_idx + 1 and ignore
262
+ # existing padding tokens
263
+ # References:
264
+ # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
265
+ # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
266
+ token_list = torch .split (input_ids [:torch .sum (seq_lens )],
267
+ seq_lens .tolist ())
268
+
269
+ offset = 0
270
+ for tokens in token_list :
271
+ length = tokens .shape [0 ]
272
+ position_ids [offset :offset + length ] = \
273
+ create_position_ids_from_input_ids (tokens , padding_idx )
274
+ offset = offset + length
0 commit comments