11package middleware
22
33import (
4+ "bytes"
45 "fmt"
6+ "io"
57 "math"
68 "math/rand"
79 "net/http"
@@ -25,34 +27,44 @@ const (
2527// By default retries on network errors and 5xx responses. Can be configured to retry on specific status codes
2628// or to exclude specific codes from retry.
2729//
30+ // For requests with bodies (POST, PUT, PATCH), the middleware handles body replay:
31+ // - If req.GetBody is set (automatic for strings.Reader, bytes.Buffer, bytes.Reader), it uses that
32+ // - If req.GetBody is nil and body buffering is disabled (default), requests won't be retried
33+ // - If body buffering is enabled with RetryBufferBodies(true), bodies up to maxBufferSize are buffered
34+ //
2835// Default configuration:
2936// - 3 attempts
3037// - Initial delay: 100ms
3138// - Max delay: 30s
3239// - Exponential backoff
3340// - 10% jitter
3441// - Retries on 5xx status codes
42+ // - Body buffering disabled (preserves streaming, no retries for bodies without GetBody)
3543type RetryMiddleware struct {
36- next http.RoundTripper
37- attempts int
38- initialDelay time.Duration
39- maxDelay time.Duration
40- backoff BackoffType
41- jitterFactor float64
42- retryCodes []int
43- excludeCodes []int
44+ next http.RoundTripper
45+ attempts int
46+ initialDelay time.Duration
47+ maxDelay time.Duration
48+ backoff BackoffType
49+ jitterFactor float64
50+ retryCodes []int
51+ excludeCodes []int
52+ bufferBodies bool
53+ maxBufferSize int64
4454}
4555
4656// Retry creates retry middleware with provided options
4757func Retry (attempts int , initialDelay time.Duration , opts ... RetryOption ) RoundTripperHandler {
4858 return func (next http.RoundTripper ) http.RoundTripper {
4959 r := & RetryMiddleware {
50- next : next ,
51- attempts : attempts ,
52- initialDelay : initialDelay ,
53- maxDelay : 30 * time .Second ,
54- backoff : BackoffExponential ,
55- jitterFactor : 0.1 ,
60+ next : next ,
61+ attempts : attempts ,
62+ initialDelay : initialDelay ,
63+ maxDelay : 30 * time .Second ,
64+ backoff : BackoffExponential ,
65+ jitterFactor : 0.1 ,
66+ bufferBodies : false , // disabled by default to preserve streaming; when enabled, reads entire body into memory
67+ maxBufferSize : 10 * 1024 * 1024 , // 10MB limit when buffering enabled
5668 }
5769
5870 for _ , opt := range opts {
@@ -69,10 +81,28 @@ func Retry(attempts int, initialDelay time.Duration, opts ...RetryOption) RoundT
6981
7082// RoundTrip implements http.RoundTripper
7183func (r * RetryMiddleware ) RoundTrip (req * http.Request ) (* http.Response , error ) {
84+ // determine effective attempts based on body handling
85+ attempts := r .attempts
86+ hasBody := req .Body != nil && req .Body != http .NoBody
87+
88+ // prepare body for retries if needed
89+ if hasBody && req .GetBody == nil && r .attempts > 1 {
90+ if r .bufferBodies {
91+ // try to buffer body for retries
92+ if err := r .bufferRequestBody (req ); err != nil {
93+ // buffering failed or body too large
94+ return nil , err
95+ }
96+ } else {
97+ // buffering disabled - can't retry with body
98+ attempts = 1
99+ }
100+ }
101+
72102 var lastResponse * http.Response
73103 var lastError error
74104
75- for attempt := 0 ; attempt < r . attempts ; attempt ++ {
105+ for attempt := 0 ; attempt < attempts ; attempt ++ {
76106 if req .Context ().Err () != nil {
77107 return nil , req .Context ().Err ()
78108 }
@@ -84,6 +114,15 @@ func (r *RetryMiddleware) RoundTrip(req *http.Request) (*http.Response, error) {
84114 return nil , req .Context ().Err ()
85115 case <- time .After (delay ):
86116 }
117+
118+ // reset body for retry
119+ if req .GetBody != nil {
120+ newBody , err := req .GetBody ()
121+ if err != nil {
122+ return nil , fmt .Errorf ("retry: failed to get new request body: %w" , err )
123+ }
124+ req .Body = newBody
125+ }
87126 }
88127
89128 resp , err := r .next .RoundTrip (req )
@@ -101,11 +140,36 @@ func (r *RetryMiddleware) RoundTrip(req *http.Request) (*http.Response, error) {
101140 }
102141
103142 if lastError != nil {
104- return lastResponse , fmt .Errorf ("retry: transport error after %d attempts: %w" , r . attempts , lastError )
143+ return lastResponse , fmt .Errorf ("retry: transport error after %d attempts: %w" , attempts , lastError )
105144 }
106145 return lastResponse , nil
107146}
108147
148+ // bufferRequestBody attempts to buffer the request body for retries
149+ // this consumes the original body - returns error if body is too large
150+ func (r * RetryMiddleware ) bufferRequestBody (req * http.Request ) error {
151+ // read entire body (with limit for safety)
152+ bodyBytes , err := io .ReadAll (io .LimitReader (req .Body , r .maxBufferSize + 1 ))
153+ if err != nil {
154+ return fmt .Errorf ("retry: failed to read request body: %w" , err )
155+ }
156+ _ = req .Body .Close ()
157+
158+ // check if body exceeds limit
159+ if int64 (len (bodyBytes )) > r .maxBufferSize {
160+ return fmt .Errorf ("retry: request body too large (%d bytes exceeds %d byte limit) - cannot retry" ,
161+ len (bodyBytes ), r .maxBufferSize )
162+ }
163+
164+ // set up body and GetBody for retries
165+ req .Body = io .NopCloser (bytes .NewReader (bodyBytes ))
166+ req .GetBody = func () (io.ReadCloser , error ) {
167+ return io .NopCloser (bytes .NewReader (bodyBytes )), nil
168+ }
169+
170+ return nil
171+ }
172+
109173func (r * RetryMiddleware ) calcDelay (attempt int ) time.Duration {
110174 if attempt == 0 {
111175 return 0
@@ -192,3 +256,17 @@ func RetryExcludeCodes(codes ...int) RetryOption {
192256 r .excludeCodes = codes
193257 }
194258}
259+
260+ // RetryBufferBodies enables or disables automatic body buffering for retries
261+ func RetryBufferBodies (enabled bool ) RetryOption {
262+ return func (r * RetryMiddleware ) {
263+ r .bufferBodies = enabled
264+ }
265+ }
266+
267+ // RetryMaxBufferSize sets the maximum size of request bodies that will be buffered
268+ func RetryMaxBufferSize (size int64 ) RetryOption {
269+ return func (r * RetryMiddleware ) {
270+ r .maxBufferSize = size
271+ }
272+ }
0 commit comments