Skip to content

Commit 24e6527

Browse files
committed
Add support to images in requests
1 parent cb1f244 commit 24e6527

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

src/guidellm/backend/openai.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import AsyncGenerator, Dict, List, Optional
2+
import io, base64
23

34
from loguru import logger
45
from openai import AsyncOpenAI, OpenAI
@@ -103,11 +104,11 @@ async def make_request(
103104

104105
request_args.update(self._request_args)
105106

107+
messages = self._build_messages(request)
108+
106109
stream = await self._async_client.chat.completions.create(
107110
model=self.model,
108-
messages=[
109-
{"role": "user", "content": request.prompt},
110-
],
111+
messages=messages,
111112
stream=True,
112113
**request_args,
113114
)
@@ -167,3 +168,21 @@ def validate_connection(self):
167168
except Exception as error:
168169
logger.error("Failed to validate OpenAI connection: {}", error)
169170
raise error
171+
172+
def _build_messages(self, request: TextGenerationRequest) -> Dict:
173+
if request.number_images == 0:
174+
messages = [{"role": "user", "content": request.prompt}]
175+
else:
176+
content = []
177+
for image in request.images:
178+
stream = io.BytesIO()
179+
im_format = image.image.format or "PNG"
180+
image.image.save(stream, format=im_format)
181+
im_b64 = base64.b64encode(stream.getvalue()).decode("ascii")
182+
image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"}
183+
content.append({"type": "image_url", "image_url": image_url})
184+
185+
content.append({"type": "text", "text": request.prompt})
186+
messages = [{"role": "user", "content": content}]
187+
188+
return messages

src/guidellm/core/request.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import uuid
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, Optional, List
33

44
from pydantic import Field
55

66
from guidellm.core.serializable import Serializable
7+
from guidellm.utils import ImageDescriptor
78

89

910
class TextGenerationRequest(Serializable):
@@ -16,6 +17,10 @@ class TextGenerationRequest(Serializable):
1617
description="The unique identifier for the request.",
1718
)
1819
prompt: str = Field(description="The input prompt for the text generation.")
20+
images: Optional[List[ImageDescriptor]] = Field(
21+
default=None,
22+
description="Input images.",
23+
)
1924
prompt_token_count: Optional[int] = Field(
2025
default=None,
2126
description="The number of tokens in the input prompt.",
@@ -29,6 +34,13 @@ class TextGenerationRequest(Serializable):
2934
description="The parameters for the text generation request.",
3035
)
3136

37+
@property
38+
def number_images(self) -> int:
39+
if self.images is None:
40+
return 0
41+
else:
42+
return len(self.images)
43+
3244
def __str__(self) -> str:
3345
prompt_short = (
3446
self.prompt[:32] + "..."
@@ -41,4 +53,5 @@ def __str__(self) -> str:
4153
f"prompt={prompt_short}, prompt_token_count={self.prompt_token_count}, "
4254
f"output_token_count={self.output_token_count}, "
4355
f"params={self.params})"
56+
f"images={self.number_images}"
4457
)

src/guidellm/request/emulated.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from guidellm.config import settings
1212
from guidellm.core.request import TextGenerationRequest
1313
from guidellm.request.base import GenerationMode, RequestGenerator
14-
from guidellm.utils import clean_text, filter_text, load_text, split_text
14+
from guidellm.utils import clean_text, filter_text, load_text, split_text, load_images
1515

1616
__all__ = ["EmulatedConfig", "EmulatedRequestGenerator", "EndlessTokens"]
1717

@@ -30,6 +30,7 @@ class EmulatedConfig:
3030
generated_tokens_variance (Optional[int]): Variance for generated tokens.
3131
generated_tokens_min (Optional[int]): Minimum number of generated tokens.
3232
generated_tokens_max (Optional[int]): Maximum number of generated tokens.
33+
images (Optional[int]): Number of input images.
3334
"""
3435

3536
@staticmethod
@@ -47,7 +48,7 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
4748
"""
4849
if not config:
4950
logger.debug("Creating default configuration")
50-
return EmulatedConfig(prompt_tokens=1024, generated_tokens=256)
51+
return EmulatedConfig(prompt_tokens=1024, generated_tokens=256, images=0)
5152

5253
if isinstance(config, dict):
5354
logger.debug("Loading configuration from dict: {}", config)
@@ -105,6 +106,8 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
105106
generated_tokens_min: Optional[int] = None
106107
generated_tokens_max: Optional[int] = None
107108

109+
images: int = 0
110+
108111
@property
109112
def prompt_tokens_range(self) -> Tuple[int, int]:
110113
"""
@@ -327,6 +330,8 @@ def __init__(
327330
settings.emulated_data.filter_start,
328331
settings.emulated_data.filter_end,
329332
)
333+
if self._config.images > 0:
334+
self._images = load_images(settings.emulated_data.image_source)
330335
self._rng = np.random.default_rng(random_seed)
331336

332337
# NOTE: Must be after all the parameters since the queue population
@@ -355,6 +360,7 @@ def create_item(self) -> TextGenerationRequest:
355360
logger.debug("Creating new text generation request")
356361
target_prompt_token_count = self._config.sample_prompt_tokens(self._rng)
357362
prompt = self.sample_prompt(target_prompt_token_count)
363+
images = self.sample_images()
358364
prompt_token_count = len(self.tokenizer.tokenize(prompt))
359365
output_token_count = self._config.sample_output_tokens(self._rng)
360366
logger.debug("Generated prompt: {}", prompt)
@@ -363,6 +369,7 @@ def create_item(self) -> TextGenerationRequest:
363369
prompt=prompt,
364370
prompt_token_count=prompt_token_count,
365371
output_token_count=output_token_count,
372+
images=images,
366373
)
367374

368375
def sample_prompt(self, tokens: int) -> str:
@@ -395,3 +402,9 @@ def sample_prompt(self, tokens: int) -> str:
395402
right = mid
396403

397404
return self._tokens.create_text(start_line_index, left)
405+
406+
407+
def sample_images(self):
408+
image_indices = self._rng.choice(len(self._images), size=self._config.images, replace=False)
409+
410+
return [self._images[i] for i in image_indices]

0 commit comments

Comments
 (0)