Skip to content

Commit b06b752

Browse files
authored
Optimize the performance of FlashBert on HPU by using fast mode softmax (#555)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 8eb7a84 commit b06b752

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
305305
with safe_open(model_path / "model.safetensors", framework="pt") as f:
306306
model = FlashBertModel(f, device, dtype, config)
307307
self.device = device
308+
self.dtype = dtype
308309
if device.type == "hpu":
309310
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
310311

@@ -326,12 +327,15 @@ def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
326327
cu_seqlens = torch.cat(
327328
(input_lens.new_tensor([0]), input_lens.cumsum(-1).int())
328329
)
329-
mask = batch.attention_mask.to(torch.bool)
330+
mask = batch.attention_mask.bool()
330331
batch_size = input_lens.size(0)
331-
attn_mask = torch.empty(
332-
[batch_size, 1, 1, mask.shape[-1]], device=self.device
333-
).fill_(float("-inf"))
334-
attn_mask[:, :, :, :].masked_fill_(mask[:, None, None, :], 0)
332+
attn_mask = torch.full(
333+
[batch_size, 1, 1, mask.shape[-1]],
334+
fill_value=torch.finfo(self.dtype).min,
335+
device=self.device,
336+
dtype=self.dtype,
337+
)
338+
attn_mask.masked_fill_(mask[:, None, None, :], 0)
335339
elif isinstance(batch, FlashBatch):
336340
cu_seqlens = batch.cu_seqlens
337341
mask = None

backends/python/server/text_embeddings_server/utils/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def hpu_attn(
7878
if is_causal:
7979
attn_mask = None
8080

81-
out_ = FusedSDPA.apply(q, k, v, attn_mask, 0.0, is_causal, softmax_scale)
81+
out_ = FusedSDPA.apply(q, k, v, attn_mask, 0.0, is_causal, softmax_scale, "fast", False)
8282
out_ = out_.transpose(1, 2)
8383
out.copy_(out_)
8484
return out

0 commit comments

Comments
 (0)