|
| 1 | +// SPDX-License-Identifier: MIT |
| 2 | +// SPDX-FileCopyrightText: © 2016 LabStack and Echo contributors |
| 3 | + |
| 4 | +package echojwt |
| 5 | + |
| 6 | +import ( |
| 7 | + "errors" |
| 8 | + "fmt" |
| 9 | + "github.com/labstack/echo/v4" |
| 10 | + "github.com/labstack/echo/v4/middleware" |
| 11 | + "net/textproto" |
| 12 | + "strings" |
| 13 | +) |
| 14 | + |
| 15 | +const ( |
| 16 | + // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion |
| 17 | + // attack vector |
| 18 | + extractorLimit = 20 |
| 19 | +) |
| 20 | + |
| 21 | +var errHeaderExtractorValueMissing = errors.New("missing value in request header") |
| 22 | +var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") |
| 23 | +var errQueryExtractorValueMissing = errors.New("missing value in the query string") |
| 24 | +var errParamExtractorValueMissing = errors.New("missing value in path params") |
| 25 | +var errCookieExtractorValueMissing = errors.New("missing value in cookies") |
| 26 | +var errFormExtractorValueMissing = errors.New("missing value in the form") |
| 27 | + |
| 28 | +// CreateExtractors creates ValuesExtractors from given lookups. |
| 29 | +// Lookups is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used |
| 30 | +// to extract key from the request. |
| 31 | +// Possible values: |
| 32 | +// - "header:<name>" or "header:<name>:<cut-prefix>" |
| 33 | +// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header |
| 34 | +// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we |
| 35 | +// want to cut is `<auth-scheme> ` note the space at the end. |
| 36 | +// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `. |
| 37 | +// - "query:<name>" |
| 38 | +// - "param:<name>" |
| 39 | +// - "form:<name>" |
| 40 | +// - "cookie:<name>" |
| 41 | +// |
| 42 | +// Multiple sources example: |
| 43 | +// - "header:Authorization,header:X-Api-Key" |
| 44 | +func CreateExtractors(lookups string) ([]middleware.ValuesExtractor, error) { |
| 45 | + if lookups == "" { |
| 46 | + return nil, nil |
| 47 | + } |
| 48 | + sources := strings.Split(lookups, ",") |
| 49 | + var extractors = make([]middleware.ValuesExtractor, 0) |
| 50 | + for _, source := range sources { |
| 51 | + parts := strings.Split(source, ":") |
| 52 | + if len(parts) < 2 { |
| 53 | + return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) |
| 54 | + } |
| 55 | + |
| 56 | + switch parts[0] { |
| 57 | + case "query": |
| 58 | + extractors = append(extractors, valuesFromQuery(parts[1])) |
| 59 | + case "param": |
| 60 | + extractors = append(extractors, valuesFromParam(parts[1])) |
| 61 | + case "cookie": |
| 62 | + extractors = append(extractors, valuesFromCookie(parts[1])) |
| 63 | + case "form": |
| 64 | + extractors = append(extractors, valuesFromForm(parts[1])) |
| 65 | + case "header": |
| 66 | + prefix := "" |
| 67 | + if len(parts) > 2 { |
| 68 | + prefix = parts[2] |
| 69 | + } |
| 70 | + extractors = append(extractors, valuesFromHeader(parts[1], prefix)) |
| 71 | + } |
| 72 | + } |
| 73 | + return extractors, nil |
| 74 | +} |
| 75 | + |
| 76 | +// valuesFromHeader returns a functions that extracts values from the request header. |
| 77 | +// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static |
| 78 | +// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> ` |
| 79 | +// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove |
| 80 | +// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `. |
| 81 | +// If prefix is left empty the whole value is returned. |
| 82 | +func valuesFromHeader(header string, valuePrefix string) middleware.ValuesExtractor { |
| 83 | + prefixLen := len(valuePrefix) |
| 84 | + // standard library parses http.Request header keys in canonical form but we may provide something else so fix this |
| 85 | + header = textproto.CanonicalMIMEHeaderKey(header) |
| 86 | + return func(c echo.Context) ([]string, error) { |
| 87 | + values := c.Request().Header.Values(header) |
| 88 | + if len(values) == 0 { |
| 89 | + return nil, errHeaderExtractorValueMissing |
| 90 | + } |
| 91 | + |
| 92 | + result := make([]string, 0) |
| 93 | + for i, value := range values { |
| 94 | + if prefixLen == 0 { |
| 95 | + result = append(result, value) |
| 96 | + if i >= extractorLimit-1 { |
| 97 | + break |
| 98 | + } |
| 99 | + continue |
| 100 | + } |
| 101 | + if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { |
| 102 | + result = append(result, value[prefixLen:]) |
| 103 | + if i >= extractorLimit-1 { |
| 104 | + break |
| 105 | + } |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + if len(result) == 0 { |
| 110 | + if prefixLen > 0 { |
| 111 | + return nil, errHeaderExtractorValueInvalid |
| 112 | + } |
| 113 | + return nil, errHeaderExtractorValueMissing |
| 114 | + } |
| 115 | + return result, nil |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +// valuesFromQuery returns a function that extracts values from the query string. |
| 120 | +func valuesFromQuery(param string) middleware.ValuesExtractor { |
| 121 | + return func(c echo.Context) ([]string, error) { |
| 122 | + result := c.QueryParams()[param] |
| 123 | + if len(result) == 0 { |
| 124 | + return nil, errQueryExtractorValueMissing |
| 125 | + } else if len(result) > extractorLimit-1 { |
| 126 | + result = result[:extractorLimit] |
| 127 | + } |
| 128 | + return result, nil |
| 129 | + } |
| 130 | +} |
| 131 | + |
| 132 | +// valuesFromParam returns a function that extracts values from the url param string. |
| 133 | +func valuesFromParam(param string) middleware.ValuesExtractor { |
| 134 | + return func(c echo.Context) ([]string, error) { |
| 135 | + result := make([]string, 0) |
| 136 | + paramVales := c.ParamValues() |
| 137 | + for i, p := range c.ParamNames() { |
| 138 | + if param == p { |
| 139 | + result = append(result, paramVales[i]) |
| 140 | + if i >= extractorLimit-1 { |
| 141 | + break |
| 142 | + } |
| 143 | + } |
| 144 | + } |
| 145 | + if len(result) == 0 { |
| 146 | + return nil, errParamExtractorValueMissing |
| 147 | + } |
| 148 | + return result, nil |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +// valuesFromCookie returns a function that extracts values from the named cookie. |
| 153 | +func valuesFromCookie(name string) middleware.ValuesExtractor { |
| 154 | + return func(c echo.Context) ([]string, error) { |
| 155 | + cookies := c.Cookies() |
| 156 | + if len(cookies) == 0 { |
| 157 | + return nil, errCookieExtractorValueMissing |
| 158 | + } |
| 159 | + |
| 160 | + result := make([]string, 0) |
| 161 | + for i, cookie := range cookies { |
| 162 | + if name == cookie.Name { |
| 163 | + result = append(result, cookie.Value) |
| 164 | + if i >= extractorLimit-1 { |
| 165 | + break |
| 166 | + } |
| 167 | + } |
| 168 | + } |
| 169 | + if len(result) == 0 { |
| 170 | + return nil, errCookieExtractorValueMissing |
| 171 | + } |
| 172 | + return result, nil |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +// valuesFromForm returns a function that extracts values from the form field. |
| 177 | +func valuesFromForm(name string) middleware.ValuesExtractor { |
| 178 | + return func(c echo.Context) ([]string, error) { |
| 179 | + if c.Request().Form == nil { |
| 180 | + _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does |
| 181 | + } |
| 182 | + values := c.Request().Form[name] |
| 183 | + if len(values) == 0 { |
| 184 | + return nil, errFormExtractorValueMissing |
| 185 | + } |
| 186 | + if len(values) > extractorLimit-1 { |
| 187 | + values = values[:extractorLimit] |
| 188 | + } |
| 189 | + result := append([]string{}, values...) |
| 190 | + return result, nil |
| 191 | + } |
| 192 | +} |
0 commit comments