Skip to content

Commit 8c4ea83

Browse files
author
Dmytro Parfeniuk
committed
🚧 WIP
1 parent faa88cc commit 8c4ea83

File tree

9 files changed

+64
-46
lines changed

9 files changed

+64
-46
lines changed

src/domain/backend/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import uuid
21
from abc import ABC, abstractmethod
32
from dataclasses import dataclass
43
from enum import Enum
@@ -91,16 +90,17 @@ def submit(self, request: TextGenerationRequest) -> TextGenerationResult:
9190
"""
9291

9392
logger.info(f"Submitting request with prompt: {request.prompt}")
94-
result_id = str(uuid.uuid4())
95-
result = TextGenerationResult(result_id)
93+
result = TextGenerationResult(request=request)
9694
result.start(request.prompt)
9795

96+
breakpoint() # TODO: remove
9897
for response in self.make_request(request):
9998
if response.type_ == "token_iter" and response.add_token:
10099
result.output_token(response.add_token)
101100
elif response.type_ == "final":
102101
result.end(
103-
response.output,
102+
# NOTE: clarify if the `or ""` makesa any sense
103+
response.output or "",
104104
response.prompt_token_count,
105105
response.output_token_count,
106106
)

src/domain/core/result.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, request: TextGenerationRequest):
4040
self._first_token_time: Optional[float] = None
4141
self._decode_times = Distribution()
4242

43+
breakpoint() # TODO: remove
4344
logger.debug(f"Initialized TextGenerationResult for request: {self._request}")
4445

4546
def __repr__(self) -> str:

src/domain/load_generator/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88

99

1010
class LoadGenerationMode(str, Enum):
11+
"""
12+
Available values:
13+
* SYNCHRONOUS
14+
* CONSTANT
15+
* POISSION
16+
"""
17+
1118
SYNCHRONOUS = "sync"
1219
CONSTANT = "constant"
1320
POISSON = "poisson"

src/domain/request/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import threading
23
import time
34
from abc import ABC, abstractmethod
@@ -46,8 +47,7 @@ def __init__(
4647
logger.debug("No tokenizer provided")
4748

4849
if self._mode == "async":
49-
self._thread = threading.Thread(target=self._populate_queue)
50-
self._thread.daemon = True
50+
self._thread = threading.Thread(target=self._populate_queue, daemon=True)
5151
self._thread.start()
5252
logger.info(
5353
"RequestGenerator started in async mode with queue size: {}",
@@ -142,7 +142,7 @@ def _populate_queue(self):
142142
Populate the request queue in the background.
143143
"""
144144
while not self._stop_event.is_set():
145-
try:
145+
with contextlib.suppress(Full):
146146
if self._queue.qsize() < self._async_queue_size:
147147
item = self.create_item()
148148
self._queue.put(item, timeout=0.1)
@@ -151,7 +151,7 @@ def _populate_queue(self):
151151
self._queue.qsize(),
152152
)
153153
else:
154-
time.sleep(0.1)
155-
except Full:
156-
continue
154+
# print("\nSleeping on _populate_queue...")
155+
time.sleep(1) # TODO: Change me
156+
157157
logger.info("RequestGenerator stopped populating queue")

src/domain/scheduler/scheduler.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,13 @@
22
import time
33
from typing import Iterable, Optional
44

5-
from loguru import logger
6-
75
from domain.backend import Backend
86
from domain.core import TextGenerationBenchmark, TextGenerationError
97
from domain.load_generator import LoadGenerationMode, LoadGenerator
108
from domain.request import RequestGenerator
119

1210
from .task import Task
1311

14-
__all__ = ["Scheduler"]
15-
1612

1713
class Scheduler:
1814
def __init__(
@@ -76,19 +72,26 @@ async def _run_async(self) -> TextGenerationBenchmark:
7672
)
7773
load_gen = LoadGenerator(self._load_gen_mode, self._load_gen_rate)
7874

79-
tasks = []
75+
coroutines = []
8076
start_time = time.time()
8177
counter = 0
78+
8279
try:
83-
for task, task_start_time in zip(self._task_iterator(), load_gen.times()):
80+
for text_generation_request, task_start_time in zip(
81+
self._request_generator, load_gen.times()
82+
):
83+
coro = Task(
84+
func=self._backend.submit,
85+
params={"request": text_generation_request.prompt},
86+
err_container=TextGenerationError,
87+
)
88+
8489
pending_time = task_start_time - time.time()
8590

8691
if pending_time > 0:
8792
await asyncio.sleep(pending_time)
8893

89-
tasks.append(
90-
asyncio.create_task(self._run_task_async(task, result_set))
91-
)
94+
coroutines.append(self._run_task_async(coro, result_set))
9295
counter += 1
9396

9497
if (
@@ -105,24 +108,17 @@ async def _run_async(self) -> TextGenerationBenchmark:
105108
await asyncio.sleep(pending_duration)
106109
raise asyncio.CancelledError()
107110

108-
await asyncio.gather(*tasks)
111+
await asyncio.gather(*coroutines)
112+
109113
except asyncio.CancelledError:
110114
# Cancel all pending tasks
111-
for task in tasks:
112-
if not task.done():
113-
task.cancel()
115+
for coro in coroutines:
116+
if not coro.done():
117+
coro.cancel()
114118

115119
return result_set
116120

117121
async def _run_task_async(self, task: Task, result_set: TextGenerationBenchmark):
118122
result_set.request_started()
119123
res = await task.run_async()
120124
result_set.request_completed(res)
121-
122-
def _task_iterator(self) -> Iterable[Task]:
123-
for request in self._request_generator:
124-
yield Task(
125-
func=self._backend.submit,
126-
params={"request": request},
127-
err_container=TextGenerationError,
128-
)

src/domain/scheduler/task.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import asyncio
2+
import functools
23
import threading
34
from typing import Any, Callable, Dict, Optional
45

56
from loguru import logger
67

7-
__all__ = ["Task"]
8-
98

109
class Task:
1110
"""
@@ -26,10 +25,12 @@ def __init__(
2625
params: Optional[Dict[str, Any]] = None,
2726
err_container: Optional[Callable] = None,
2827
):
29-
self._func = func
30-
self._params = params or {}
31-
self._err_container = err_container
28+
self._func: Callable[..., Any] = func
29+
self._params: Dict[str, Any] = params or {}
30+
self._err_container: Optional[Callable] = err_container
3231
self._cancel_event = asyncio.Event()
32+
self._loop = asyncio.get_running_loop()
33+
3334
logger.info(
3435
f"Task created with function: {self._func.__name__} and "
3536
f"params: {self._params}"
@@ -51,7 +52,9 @@ async def run_async(self) -> Any:
5152
self._thread.start()
5253

5354
result = await asyncio.gather(
54-
asyncio.to_thread(self._func, **self._params),
55+
self._loop.run_in_executor(
56+
None, functools.partial(self._func, **self._params)
57+
),
5558
self._check_cancelled(),
5659
return_exceptions=True,
5760
)

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import List
2+
from typing import List, Optional
33

44
import pytest
55
from loguru import logger
@@ -49,11 +49,11 @@ def openai_backend_factory():
4949
Call without provided arguments returns default Backend service.
5050
"""
5151

52-
def inner_wrapper(*_, **kwargs) -> OpenAIBackend:
52+
def inner_wrapper(*_, base_url: Optional[str] = None, **kwargs) -> OpenAIBackend:
5353
static = {"backend_type": BackendEngine.OPENAI_SERVER}
5454
defaults = {
5555
"openai_api_key": "required but not used",
56-
"internal_callback_url": "http://localhost:8080",
56+
"internal_callback_url": base_url or "http://localhost:8080",
5757
}
5858

5959
defaults.update(kwargs)

tests/integration/executor/test_report_generation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import random
2+
13
import pytest
24

35
from domain.backend.base import BackendEngine
46
from domain.core.result import TextGenerationBenchmarkReport
57
from domain.executor import Executor, ProfileGenerationMode, SingleProfileGenerator
6-
from tests.dummy.services import TestRequestGenerator
8+
from tests import dummy
79

810

911
@pytest.mark.parametrize(
@@ -22,7 +24,9 @@ def test_executor_openai_unsupported_generation_modes(
2224
* Profile generation modes: sync,
2325
"""
2426

25-
request_genrator = TestRequestGenerator(tokenizer="bert-base-uncased")
27+
request_genrator = dummy.services.TestRequestGenerator(
28+
tokenizer="bert-base-uncased"
29+
)
2630
profile_generator_args = {"rate_type": profile_generation_mode, "rate": 1.0}
2731

2832
with pytest.raises(ValueError):
@@ -36,12 +40,20 @@ def test_executor_openai_unsupported_generation_modes(
3640
)
3741

3842

39-
def test_executor_openai_single_report_generation(openai_backend_factory):
43+
def test_executor_openai_single_report_generation(mocker, openai_backend_factory):
4044
"""
4145
Check OpenAI Single Report Generation.
46+
47+
1. create dummy data for all the OpenAI responses
48+
2. create an `Executor` instance
49+
3. run the executor
50+
4. check the executor schedule tasks for submiting requests
51+
5. validate the output report
4252
"""
4353

44-
request_genrator = TestRequestGenerator(tokenizer="bert-base-uncased")
54+
request_genrator = dummy.services.TestRequestGenerator(
55+
tokenizer="bert-base-uncased"
56+
)
4557
profile_generation_mode = ProfileGenerationMode.SINGLE
4658
profile_generator_args = {"rate_type": profile_generation_mode, "rate": 1.0}
4759

tests/unit/request/test_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from tests.dummy.services import TestRequestGenerator
88

99

10-
@pytest.mark.smoke
11-
def test_request_generator_sync_constructor():
10+
@pytest.mark.smoke def test_request_generator_sync_constructor():
1211
generator = TestRequestGenerator(mode="sync")
1312
assert generator.mode == "sync"
1413
assert generator.async_queue_size == 50 # Default value

0 commit comments

Comments
 (0)