Skip to content

Commit 98f79c4

Browse files
committed
update fixed rate profile generator, add tests, set up initial executor test
1 parent bf72422 commit 98f79c4

File tree

3 files changed

+72
-20
lines changed

3 files changed

+72
-20
lines changed

src/guidellm/executor/profile_generator.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,30 +69,37 @@ def next_profile(
6969

7070
@ProfileGenerator.register_generator(ProfileGenerationModes.FIXED)
7171
class FixedRateProfileGenerator(ProfileGenerator):
72-
def __init__(self, rate: List[float], rate_type: str, **kwargs):
72+
def __init__(self, rate_type: str, rate: Optional[List[float]] = None, **kwargs):
7373
super().__init__(ProfileGenerationModes.FIXED)
74-
if rate_type == "synchronous" and len(rate) > 0:
74+
if rate_type == "synchronous" and rate and len(rate) > 0:
7575
raise ValueError("custom rates are not supported in synchronous mode")
7676
self._rates = rate
7777
self._rate_index = 0
78+
self._generated = False
7879
self._rate_type = rate_type
7980

8081
def next_profile(
8182
self, current_report: TextGenerationBenchmarkReport
8283
) -> Optional[Profile]:
83-
if self._rate_index >= len(self._rates):
84-
return None
84+
if self._rate_type == "synchronous":
85+
if self._generated:
86+
return None
8587

86-
current_rate = self._rates[self._rate_index]
87-
self._rate_index += 1
88+
self._generated = True
8889

89-
if self._rate_type == "synchronous":
9090
return Profile(
9191
load_gen_mode=LoadGenerationModes.SYNCHRONOUS, load_gen_rate=None
9292
)
9393

9494
if self._rate_type in {"constant", "poisson"}:
95+
if self._rate_index >= len(self._rates):
96+
return None
97+
98+
current_rate = self._rates[self._rate_index]
99+
self._rate_index += 1
100+
95101
load_gen_mode = RateTypeLoadGenModeMap[self._rate_type]
102+
96103
return Profile(
97104
load_gen_mode=load_gen_mode, load_gen_rate=current_rate
98105
)

tests/unit/executor/test_executor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from unittest.mock import MagicMock
2+
from src.guidellm.backend.base import Backend
3+
from src.guidellm.executor.executor import Executor
4+
from src.guidellm.request.base import RequestGenerator
5+
6+
def test_executor_creation():
7+
mock_request_generator = MagicMock(spec=RequestGenerator)
8+
mock_backend = MagicMock(spec=Backend)
9+
rate_type = "sweep"
10+
profile_args = None
11+
max_requests = None,
12+
max_duration = None,
13+
executor = Executor(mock_request_generator, mock_backend, rate_type, profile_args, max_requests, max_duration);
14+
assert executor.request_generator == mock_request_generator
15+
assert executor.backend == mock_backend
16+
assert executor.max_requests == max_requests
17+
assert executor.max_duration == max_duration
Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
2-
2+
from unittest.mock import MagicMock
33
from guidellm.executor import (ProfileGenerator, FixedRateProfileGenerator, SweepProfileGenerator)
4+
from src.guidellm.core.result import TextGenerationBenchmarkReport
5+
from src.guidellm.scheduler.load_generator import LoadGenerationModes
46

57
def test_invalid_profile_generation_mode_error():
68
rate = [1]
@@ -10,19 +12,45 @@ def test_invalid_profile_generation_mode_error():
1012
ProfileGenerator.create_generator(profile_mode, **({ "rate": rate, "rate_type": rate_type}))
1113

1214
def test_sweep_profile_generator_creation():
13-
profile = ProfileGenerator.create_generator("sweep", **({}))
14-
assert isinstance(profile, SweepProfileGenerator)
15-
assert profile._sync_run == False
16-
assert profile._max_found == False
17-
assert profile._pending_rates == None
18-
assert profile._pending_rates == None
15+
profile_generator = ProfileGenerator.create_generator("sweep", **({}))
16+
assert isinstance(profile_generator, SweepProfileGenerator)
17+
assert profile_generator._sync_run == False
18+
assert profile_generator._max_found == False
19+
assert profile_generator._pending_rates == None
20+
assert profile_generator._pending_rates == None
1921

2022
def test_fixed_rate_profile_generator_creation():
2123
rate = [1]
2224
rate_type = "constant"
23-
profile = ProfileGenerator.create_generator("fixed_rate", **({ "rate": rate, "rate_type": rate_type}))
24-
assert isinstance(profile, FixedRateProfileGenerator)
25-
assert profile._rates == rate
26-
assert profile._rate_type == rate_type
27-
assert profile._rate_index == 0
28-
assert profile._rate_index == 0
25+
profile_generator = ProfileGenerator.create_generator("fixed_rate", **({ "rate": rate, "rate_type": rate_type}))
26+
assert isinstance(profile_generator, FixedRateProfileGenerator)
27+
assert profile_generator._rates == rate
28+
assert profile_generator._rate_type == rate_type
29+
assert profile_generator._rate_index == 0
30+
assert profile_generator._rate_index == 0
31+
32+
def test_synchronous_mode_rate_list_error():
33+
rate = [1]
34+
rate_type = "synchronous"
35+
with pytest.raises(ValueError, match="custom rates are not supported in synchronous mode"):
36+
ProfileGenerator.create_generator("fixed_rate", **({ "rate": rate, "rate_type": rate_type}))
37+
38+
def test_next_profile_with_multiple_rates():
39+
rates = [1, 2]
40+
rate_type = "constant"
41+
profile_generator = ProfileGenerator.create_generator("fixed_rate", **({ "rate": rates, "rate_type": rate_type}))
42+
mock_report = MagicMock(spec=TextGenerationBenchmarkReport)
43+
for rate in rates:
44+
current_profile = profile_generator.next_profile(mock_report)
45+
assert current_profile.load_gen_rate == rate
46+
assert current_profile.load_gen_mode.name == LoadGenerationModes.CONSTANT.name
47+
assert profile_generator.next_profile(mock_report) == None
48+
49+
def test_next_profile_with_sync_mode():
50+
rate_type = "synchronous"
51+
profile_generator = ProfileGenerator.create_generator("fixed_rate", **({ "rate_type": rate_type}))
52+
mock_report = MagicMock(spec=TextGenerationBenchmarkReport)
53+
current_profile = profile_generator.next_profile(mock_report)
54+
assert current_profile.load_gen_rate == None
55+
assert current_profile.load_gen_mode.name == LoadGenerationModes.SYNCHRONOUS.name
56+
assert profile_generator.next_profile(mock_report) == None

0 commit comments

Comments
 (0)