Skip to content

Commit 6e4133b

Browse files
authored
Optimize flash bert path for hpu device (#509)
Signed-off-by: kaixuanliu <kaixuan.liu@intel.com>
1 parent 53bdab3 commit 6e4133b

File tree

3 files changed

+100
-84
lines changed

3 files changed

+100
-84
lines changed

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from pathlib import Path
33
from torch import nn
44
import torch.nn.functional as F
5-
from typing import Type, List
5+
from typing import Type, List, Union
66
from safetensors import safe_open
77
from transformers.activations import ACT2FN
88
from transformers.models.bert import BertConfig
99
from opentelemetry import trace
1010
from text_embeddings_server.models import Model
11-
from text_embeddings_server.models.types import FlashBatch, Embedding
11+
from text_embeddings_server.models.types import FlashBatch, Embedding, PaddedBatch
1212
from text_embeddings_server.utils.flash_attn import attention
1313
from text_embeddings_server.utils.device import use_ipex
1414

@@ -166,22 +166,41 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
166166
self.num_heads = config.num_attention_heads
167167
self.device = device
168168

169-
def forward(self, hidden_states, cu_seqlens, max_s):
169+
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
170170
residual = hidden_states
171-
172-
qkv = torch.addmm(self.qkv_bias, hidden_states, self.qkv_weight)
173-
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
174-
self.num_heads, dim=1
175-
)
176-
171+
qkv = F.linear(hidden_states, self.qkv_weight.T, self.qkv_bias)
172+
bs = 1
173+
hidden_dim = hidden_states.size(-1)
174+
is_flat = True
175+
if hidden_states.dim() > 2:
176+
is_flat = False
177+
bs = hidden_states.size(0)
178+
q, k, v = qkv.view(bs, -1, self.num_heads * 3, self.head_size).split(
179+
self.num_heads, dim=2
180+
)
181+
else:
182+
q, k, v = qkv.view(-1, self.num_heads * 3, self.head_size).split(
183+
self.num_heads, dim=1
184+
)
177185
attn_output = torch.empty_like(q)
178-
attention(q, k, v, attn_output, cu_seqlens, max_s, self.softmax_scale)
186+
attention(
187+
q,
188+
k,
189+
v,
190+
attn_output,
191+
cu_seqlens,
192+
max_s,
193+
self.softmax_scale,
194+
attn_mask=attn_mask,
195+
)
179196

180197
hidden_states = torch.addmm(
181198
self.dense_bias,
182199
attn_output.view(-1, self.num_heads * self.head_size),
183200
self.dense_weight,
184201
)
202+
if not is_flat:
203+
hidden_states = hidden_states.view(bs, -1, hidden_dim)
185204
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
186205

187206
return hidden_states
@@ -224,19 +243,16 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
224243
f"{prefix}.output.LayerNorm", handle, device, dtype, config
225244
)
226245

227-
def forward(self, hidden_states, cu_seqlens, max_s):
228-
hidden_states = self.attention.forward(hidden_states, cu_seqlens, max_s)
246+
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
247+
hidden_states = self.attention.forward(
248+
hidden_states, cu_seqlens, max_s, attn_mask
249+
)
229250
residual = hidden_states
230-
231-
hidden_states = torch.addmm(
232-
self.intermediate_bias, hidden_states, self.intermediate_weight
251+
hidden_states = F.linear(
252+
hidden_states, self.intermediate_weight.T, self.intermediate_bias
233253
)
234254
hidden_states = self.intermediate_act_fn(hidden_states)
235-
hidden_states = torch.addmm(
236-
self.output_bias,
237-
hidden_states,
238-
self.output_weight,
239-
)
255+
hidden_states = F.linear(hidden_states, self.output_weight.T, self.output_bias)
240256
hidden_states, _ = self.layer_norm.forward(hidden_states, residual)
241257
return hidden_states
242258

@@ -248,9 +264,9 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig):
248264
for i in range(config.num_hidden_layers)
249265
]
250266

251-
def forward(self, hidden_states, cu_seqlens, max_s):
267+
def forward(self, hidden_states, cu_seqlens, max_s, attn_mask=None):
252268
for layer in self.layers:
253-
hidden_states = layer.forward(hidden_states, cu_seqlens, max_s)
269+
hidden_states = layer.forward(hidden_states, cu_seqlens, max_s, attn_mask)
254270
return hidden_states
255271

256272

@@ -259,10 +275,21 @@ def __init__(self, handle, device, dtype, config: BertConfig):
259275
self.embeddings = BertEmbeddings("embeddings", handle, device, dtype, config)
260276
self.encoder = BertEncoder("encoder", handle, device, dtype, config)
261277

262-
def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
278+
def forward(
279+
self,
280+
input_ids,
281+
token_type_ids,
282+
position_ids,
283+
cu_seqlens,
284+
max_s,
285+
mask=None,
286+
attn_mask=None,
287+
):
263288
embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids)
264-
encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s)
265-
289+
encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s, attn_mask)
290+
if mask is not None:
291+
outputs = encoder_outputs[mask]
292+
return outputs[cu_seqlens[:-1]]
266293
return encoder_outputs[cu_seqlens[:-1]]
267294

268295

@@ -277,6 +304,7 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
277304

278305
with safe_open(model_path / "model.safetensors", framework="pt") as f:
279306
model = FlashBertModel(f, device, dtype, config)
307+
self.device = device
280308
if device.type == "hpu":
281309
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
282310

@@ -286,17 +314,38 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
286314
super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)
287315

288316
@property
289-
def batch_type(self) -> Type[FlashBatch]:
290-
return FlashBatch
317+
def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
318+
# for hpu devices, we use PaddedBatch as we do not have real varlen fwd yet
319+
return FlashBatch if self.device.type != "hpu" else PaddedBatch
291320

292321
@tracer.start_as_current_span("embed")
293-
def embed(self, batch: FlashBatch) -> List[Embedding]:
322+
def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
323+
if isinstance(batch, PaddedBatch):
324+
input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32)
325+
max_input_lens = input_lens.max().item()
326+
cu_seqlens = torch.cat(
327+
(input_lens.new_tensor([0]), input_lens.cumsum(-1).int())
328+
)
329+
mask = batch.attention_mask.to(torch.bool)
330+
batch_size = input_lens.size(0)
331+
attn_mask = torch.empty(
332+
[batch_size, 1, 1, mask.shape[-1]], device=self.device
333+
).fill_(float("-inf"))
334+
attn_mask[:, :, :, :].masked_fill_(mask[:, None, None, :], 0)
335+
elif isinstance(batch, FlashBatch):
336+
cu_seqlens = batch.cu_seqlens
337+
mask = None
338+
attn_mask = None
339+
max_input_lens = batch.max_s
340+
294341
embedding = self.model.forward(
295342
input_ids=batch.input_ids,
296343
token_type_ids=batch.token_type_ids,
297344
position_ids=batch.position_ids,
298-
cu_seqlens=batch.cu_seqlens,
299-
max_s=batch.max_s,
345+
cu_seqlens=cu_seqlens,
346+
max_s=max_input_lens,
347+
mask=mask,
348+
attn_mask=attn_mask,
300349
)
301350
cpu_results = embedding.view(-1).tolist()
302351

backends/python/server/text_embeddings_server/utils/device.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
import torch
77
import subprocess
88

9+
ALLOW_REDUCED_PRECISION = os.getenv(
10+
"ALLOW_REDUCED_PRECISION_FP16_BF16", "true"
11+
).lower() in [
12+
"true",
13+
"1",
14+
]
15+
916

1017
def _is_ipex_available():
1118
def get_major_and_minor_from_version(full_version):
@@ -55,6 +62,9 @@ def get_device():
5562
elif is_hpu():
5663
import habana_frameworks.torch.core as htcore
5764

65+
# WA for perf degradation from pytorch 2.5
66+
if ALLOW_REDUCED_PRECISION:
67+
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
5868
if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore
5969
device = torch.device("hpu")
6070
elif use_ipex():

backends/python/server/text_embeddings_server/utils/flash_attn.py

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def hpu_attn(
6262
k,
6363
v,
6464
out,
65+
attn_mask,
6566
seqlen_q,
6667
seqlen_k,
6768
max_seqlen_q,
@@ -71,66 +72,21 @@ def hpu_attn(
7172
):
7273
from habana_frameworks.torch.hpex.kernels import FusedSDPA
7374

74-
total_q, num_head, head_size = q.size()
75-
total_k, num_head_k, _ = k.size()
76-
batch_size = seqlen_q.size(0) - 1
77-
seqlen_q_ = seqlen_q.clone()
78-
seqlen_q_[:batch_size] = seqlen_q[1:]
79-
seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size]
80-
seqlen_k_ = seqlen_k.clone()
81-
seqlen_k_[:batch_size] = seqlen_k[1:]
82-
seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size]
83-
84-
pad_q = torch.zeros(
85-
[batch_size, max_seqlen_q, num_head, head_size],
86-
dtype=q.dtype,
87-
device=q.device,
88-
)
89-
pad_k = torch.zeros(
90-
[batch_size, max_seqlen_k, num_head_k, head_size],
91-
dtype=k.dtype,
92-
device=k.device,
93-
)
94-
pad_v = torch.zeros(
95-
[batch_size, max_seqlen_k, num_head_k, head_size],
96-
dtype=v.dtype,
97-
device=v.device,
98-
)
99-
q_mask = torch.arange(0, max_seqlen_q, device=q.device)[None, :].repeat(
100-
batch_size, 1
101-
)
102-
q_mask = q_mask < seqlen_q[:, None].repeat(1, q_mask.size(-1))
103-
k_mask = torch.arange(0, max_seqlen_k, device=k.device)[None, :].repeat(
104-
batch_size, 1
105-
)
106-
k_mask = k_mask < seqlen_k[:, None].repeat(1, k_mask.size(-1))
107-
align_mask_seqlen = max_seqlen_k
108-
attn_mask = torch.empty(
109-
[batch_size, 1, 1, align_mask_seqlen],
110-
dtype=q.dtype,
111-
device=q.device,
112-
).fill_(float("-inf"))
113-
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
114-
115-
pad_q[q_mask] = q
116-
pad_k[k_mask] = k
117-
pad_v[k_mask] = v
118-
119-
pad_q = pad_q.permute(0, 2, 1, 3)
120-
pad_k = pad_k.permute(0, 2, 1, 3)
121-
pad_v = pad_v.permute(0, 2, 1, 3)
75+
q = q.transpose(1, 2)
76+
k = k.transpose(1, 2)
77+
v = v.transpose(1, 2)
12278
if is_causal:
12379
attn_mask = None
12480

125-
out_ = FusedSDPA.apply(
126-
pad_q, pad_k, pad_v, attn_mask, 0.0, is_causal, softmax_scale
127-
)
128-
out_ = out_.permute(0, 2, 1, 3)
129-
out.copy_(out_[q_mask])
81+
out_ = FusedSDPA.apply(q, k, v, attn_mask, 0.0, is_causal, softmax_scale)
82+
out_ = out_.transpose(1, 2)
83+
out.copy_(out_)
13084
return out
13185

13286

133-
def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
87+
def attention(
88+
q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False, attn_mask=None
89+
):
13490
if HAS_FLASH_ATTN_V2:
13591
if use_ipex:
13692
import intel_extension_for_pytorch as ipex
@@ -157,6 +113,7 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
157113
k,
158114
v,
159115
out,
116+
attn_mask,
160117
cu_seqlens,
161118
cu_seqlens,
162119
max_s,

0 commit comments

Comments
 (0)