Skip to content

Commit 31dbdd2

Browse files
authored
Response limit middleware (#917)
1 parent 88020f3 commit 31dbdd2

File tree

5 files changed

+206
-0
lines changed

5 files changed

+206
-0
lines changed

backend/config.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const (
2020
SQLMaxOpenConnsDefault = "GF_SQL_MAX_OPEN_CONNS_DEFAULT"
2121
SQLMaxIdleConnsDefault = "GF_SQL_MAX_IDLE_CONNS_DEFAULT"
2222
SQLMaxConnLifetimeSecondsDefault = "GF_SQL_MAX_CONN_LIFETIME_SECONDS_DEFAULT"
23+
ResponseLimit = "GF_RESPONSE_LIMIT"
2324
)
2425

2526
type configKey struct{}
@@ -225,6 +226,18 @@ func (c *GrafanaCfg) UserFacingDefaultError() (string, error) {
225226
return value, nil
226227
}
227228

229+
func (c *GrafanaCfg) ResponseLimit() int64 {
230+
count, ok := c.config[ResponseLimit]
231+
if !ok {
232+
return 0
233+
}
234+
i, err := strconv.ParseInt(count, 10, 64)
235+
if err != nil {
236+
return 0
237+
}
238+
return i
239+
}
240+
228241
type userAgentKey struct{}
229242

230243
// UserAgentFromContext returns user agent from context.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package httpclient
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
)
8+
9+
// Similar implementation to http/net MaxBytesReader
10+
// https://pkg.go.dev/net/http#MaxBytesReader
11+
// What's happening differently here, is that the field that
12+
// is limited is the response and not the request, thus
13+
// the error handling/message needed to be accurate.
14+
15+
// ErrResponseBodyTooLarge indicates response body is too large
16+
var ErrResponseBodyTooLarge = errors.New("http: response body too large")
17+
18+
// MaxBytesReader is similar to io.LimitReader but is intended for
19+
// limiting the size of incoming request bodies. In contrast to
20+
// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
21+
// non-EOF error for a Read beyond the limit, and closes the
22+
// underlying reader when its Close method is called.
23+
//
24+
// MaxBytesReader prevents clients from accidentally or maliciously
25+
// sending a large request and wasting server resources.
26+
func MaxBytesReader(r io.ReadCloser, n int64) io.ReadCloser {
27+
return &maxBytesReader{r: r, n: n}
28+
}
29+
30+
type maxBytesReader struct {
31+
r io.ReadCloser // underlying reader
32+
n int64 // max bytes remaining
33+
err error // sticky error
34+
}
35+
36+
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
37+
if l.err != nil {
38+
return 0, l.err
39+
}
40+
if len(p) == 0 {
41+
return 0, nil
42+
}
43+
// If they asked for a 32KB byte read but only 5 bytes are
44+
// remaining, no need to read 32KB. 6 bytes will answer the
45+
// question of the whether we hit the limit or go past it.
46+
if int64(len(p)) > l.n+1 {
47+
p = p[:l.n+1]
48+
}
49+
n, err = l.r.Read(p)
50+
51+
if int64(n) <= l.n {
52+
l.n -= int64(n)
53+
l.err = err
54+
return n, err
55+
}
56+
57+
n = int(l.n)
58+
l.n = 0
59+
60+
l.err = fmt.Errorf("error: %w, response limit is set to: %d", ErrResponseBodyTooLarge, n)
61+
return n, l.err
62+
}
63+
64+
func (l *maxBytesReader) Close() error {
65+
return l.r.Close()
66+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package httpclient
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"strings"
8+
"testing"
9+
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestMaxBytesReader(t *testing.T) {
14+
tcs := []struct {
15+
limit int64
16+
bodyLength int
17+
body string
18+
err error
19+
}{
20+
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")},
21+
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil},
22+
{limit: 0, bodyLength: 0, body: "", err: errors.New("error: http: response body too large, response limit is set to: 0")},
23+
}
24+
for _, tc := range tcs {
25+
t.Run(fmt.Sprintf("Test MaxBytesReader with limit: %d", tc.limit), func(t *testing.T) {
26+
body := io.NopCloser(strings.NewReader("dummy"))
27+
readCloser := MaxBytesReader(body, tc.limit)
28+
29+
bodyBytes, err := io.ReadAll(readCloser)
30+
if err != nil {
31+
require.EqualError(t, tc.err, err.Error())
32+
} else {
33+
require.NoError(t, tc.err)
34+
}
35+
36+
require.Len(t, bodyBytes, tc.bodyLength)
37+
require.Equal(t, string(bodyBytes), tc.body)
38+
})
39+
}
40+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package httpclient
2+
3+
import (
4+
"net/http"
5+
)
6+
7+
// ResponseLimitMiddlewareName is the middleware name used by ResponseLimitMiddleware.
8+
const ResponseLimitMiddlewareName = "response-limit"
9+
10+
func ResponseLimitMiddleware(limit int64) Middleware {
11+
return NamedMiddlewareFunc(ResponseLimitMiddlewareName, func(opts Options, next http.RoundTripper) http.RoundTripper {
12+
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
13+
res, err := next.RoundTrip(req)
14+
if err != nil {
15+
return nil, err
16+
}
17+
18+
if limit <= 0 {
19+
return res, nil
20+
}
21+
22+
if res != nil && res.StatusCode != http.StatusSwitchingProtocols {
23+
res.Body = MaxBytesReader(res.Body, limit)
24+
}
25+
26+
return res, nil
27+
})
28+
})
29+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package httpclient
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"strings"
10+
"testing"
11+
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestResponseLimitMiddleware(t *testing.T) {
16+
tcs := []struct {
17+
limit int64
18+
bodyLength int
19+
body string
20+
err error
21+
}{
22+
{limit: 1, bodyLength: 1, body: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")},
23+
{limit: 1000000, bodyLength: 5, body: "dummy", err: nil},
24+
{limit: 0, bodyLength: 5, body: "dummy", err: nil},
25+
}
26+
for _, tc := range tcs {
27+
t.Run(fmt.Sprintf("Test ResponseLimitMiddleware with limit: %d", tc.limit), func(t *testing.T) {
28+
finalRoundTripper := RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
29+
return &http.Response{StatusCode: http.StatusOK, Request: req, Body: io.NopCloser(strings.NewReader("dummy"))}, nil
30+
})
31+
32+
mw := ResponseLimitMiddleware(tc.limit)
33+
rt := mw.CreateMiddleware(Options{}, finalRoundTripper)
34+
require.NotNil(t, rt)
35+
middlewareName, ok := mw.(MiddlewareName)
36+
require.True(t, ok)
37+
require.Equal(t, ResponseLimitMiddlewareName, middlewareName.MiddlewareName())
38+
39+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://test.com/query", nil)
40+
require.NoError(t, err)
41+
res, err := rt.RoundTrip(req)
42+
require.NoError(t, err)
43+
require.NotNil(t, res)
44+
require.NotNil(t, res.Body)
45+
46+
bodyBytes, err := io.ReadAll(res.Body)
47+
if err != nil {
48+
require.EqualError(t, tc.err, err.Error())
49+
} else {
50+
require.NoError(t, tc.err)
51+
}
52+
require.NoError(t, res.Body.Close())
53+
54+
require.Len(t, bodyBytes, tc.bodyLength)
55+
require.Equal(t, string(bodyBytes), tc.body)
56+
})
57+
}
58+
}

0 commit comments

Comments
 (0)