@@ -305,6 +305,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
305
305
with safe_open (model_path / "model.safetensors" , framework = "pt" ) as f :
306
306
model = FlashBertModel (f , device , dtype , config )
307
307
self .device = device
308
+ self .dtype = dtype
308
309
if device .type == "hpu" :
309
310
from habana_frameworks .torch .hpu import wrap_in_hpu_graph
310
311
@@ -326,12 +327,15 @@ def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
326
327
cu_seqlens = torch .cat (
327
328
(input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ())
328
329
)
329
- mask = batch .attention_mask .to ( torch . bool )
330
+ mask = batch .attention_mask .bool ( )
330
331
batch_size = input_lens .size (0 )
331
- attn_mask = torch .empty (
332
- [batch_size , 1 , 1 , mask .shape [- 1 ]], device = self .device
333
- ).fill_ (float ("-inf" ))
334
- attn_mask [:, :, :, :].masked_fill_ (mask [:, None , None , :], 0 )
332
+ attn_mask = torch .full (
333
+ [batch_size , 1 , 1 , mask .shape [- 1 ]],
334
+ fill_value = torch .finfo (self .dtype ).min ,
335
+ device = self .device ,
336
+ dtype = self .dtype ,
337
+ )
338
+ attn_mask .masked_fill_ (mask [:, None , None , :], 0 )
335
339
elif isinstance (batch , FlashBatch ):
336
340
cu_seqlens = batch .cu_seqlens
337
341
mask = None
0 commit comments