@@ -48,6 +48,7 @@ enum Kind {
48
48
49
49
#[ derive( Debug , PartialEq , Clone , Copy ) ]
50
50
enum ChunkedState {
51
+ Start ,
51
52
Size ,
52
53
SizeLws ,
53
54
Extension ,
@@ -73,7 +74,7 @@ impl Decoder {
73
74
74
75
pub ( crate ) fn chunked ( ) -> Decoder {
75
76
Decoder {
76
- kind : Kind :: Chunked ( ChunkedState :: Size , 0 ) ,
77
+ kind : Kind :: Chunked ( ChunkedState :: new ( ) , 0 ) ,
77
78
}
78
79
}
79
80
@@ -181,7 +182,22 @@ macro_rules! byte (
181
182
} )
182
183
) ;
183
184
185
+ macro_rules! or_overflow {
186
+ ( $e: expr) => (
187
+ match $e {
188
+ Some ( val) => val,
189
+ None => return Poll :: Ready ( Err ( io:: Error :: new(
190
+ io:: ErrorKind :: InvalidData ,
191
+ "invalid chunk size: overflow" ,
192
+ ) ) ) ,
193
+ }
194
+ )
195
+ }
196
+
184
197
impl ChunkedState {
198
+ fn new ( ) -> ChunkedState {
199
+ ChunkedState :: Start
200
+ }
185
201
fn step < R : MemRead > (
186
202
& self ,
187
203
cx : & mut Context < ' _ > ,
@@ -191,6 +207,7 @@ impl ChunkedState {
191
207
) -> Poll < Result < ChunkedState , io:: Error > > {
192
208
use self :: ChunkedState :: * ;
193
209
match * self {
210
+ Start => ChunkedState :: read_start ( cx, body, size) ,
194
211
Size => ChunkedState :: read_size ( cx, body, size) ,
195
212
SizeLws => ChunkedState :: read_size_lws ( cx, body) ,
196
213
Extension => ChunkedState :: read_extension ( cx, body) ,
@@ -205,25 +222,46 @@ impl ChunkedState {
205
222
End => Poll :: Ready ( Ok ( ChunkedState :: End ) ) ,
206
223
}
207
224
}
208
- fn read_size < R : MemRead > (
225
+
226
+ fn read_start < R : MemRead > (
209
227
cx : & mut Context < ' _ > ,
210
228
rdr : & mut R ,
211
229
size : & mut u64 ,
212
230
) -> Poll < Result < ChunkedState , io:: Error > > {
213
- trace ! ( "Read chunk hex size " ) ;
231
+ trace ! ( "Read chunk start " ) ;
214
232
215
- macro_rules! or_overflow {
216
- ( $e: expr) => (
217
- match $e {
218
- Some ( val) => val,
219
- None => return Poll :: Ready ( Err ( io:: Error :: new(
220
- io:: ErrorKind :: InvalidData ,
221
- "invalid chunk size: overflow" ,
222
- ) ) ) ,
223
- }
224
- )
233
+ let radix = 16 ;
234
+ match byte ! ( rdr, cx) {
235
+ b @ b'0' ..=b'9' => {
236
+ * size = or_overflow ! ( size. checked_mul( radix) ) ;
237
+ * size = or_overflow ! ( size. checked_add( ( b - b'0' ) as u64 ) ) ;
238
+ }
239
+ b @ b'a' ..=b'f' => {
240
+ * size = or_overflow ! ( size. checked_mul( radix) ) ;
241
+ * size = or_overflow ! ( size. checked_add( ( b + 10 - b'a' ) as u64 ) ) ;
242
+ }
243
+ b @ b'A' ..=b'F' => {
244
+ * size = or_overflow ! ( size. checked_mul( radix) ) ;
245
+ * size = or_overflow ! ( size. checked_add( ( b + 10 - b'A' ) as u64 ) ) ;
246
+ }
247
+ _ => {
248
+ return Poll :: Ready ( Err ( io:: Error :: new (
249
+ io:: ErrorKind :: InvalidInput ,
250
+ "Invalid chunk size line: missing size digit" ,
251
+ ) ) ) ;
252
+ }
225
253
}
226
254
255
+ Poll :: Ready ( Ok ( ChunkedState :: Size ) )
256
+ }
257
+
258
+ fn read_size < R : MemRead > (
259
+ cx : & mut Context < ' _ > ,
260
+ rdr : & mut R ,
261
+ size : & mut u64 ,
262
+ ) -> Poll < Result < ChunkedState , io:: Error > > {
263
+ trace ! ( "Read chunk hex size" ) ;
264
+
227
265
let radix = 16 ;
228
266
match byte ! ( rdr, cx) {
229
267
b @ b'0' ..=b'9' => {
@@ -478,7 +516,7 @@ mod tests {
478
516
use std:: io:: ErrorKind :: { InvalidData , InvalidInput , UnexpectedEof } ;
479
517
480
518
async fn read ( s : & str ) -> u64 {
481
- let mut state = ChunkedState :: Size ;
519
+ let mut state = ChunkedState :: new ( ) ;
482
520
let rdr = & mut s. as_bytes ( ) ;
483
521
let mut size = 0 ;
484
522
loop {
@@ -495,7 +533,7 @@ mod tests {
495
533
}
496
534
497
535
async fn read_err ( s : & str , expected_err : io:: ErrorKind ) {
498
- let mut state = ChunkedState :: Size ;
536
+ let mut state = ChunkedState :: new ( ) ;
499
537
let rdr = & mut s. as_bytes ( ) ;
500
538
let mut size = 0 ;
501
539
loop {
@@ -532,6 +570,9 @@ mod tests {
532
570
// Missing LF or CRLF
533
571
read_err ( "F\r F" , InvalidInput ) . await ;
534
572
read_err ( "F" , UnexpectedEof ) . await ;
573
+ // Missing digit
574
+ read_err ( "\r \n \r \n " , InvalidInput ) . await ;
575
+ read_err ( "\r \n " , InvalidInput ) . await ;
535
576
// Invalid hex digit
536
577
read_err ( "X\r \n " , InvalidInput ) . await ;
537
578
read_err ( "1X\r \n " , InvalidInput ) . await ;
0 commit comments