Skip to content

Commit 53bdab3

Browse files
authored
Hpu bucketing (#489)
Signed-off-by: kaixuanliu <kaixuan.liu@intel.com>
1 parent 2be18ab commit 53bdab3

File tree

5 files changed

+61
-7
lines changed

5 files changed

+61
-7
lines changed

backends/python/server/text_embeddings_server/models/classification_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ def __init__(
2626
model = model.to(dtype).to(device)
2727

2828
self.hidden_size = model.config.hidden_size
29+
position_offset = 0
30+
model_type = model.config.model_type
31+
if model_type in ["xlm-roberta", "camembert", "roberta"]:
32+
position_offset = model.config.pad_token_id + 1
33+
if hasattr(model.config, "max_seq_length"):
34+
self.max_input_length = model.config.max_seq_length
35+
else:
36+
self.max_input_length = (
37+
model.config.max_position_embeddings - position_offset
38+
)
39+
2940
self.has_position_ids = (
3041
inspect.signature(model.forward).parameters.get("position_ids", None)
3142
is not None

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ def __init__(
3030
self.hidden_size = model.config.hidden_size
3131
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)
3232

33+
position_offset = 0
34+
model_type = model.config.model_type
35+
if model_type in ["xlm-roberta", "camembert", "roberta"]:
36+
position_offset = model.config.pad_token_id + 1
37+
if hasattr(model.config, "max_seq_length"):
38+
self.max_input_length = model.config.max_seq_length
39+
else:
40+
self.max_input_length = (
41+
model.config.max_position_embeddings - position_offset
42+
)
43+
3344
self.has_position_ids = (
3445
inspect.signature(model.forward).parameters.get("position_ids", None)
3546
is not None

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s):
269269
class FlashBert(Model):
270270
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
271271
config = BertConfig.from_pretrained(model_path)
272+
273+
if hasattr(config, "max_seq_length"):
274+
self.max_input_length = config.max_seq_length
275+
else:
276+
self.max_input_length = config.max_position_embeddings
277+
272278
with safe_open(model_path / "model.safetensors", framework="pt") as f:
273279
model = FlashBertModel(f, device, dtype, config)
274280
if device.type == "hpu":

backends/python/server/text_embeddings_server/models/types.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23

34
from abc import ABC, abstractmethod
@@ -8,6 +9,11 @@
89
from text_embeddings_server.pb.embed_pb2 import Embedding, Score
910

1011
tracer = trace.get_tracer(__name__)
12+
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
13+
14+
15+
def round_up(number, k):
16+
return (number + k - 1) // k * k
1117

1218

1319
class Batch(ABC):
@@ -30,11 +36,23 @@ class PaddedBatch(Batch):
3036

3137
@classmethod
3238
@tracer.start_as_current_span("from_pb")
33-
def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "PaddedBatch":
39+
def from_pb(
40+
cls, pb: embed_pb2.EmbedRequest, device: torch.device, max_input_length: int
41+
) -> "PaddedBatch":
42+
if pb.max_length > max_input_length:
43+
raise RuntimeError(f"input length exceeds model config's max_input_length")
44+
45+
batch_size = len(pb.cu_seq_lengths) - 1
46+
if device.type == "hpu":
47+
# To better utilize HPU, we need to do batch/seq_len bucketing
48+
max_length = round_up(pb.max_length, PAD_SEQUENCE_TO_MULTIPLE_OF)
49+
max_length = min(max_length, max_input_length)
50+
new_bs = 2 ** math.ceil(math.log2(batch_size))
51+
else:
52+
new_bs = batch_size
53+
max_length = pb.max_length
3454
# Allocate padded tensors all at once
35-
all_tensors = torch.zeros(
36-
[4, len(pb.cu_seq_lengths) - 1, pb.max_length], dtype=torch.int32
37-
)
55+
all_tensors = torch.zeros([4, new_bs, max_length], dtype=torch.int32)
3856

3957
for i, start_index in enumerate(pb.cu_seq_lengths[:-1]):
4058
end_index = pb.cu_seq_lengths[i + 1]
@@ -77,7 +95,9 @@ class FlashBatch(Batch):
7795

7896
@classmethod
7997
@tracer.start_as_current_span("from_pb")
80-
def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "FlashBatch":
98+
def from_pb(
99+
cls, pb: embed_pb2.EmbedRequest, device: torch.device, max_input_length: int
100+
) -> "FlashBatch":
81101
batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device)
82102
batch_token_type_ids = torch.tensor(
83103
pb.token_type_ids, dtype=torch.int32, device=device

backends/python/server/text_embeddings_server/server.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,20 @@ async def Health(self, request, context):
2525
return embed_pb2.HealthResponse()
2626

2727
async def Embed(self, request, context):
28-
batch = self.model.batch_type.from_pb(request, self.model.device)
28+
max_input_length = self.model.max_input_length
29+
batch = self.model.batch_type.from_pb(
30+
request, self.model.device, max_input_length
31+
)
2932

3033
embeddings = self.model.embed(batch)
3134

3235
return embed_pb2.EmbedResponse(embeddings=embeddings)
3336

3437
async def Predict(self, request, context):
35-
batch = self.model.batch_type.from_pb(request, self.model.device)
38+
max_input_length = self.model.max_input_length
39+
batch = self.model.batch_type.from_pb(
40+
request, self.model.device, max_input_length
41+
)
3642

3743
scores = self.model.predict(batch)
3844

0 commit comments

Comments
 (0)