16
16
use think \middleware \throttle \ThrottleAbstract ;
17
17
use think \Request ;
18
18
use think \Response ;
19
+ use think \Session ;
19
20
use TypeError ;
20
21
use function sprintf ;
21
22
@@ -37,7 +38,7 @@ class Throttle
37
38
'visit_rate ' => '' , // 节流频率, 空字符串表示不限制 eg: '', '10/m', '20/h', '300/d'
38
39
'visit_enable_show_rate_limit ' => true , // 在响应体中设置速率限制的头部信息
39
40
'visit_fail_code ' => 429 , // 访问受限时返回的 http 状态码,当没有 visit_fail_response 时生效
40
- 'visit_fail_text ' => 'Too Many Requests ' , // 访问受限时访问的文本信息,当没有 visit_fail_response 时生效
41
+ 'visit_fail_text ' => 'Too many requests, try again after __WAIT__ seconds. ' , // 访问受限时访问的文本信息
41
42
'visit_fail_response ' => null , // 访问受限时的响应信息闭包回调
42
43
'driver_name ' => CounterFixed::class, // 限流算法驱动
43
44
];
@@ -55,6 +56,7 @@ class Throttle
55
56
*/
56
57
protected CacheInterface $ cache ;
57
58
protected App $ app ;
59
+ protected Session $ session ;
58
60
59
61
/**
60
62
* 配置参数
@@ -74,12 +76,13 @@ class Throttle
74
76
* @param Cache $cache
75
77
* @param Config $config
76
78
*/
77
- public function __construct (Cache $ cache , Config $ config , App $ app )
79
+ public function __construct (Cache $ cache , Config $ config , App $ app, Session $ session )
78
80
{
79
81
$ this ->cache = $ cache ;
80
82
$ this ->config = array_merge (static ::$ default_config , $ config ->get ('throttle ' , []));
81
83
$ this ->app = $ app ;
82
84
$ this ->config_instance = $ config ;
85
+ $ this ->session = $ session ;
83
86
}
84
87
85
88
/**
@@ -95,7 +98,7 @@ public function handle(Request $request, Closure $next, array $params = []): Res
95
98
$ this ->config = array_merge ($ this ->config , $ params );
96
99
}
97
100
98
- $ allow = $ this ->allowRequestByConfig ($ request ) && $ this ->allowRequestByAnnotation ($ request );
101
+ $ allow = $ this ->allowRequestByAnnotation ($ request ) && $ this ->allowRequestByConfig ($ request );
99
102
if (!$ allow ) {
100
103
// 访问受限
101
104
throw $ this ->buildLimitException ($ this ->wait_seconds , $ request );
@@ -113,43 +116,73 @@ public function handle(Request $request, Closure $next, array $params = []): Res
113
116
}
114
117
115
118
/**
116
- * 根据**配置 **信息是否允许请求通过
119
+ * 根据**注解 **信息是否允许请求通过
117
120
* @param Request $request
118
121
* @return bool
119
122
*/
120
- protected function allowRequestByConfig (Request $ request ): bool
123
+ protected function allowRequestByAnnotation (Request $ request ): bool
121
124
{
122
- // 若请求类型不在限制内
123
- if (!in_array ($ request ->method (), $ this ->config ['visit_method ' ])) {
124
- return true ;
125
+ // 处理注解
126
+ $ controller = $ this ->getFullController ($ request );
127
+ if ($ controller ) {
128
+ $ action = $ request ->action ();
129
+ if (method_exists ($ controller , $ action )) {
130
+ $ reflectionMethod = new ReflectionMethod ($ controller , $ action );
131
+ $ attributes = $ reflectionMethod ->getAttributes (RateLimitAnnotation::class);
132
+ foreach ($ attributes as $ attribute ) {
133
+ $ annotation = $ attribute ->newInstance ();
134
+ $ key = $ this ->getCacheKey ($ request , $ annotation ->key , $ annotation ->driver , true );
135
+ if (!$ this ->allowRequest ($ key , $ annotation ->rate , $ annotation ->driver )) {
136
+ $ this ->config ['visit_fail_text ' ] = $ annotation ->message ;
137
+ return false ;
138
+ }
139
+ }
140
+ }
125
141
}
126
- $ driver = $ this ->config ['driver_name ' ];
127
- $ key = $ this ->getCacheKey ($ request , $ this ->config ['key ' ], $ driver );
128
- return $ this ->allowRequest ($ key , $ this ->config ['visit_rate ' ], $ driver );
142
+ return true ;
143
+ }
144
+
145
+ private function getFullController (Request $ request ): string
146
+ {
147
+ $ controller = $ request ->controller ();
148
+ if (empty ($ controller )) {
149
+ return '' ;
150
+ }
151
+ $ suffix = $ this ->config_instance ->get ('route.controller_suffix ' ) ? 'Controller ' : '' ;
152
+ $ layer = $ this ->config_instance ->get ('route.controller_layer ' ) ?: 'controller ' ;
153
+ $ controllerClassName = $ this ->app ->parseClass ($ layer , $ controller . $ suffix );
154
+ return $ controllerClassName ;
129
155
}
130
156
131
157
/**
132
158
* 生成缓存的 key
133
159
* @param Request $request
134
- * @param string|bool|Closure|null $key
160
+ * @param string|bool|Closure $key
135
161
* @param string $driver
136
162
* @return string
137
163
*/
138
- protected function getCacheKey (Request $ request , string |bool |Closure | null $ key , string $ driver ): string
164
+ protected function getCacheKey (Request $ request , string |bool |Closure $ key , string $ driver, bool $ annotation = false ): string
139
165
{
140
166
if ($ key instanceof Closure) {
141
167
$ key = Container::getInstance ()->invokeFunction ($ key , [$ this , $ request ]);
142
168
}
143
169
144
- if ($ key === null || $ key === false ) {
145
- // 关闭当前限制
170
+ if ($ key === false || $ key === '' ) {
171
+ // 不做限制
146
172
return '' ;
147
173
}
148
174
149
175
if ($ key === true ) {
150
176
$ key = $ request ->ip ();
151
177
} elseif (is_string ($ key ) && str_contains ($ key , '__ ' )) {
152
- $ key = str_replace (['__CONTROLLER__ ' , '__ACTION__ ' , '__IP__ ' ], [$ request ->controller (), $ request ->action (), $ request ->ip ()], $ key );
178
+ $ key = str_replace (['__CONTROLLER__ ' , '__ACTION__ ' , '__IP__ ' , '__SESSION__ ' ],
179
+ [$ request ->controller (), $ request ->action (), $ request ->ip (), $ this ->session ->getId ()],
180
+ $ key );
181
+ }
182
+
183
+ if ($ annotation ) {
184
+ // 注解需要以实际方法作为前缀
185
+ $ key = $ request ->controller () . $ request ->action () . $ key ;
153
186
}
154
187
155
188
return md5 ($ this ->config ['prefix ' ] . $ key . $ driver );
@@ -206,44 +239,19 @@ protected function parseRate(string $rate): array
206
239
}
207
240
208
241
/**
209
- * 根据**注解 **信息是否允许请求通过
242
+ * 根据**配置 **信息是否允许请求通过
210
243
* @param Request $request
211
244
* @return bool
212
245
*/
213
- protected function allowRequestByAnnotation (Request $ request ): bool
214
- {
215
- // 处理注解
216
- $ controller = $ this ->getFullController ($ request );
217
- if ($ controller ) {
218
- $ action = $ request ->action ();
219
- if (method_exists ($ controller , $ action )) {
220
- $ reflectionMethod = new ReflectionMethod ($ controller , $ action );
221
- $ attributes = $ reflectionMethod ->getAttributes (RateLimitAnnotation::class);
222
- foreach ($ attributes as $ attribute ) {
223
- $ annotation = $ attribute ->newInstance ();
224
- $ key = $ this ->getCacheKey ($ request , $ annotation ->key , $ annotation ->driver );
225
- $ key = $ controller . $ action . $ key ; // 注解需要以实际方法作为前缀
226
-
227
- if (!$ this ->allowRequest ($ key , $ annotation ->rate , $ annotation ->driver )) {
228
- $ this ->config ['visit_fail_text ' ] = $ annotation ->message ;
229
- return false ;
230
- }
231
- }
232
- }
233
- }
234
- return true ;
235
- }
236
-
237
- private function getFullController (Request $ request ): string
246
+ protected function allowRequestByConfig (Request $ request ): bool
238
247
{
239
- $ controller = $ request -> controller ();
240
- if (empty ( $ controller )) {
241
- return '' ;
248
+ // 若请求类型不在限制内
249
+ if (! in_array ( $ request -> method (), $ this -> config [ ' visit_method ' ] )) {
250
+ return true ;
242
251
}
243
- $ suffix = $ this ->config_instance ->get ('route.controller_suffix ' ) ? 'Controller ' : '' ;
244
- $ layer = $ this ->config_instance ->get ('route.controller_layer ' ) ?: 'controller ' ;
245
- $ controllerClassName = $ this ->app ->parseClass ($ layer , $ controller . $ suffix );
246
- return $ controllerClassName ;
252
+ $ driver = $ this ->config ['driver_name ' ];
253
+ $ key = $ this ->getCacheKey ($ request , $ this ->config ['key ' ], $ driver );
254
+ return $ this ->allowRequest ($ key , $ this ->config ['visit_rate ' ], $ driver );
247
255
}
248
256
249
257
/**
@@ -261,7 +269,7 @@ public function buildLimitException(int $wait_seconds, Request $request): HttpRe
261
269
throw new TypeError (sprintf ('The closure must return %s instance ' , Response::class));
262
270
}
263
271
} else {
264
- $ content = str_replace ('__WAIT__ ' , (string )$ wait_seconds , $ this ->config [ ' visit_fail_text ' ] );
272
+ $ content = str_replace ('__WAIT__ ' , (string )$ wait_seconds , $ this ->getFailMessage () );
265
273
$ response = Response::create ($ content )->code ($ this ->config ['visit_fail_code ' ]);
266
274
}
267
275
if ($ this ->config ['visit_enable_show_rate_limit ' ]) {
@@ -270,6 +278,15 @@ public function buildLimitException(int $wait_seconds, Request $request): HttpRe
270
278
return new HttpResponseException ($ response );
271
279
}
272
280
281
+ /**
282
+ * 获取受限时的信息
283
+ * @return string
284
+ */
285
+ public function getFailMessage (): string
286
+ {
287
+ return $ this ->config ['visit_fail_text ' ];
288
+ }
289
+
273
290
/**
274
291
* 设置速率
275
292
* @param string $rate '10/m' '20/300'
0 commit comments