2
2
from pathlib import Path
3
3
from torch import nn
4
4
import torch .nn .functional as F
5
- from typing import Type , List
5
+ from typing import Type , List , Union
6
6
from safetensors import safe_open
7
7
from transformers .activations import ACT2FN
8
8
from transformers .models .bert import BertConfig
9
9
from opentelemetry import trace
10
10
from text_embeddings_server .models import Model
11
- from text_embeddings_server .models .types import FlashBatch , Embedding
11
+ from text_embeddings_server .models .types import FlashBatch , Embedding , PaddedBatch
12
12
from text_embeddings_server .utils .flash_attn import attention
13
13
from text_embeddings_server .utils .device import use_ipex
14
14
@@ -166,22 +166,41 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
166
166
self .num_heads = config .num_attention_heads
167
167
self .device = device
168
168
169
- def forward (self , hidden_states , cu_seqlens , max_s ):
169
+ def forward (self , hidden_states , cu_seqlens , max_s , attn_mask = None ):
170
170
residual = hidden_states
171
-
172
- qkv = torch .addmm (self .qkv_bias , hidden_states , self .qkv_weight )
173
- q , k , v = qkv .view (- 1 , self .num_heads * 3 , self .head_size ).split (
174
- self .num_heads , dim = 1
175
- )
176
-
171
+ qkv = F .linear (hidden_states , self .qkv_weight .T , self .qkv_bias )
172
+ bs = 1
173
+ hidden_dim = hidden_states .size (- 1 )
174
+ is_flat = True
175
+ if hidden_states .dim () > 2 :
176
+ is_flat = False
177
+ bs = hidden_states .size (0 )
178
+ q , k , v = qkv .view (bs , - 1 , self .num_heads * 3 , self .head_size ).split (
179
+ self .num_heads , dim = 2
180
+ )
181
+ else :
182
+ q , k , v = qkv .view (- 1 , self .num_heads * 3 , self .head_size ).split (
183
+ self .num_heads , dim = 1
184
+ )
177
185
attn_output = torch .empty_like (q )
178
- attention (q , k , v , attn_output , cu_seqlens , max_s , self .softmax_scale )
186
+ attention (
187
+ q ,
188
+ k ,
189
+ v ,
190
+ attn_output ,
191
+ cu_seqlens ,
192
+ max_s ,
193
+ self .softmax_scale ,
194
+ attn_mask = attn_mask ,
195
+ )
179
196
180
197
hidden_states = torch .addmm (
181
198
self .dense_bias ,
182
199
attn_output .view (- 1 , self .num_heads * self .head_size ),
183
200
self .dense_weight ,
184
201
)
202
+ if not is_flat :
203
+ hidden_states = hidden_states .view (bs , - 1 , hidden_dim )
185
204
hidden_states , _ = self .layer_norm .forward (hidden_states , residual )
186
205
187
206
return hidden_states
@@ -224,19 +243,16 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
224
243
f"{ prefix } .output.LayerNorm" , handle , device , dtype , config
225
244
)
226
245
227
- def forward (self , hidden_states , cu_seqlens , max_s ):
228
- hidden_states = self .attention .forward (hidden_states , cu_seqlens , max_s )
246
+ def forward (self , hidden_states , cu_seqlens , max_s , attn_mask = None ):
247
+ hidden_states = self .attention .forward (
248
+ hidden_states , cu_seqlens , max_s , attn_mask
249
+ )
229
250
residual = hidden_states
230
-
231
- hidden_states = torch .addmm (
232
- self .intermediate_bias , hidden_states , self .intermediate_weight
251
+ hidden_states = F .linear (
252
+ hidden_states , self .intermediate_weight .T , self .intermediate_bias
233
253
)
234
254
hidden_states = self .intermediate_act_fn (hidden_states )
235
- hidden_states = torch .addmm (
236
- self .output_bias ,
237
- hidden_states ,
238
- self .output_weight ,
239
- )
255
+ hidden_states = F .linear (hidden_states , self .output_weight .T , self .output_bias )
240
256
hidden_states , _ = self .layer_norm .forward (hidden_states , residual )
241
257
return hidden_states
242
258
@@ -248,9 +264,9 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
248
264
for i in range (config .num_hidden_layers )
249
265
]
250
266
251
- def forward (self , hidden_states , cu_seqlens , max_s ):
267
+ def forward (self , hidden_states , cu_seqlens , max_s , attn_mask = None ):
252
268
for layer in self .layers :
253
- hidden_states = layer .forward (hidden_states , cu_seqlens , max_s )
269
+ hidden_states = layer .forward (hidden_states , cu_seqlens , max_s , attn_mask )
254
270
return hidden_states
255
271
256
272
@@ -259,10 +275,21 @@ def __init__(self, handle, device, dtype, config: BertConfig):
259
275
self .embeddings = BertEmbeddings ("embeddings" , handle , device , dtype , config )
260
276
self .encoder = BertEncoder ("encoder" , handle , device , dtype , config )
261
277
262
- def forward (self , input_ids , token_type_ids , position_ids , cu_seqlens , max_s ):
278
+ def forward (
279
+ self ,
280
+ input_ids ,
281
+ token_type_ids ,
282
+ position_ids ,
283
+ cu_seqlens ,
284
+ max_s ,
285
+ mask = None ,
286
+ attn_mask = None ,
287
+ ):
263
288
embeddings = self .embeddings .forward (input_ids , token_type_ids , position_ids )
264
- encoder_outputs = self .encoder .forward (embeddings , cu_seqlens , max_s )
265
-
289
+ encoder_outputs = self .encoder .forward (embeddings , cu_seqlens , max_s , attn_mask )
290
+ if mask is not None :
291
+ outputs = encoder_outputs [mask ]
292
+ return outputs [cu_seqlens [:- 1 ]]
266
293
return encoder_outputs [cu_seqlens [:- 1 ]]
267
294
268
295
@@ -277,6 +304,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
277
304
278
305
with safe_open (model_path / "model.safetensors" , framework = "pt" ) as f :
279
306
model = FlashBertModel (f , device , dtype , config )
307
+ self .device = device
280
308
if device .type == "hpu" :
281
309
from habana_frameworks .torch .hpu import wrap_in_hpu_graph
282
310
@@ -286,17 +314,38 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
286
314
super (FlashBert , self ).__init__ (model = model , dtype = dtype , device = device )
287
315
288
316
@property
289
- def batch_type (self ) -> Type [FlashBatch ]:
290
- return FlashBatch
317
+ def batch_type (self ) -> Union [FlashBatch , PaddedBatch ]:
318
+ # for hpu devices, we use PaddedBatch as we do not have real varlen fwd yet
319
+ return FlashBatch if self .device .type != "hpu" else PaddedBatch
291
320
292
321
@tracer .start_as_current_span ("embed" )
293
- def embed (self , batch : FlashBatch ) -> List [Embedding ]:
322
+ def embed (self , batch : Union [FlashBatch , PaddedBatch ]) -> List [Embedding ]:
323
+ if isinstance (batch , PaddedBatch ):
324
+ input_lens = batch .attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
325
+ max_input_lens = input_lens .max ().item ()
326
+ cu_seqlens = torch .cat (
327
+ (input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ())
328
+ )
329
+ mask = batch .attention_mask .to (torch .bool )
330
+ batch_size = input_lens .size (0 )
331
+ attn_mask = torch .empty (
332
+ [batch_size , 1 , 1 , mask .shape [- 1 ]], device = self .device
333
+ ).fill_ (float ("-inf" ))
334
+ attn_mask [:, :, :, :].masked_fill_ (mask [:, None , None , :], 0 )
335
+ elif isinstance (batch , FlashBatch ):
336
+ cu_seqlens = batch .cu_seqlens
337
+ mask = None
338
+ attn_mask = None
339
+ max_input_lens = batch .max_s
340
+
294
341
embedding = self .model .forward (
295
342
input_ids = batch .input_ids ,
296
343
token_type_ids = batch .token_type_ids ,
297
344
position_ids = batch .position_ids ,
298
- cu_seqlens = batch .cu_seqlens ,
299
- max_s = batch .max_s ,
345
+ cu_seqlens = cu_seqlens ,
346
+ max_s = max_input_lens ,
347
+ mask = mask ,
348
+ attn_mask = attn_mask ,
300
349
)
301
350
cpu_results = embedding .view (- 1 ).tolist ()
302
351
0 commit comments