@@ -72,26 +72,17 @@ async def _run_async(self) -> TextGenerationBenchmark:
72
72
)
73
73
load_gen = LoadGenerator (self ._load_gen_mode , self ._load_gen_rate )
74
74
75
- coroutines = []
75
+ tasks = []
76
76
start_time = time .time ()
77
77
counter = 0
78
-
79
78
try :
80
- for text_generation_request , task_start_time in zip (
81
- self ._request_generator , load_gen .times ()
82
- ):
83
- coro = Task (
84
- func = self ._backend .submit ,
85
- params = {"request" : text_generation_request .prompt },
86
- err_container = TextGenerationError ,
87
- )
88
-
79
+ for task , task_start_time in zip (self ._task_iterator (), load_gen .times ()):
89
80
pending_time = task_start_time - time .time ()
90
81
91
82
if pending_time > 0 :
92
83
await asyncio .sleep (pending_time )
93
84
94
- coroutines .append (self ._run_task_async (coro , result_set ))
85
+ tasks .append (self ._run_task_async (task , result_set ))
95
86
counter += 1
96
87
97
88
if (
@@ -108,17 +99,24 @@ async def _run_async(self) -> TextGenerationBenchmark:
108
99
await asyncio .sleep (pending_duration )
109
100
raise asyncio .CancelledError ()
110
101
111
- await asyncio .gather (* coroutines )
112
-
102
+ await asyncio .gather (* tasks )
113
103
except asyncio .CancelledError :
114
104
# Cancel all pending tasks
115
- for coro in coroutines :
116
- if not coro .done ():
117
- coro .cancel ()
105
+ for task in tasks :
106
+ if not task .done ():
107
+ task .cancel ()
118
108
119
109
return result_set
120
110
121
111
async def _run_task_async (self , task : Task , result_set : TextGenerationBenchmark ):
122
112
result_set .request_started ()
123
113
res = await task .run_async ()
124
114
result_set .request_completed (res )
115
+
116
+ def _task_iterator (self ) -> Iterable [Task ]:
117
+ for request in self ._request_generator :
118
+ yield Task (
119
+ func = self ._backend .submit ,
120
+ params = {"request" : request },
121
+ err_container = TextGenerationError ,
122
+ )
0 commit comments