11
11
from guidellm .config import settings
12
12
from guidellm .core .request import TextGenerationRequest
13
13
from guidellm .request .base import GenerationMode , RequestGenerator
14
- from guidellm .utils import clean_text , filter_text , load_text , split_text
14
+ from guidellm .utils import clean_text , filter_text , load_text , split_text , load_images
15
15
16
16
__all__ = ["EmulatedConfig" , "EmulatedRequestGenerator" , "EndlessTokens" ]
17
17
@@ -30,6 +30,7 @@ class EmulatedConfig:
30
30
generated_tokens_variance (Optional[int]): Variance for generated tokens.
31
31
generated_tokens_min (Optional[int]): Minimum number of generated tokens.
32
32
generated_tokens_max (Optional[int]): Maximum number of generated tokens.
33
+ images (Optional[int]): Number of input images.
33
34
"""
34
35
35
36
@staticmethod
@@ -47,7 +48,7 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
47
48
"""
48
49
if not config :
49
50
logger .debug ("Creating default configuration" )
50
- return EmulatedConfig (prompt_tokens = 1024 , generated_tokens = 256 )
51
+ return EmulatedConfig (prompt_tokens = 1024 , generated_tokens = 256 , images = 0 )
51
52
52
53
if isinstance (config , dict ):
53
54
logger .debug ("Loading configuration from dict: {}" , config )
@@ -105,6 +106,8 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
105
106
generated_tokens_min : Optional [int ] = None
106
107
generated_tokens_max : Optional [int ] = None
107
108
109
+ images : int = 0
110
+
108
111
@property
109
112
def prompt_tokens_range (self ) -> Tuple [int , int ]:
110
113
"""
@@ -327,6 +330,8 @@ def __init__(
327
330
settings .emulated_data .filter_start ,
328
331
settings .emulated_data .filter_end ,
329
332
)
333
+ if self ._config .images > 0 :
334
+ self ._images = load_images (settings .emulated_data .image_source )
330
335
self ._rng = np .random .default_rng (random_seed )
331
336
332
337
# NOTE: Must be after all the parameters since the queue population
@@ -355,6 +360,7 @@ def create_item(self) -> TextGenerationRequest:
355
360
logger .debug ("Creating new text generation request" )
356
361
target_prompt_token_count = self ._config .sample_prompt_tokens (self ._rng )
357
362
prompt = self .sample_prompt (target_prompt_token_count )
363
+ images = self .sample_images ()
358
364
prompt_token_count = len (self .tokenizer .tokenize (prompt ))
359
365
output_token_count = self ._config .sample_output_tokens (self ._rng )
360
366
logger .debug ("Generated prompt: {}" , prompt )
@@ -363,6 +369,7 @@ def create_item(self) -> TextGenerationRequest:
363
369
prompt = prompt ,
364
370
prompt_token_count = prompt_token_count ,
365
371
output_token_count = output_token_count ,
372
+ images = images ,
366
373
)
367
374
368
375
def sample_prompt (self , tokens : int ) -> str :
@@ -395,3 +402,9 @@ def sample_prompt(self, tokens: int) -> str:
395
402
right = mid
396
403
397
404
return self ._tokens .create_text (start_line_index , left )
405
+
406
+
407
+ def sample_images (self ):
408
+ image_indices = self ._rng .choice (len (self ._images ), size = self ._config .images , replace = False )
409
+
410
+ return [self ._images [i ] for i in image_indices ]
0 commit comments