Skip to content

Commit bfde9df

Browse files
authored
feat: Traefik decision api support (#904)
Closes #899 See #521 See #441 See #487 See #263
1 parent 09be55f commit bfde9df

File tree

10 files changed

+280
-35
lines changed

10 files changed

+280
-35
lines changed

api/decision.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,19 @@ import (
3333

3434
const (
3535
DecisionPath = "/decisions"
36+
37+
xForwardedMethod = "X-Forwarded-Method"
38+
xForwardedProto = "X-Forwarded-Proto"
39+
xForwardedHost = "X-Forwarded-Host"
40+
xForwardedUri = "X-Forwarded-Uri"
3641
)
3742

3843
type decisionHandlerRegistry interface {
3944
x.RegistryWriter
4045
x.RegistryLogger
4146

4247
RuleMatcher() rule.Matcher
43-
ProxyRequestHandler() *proxy.RequestHandler
48+
ProxyRequestHandler() proxy.RequestHandler
4449
}
4550

4651
type DecisionHandler struct {
@@ -53,12 +58,11 @@ func NewJudgeHandler(r decisionHandlerRegistry) *DecisionHandler {
5358

5459
func (h *DecisionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
5560
if len(r.URL.Path) >= len(DecisionPath) && r.URL.Path[:len(DecisionPath)] == DecisionPath {
56-
r.URL.Scheme = "http"
57-
r.URL.Host = r.Host
58-
if r.TLS != nil || strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") {
59-
r.URL.Scheme = "https"
60-
}
61-
r.URL.Path = r.URL.Path[len(DecisionPath):]
61+
r.Method = x.OrDefaultString(r.Header.Get(xForwardedMethod), r.Method)
62+
r.URL.Scheme = x.OrDefaultString(r.Header.Get(xForwardedProto),
63+
x.IfThenElseString(r.TLS != nil, "https", "http"))
64+
r.URL.Host = x.OrDefaultString(r.Header.Get(xForwardedHost), r.Host)
65+
r.URL.Path = x.OrDefaultString(r.Header.Get(xForwardedUri), r.URL.Path[len(DecisionPath):])
6266

6367
h.decisions(w, r)
6468
} else {
@@ -112,7 +116,6 @@ func (h *DecisionHandler) decisions(w http.ResponseWriter, r *http.Request) {
112116
WithFields(fields).
113117
WithField("granted", false).
114118
Info("Access request denied")
115-
116119
h.r.ProxyRequestHandler().HandleError(w, r, rl, err)
117120
return
118121
}

api/decision_test.go

Lines changed: 137 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,30 @@ package api_test
2323
import (
2424
"bytes"
2525
"context"
26+
"crypto/tls"
2627
"fmt"
2728
"io/ioutil"
2829
"net/http"
2930
"net/http/httptest"
31+
"net/url"
3032
"strconv"
3133
"testing"
3234

33-
"github.com/ory/viper"
34-
35-
"github.com/urfave/negroni"
36-
37-
"github.com/ory/oathkeeper/driver/configuration"
38-
"github.com/ory/oathkeeper/internal"
39-
4035
"github.com/julienschmidt/httprouter"
4136
"github.com/stretchr/testify/assert"
37+
"github.com/stretchr/testify/mock"
4238
"github.com/stretchr/testify/require"
39+
"github.com/urfave/negroni"
4340

41+
"github.com/ory/herodot"
42+
"github.com/ory/oathkeeper/api"
43+
"github.com/ory/oathkeeper/driver/configuration"
44+
"github.com/ory/oathkeeper/internal"
45+
"github.com/ory/oathkeeper/pipeline/authn"
46+
"github.com/ory/oathkeeper/proxy"
4447
"github.com/ory/oathkeeper/rule"
48+
"github.com/ory/viper"
49+
"github.com/ory/x/logrusx"
4550
)
4651

4752
func TestDecisionAPI(t *testing.T) {
@@ -344,3 +349,128 @@ func TestDecisionAPI(t *testing.T) {
344349
})
345350
}
346351
}
352+
353+
type decisionHandlerRegistryMock struct {
354+
mock.Mock
355+
}
356+
357+
func (m *decisionHandlerRegistryMock) RuleMatcher() rule.Matcher {
358+
return m
359+
}
360+
361+
func (m *decisionHandlerRegistryMock) ProxyRequestHandler() proxy.RequestHandler {
362+
return m
363+
}
364+
365+
func (*decisionHandlerRegistryMock) Writer() herodot.Writer {
366+
return nil
367+
}
368+
369+
func (*decisionHandlerRegistryMock) Logger() *logrusx.Logger {
370+
return logrusx.New("", "")
371+
}
372+
373+
func (m *decisionHandlerRegistryMock) Match(ctx context.Context, method string, u *url.URL) (*rule.Rule, error) {
374+
args := m.Called(ctx, method, u)
375+
return args.Get(0).(*rule.Rule), args.Error(1)
376+
}
377+
378+
func (*decisionHandlerRegistryMock) HandleError(w http.ResponseWriter, r *http.Request, rl *rule.Rule, handleErr error) {
379+
}
380+
381+
func (*decisionHandlerRegistryMock) HandleRequest(r *http.Request, rl *rule.Rule) (session *authn.AuthenticationSession, err error) {
382+
return &authn.AuthenticationSession{}, nil
383+
}
384+
385+
func (*decisionHandlerRegistryMock) InitializeAuthnSession(r *http.Request, rl *rule.Rule) *authn.AuthenticationSession {
386+
return nil
387+
}
388+
389+
func TestDecisionAPIHeaderUsage(t *testing.T) {
390+
r := new(decisionHandlerRegistryMock)
391+
h := api.NewJudgeHandler(r)
392+
defaultUrl := &url.URL{Scheme: "http", Host: "ory.sh", Path: "/foo"}
393+
defaultMethod := "GET"
394+
defaultTransform := func(req *http.Request) {}
395+
396+
for _, tc := range []struct {
397+
name string
398+
expectedMethod string
399+
expectedUrl *url.URL
400+
transform func(req *http.Request)
401+
}{
402+
{
403+
name: "all arguments are taken from the url and request method",
404+
expectedUrl: defaultUrl,
405+
expectedMethod: defaultMethod,
406+
transform: defaultTransform,
407+
},
408+
{
409+
name: "all arguments are taken from the url and request method, but scheme from URL TLS settings",
410+
expectedUrl: &url.URL{Scheme: "https", Host: defaultUrl.Host, Path: defaultUrl.Path},
411+
expectedMethod: defaultMethod,
412+
transform: func(req *http.Request) {
413+
req.TLS = &tls.ConnectionState{}
414+
},
415+
},
416+
{
417+
name: "all arguments are taken from the headers",
418+
expectedUrl: &url.URL{Scheme: "https", Host: "test.dev", Path: "/bar"},
419+
expectedMethod: "POST",
420+
transform: func(req *http.Request) {
421+
req.Header.Add("X-Forwarded-Method", "POST")
422+
req.Header.Add("X-Forwarded-Proto", "https")
423+
req.Header.Add("X-Forwarded-Host", "test.dev")
424+
req.Header.Add("X-Forwarded-Uri", "/bar")
425+
},
426+
},
427+
{
428+
name: "only scheme is taken from the headers",
429+
expectedUrl: &url.URL{Scheme: "https", Host: defaultUrl.Host, Path: defaultUrl.Path},
430+
expectedMethod: defaultMethod,
431+
transform: func(req *http.Request) {
432+
req.Header.Add("X-Forwarded-Proto", "https")
433+
},
434+
},
435+
{
436+
name: "only method is taken from the headers",
437+
expectedUrl: defaultUrl,
438+
expectedMethod: "POST",
439+
transform: func(req *http.Request) {
440+
req.Header.Add("X-Forwarded-Method", "POST")
441+
},
442+
},
443+
{
444+
name: "only host is taken from the headers",
445+
expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: "test.dev", Path: defaultUrl.Path},
446+
expectedMethod: defaultMethod,
447+
transform: func(req *http.Request) {
448+
req.Header.Add("X-Forwarded-Host", "test.dev")
449+
},
450+
},
451+
{
452+
name: "only path is taken from the headers",
453+
expectedUrl: &url.URL{Scheme: defaultUrl.Scheme, Host: defaultUrl.Host, Path: "/bar"},
454+
expectedMethod: defaultMethod,
455+
transform: func(req *http.Request) {
456+
req.Header.Add("X-Forwarded-Uri", "/bar")
457+
},
458+
},
459+
} {
460+
t.Run(tc.name, func(t *testing.T) {
461+
res := httptest.NewRecorder()
462+
reqUrl := *defaultUrl
463+
reqUrl.Path = api.DecisionPath + reqUrl.Path
464+
req := httptest.NewRequest(defaultMethod, reqUrl.String(), nil)
465+
tc.transform(req)
466+
467+
r.On("Match", mock.Anything,
468+
mock.MatchedBy(func(val string) bool { return val == tc.expectedMethod }),
469+
mock.MatchedBy(func(val *url.URL) bool { return *val == *tc.expectedUrl })).
470+
Return(&rule.Rule{}, nil)
471+
h.ServeHTTP(res, req, nil)
472+
473+
r.AssertExpectations(t)
474+
})
475+
}
476+
}

credentials/verifier_default.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/ecdsa"
66
"crypto/rsa"
77
"fmt"
8+
"strings"
89

910
"github.com/golang-jwt/jwt/v4"
1011
"github.com/pkg/errors"
@@ -42,7 +43,7 @@ func (v *VerifierDefault) Verify(
4243

4344
kid, ok := token.Header["kid"].(string)
4445
if !ok || kid == "" {
45-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The JSON Web Token must contain a kid header value but did not."))
46+
return nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The JSON Web Token must contain a kid header value but did not."))
4647
}
4748

4849
key, err := v.r.CredentialsFetcher().ResolveKey(ctx, r.KeyURLs, kid, "sig")
@@ -74,10 +75,10 @@ func (v *VerifierDefault) Verify(
7475
return k, nil
7576
}
7677
default:
77-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`This request object uses unsupported signing algorithm "%s".`, token.Header["alg"]))
78+
return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`This request object uses unsupported signing algorithm "%s".`, token.Header["alg"]))
7879
}
7980

80-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`The signing key algorithm does not match the algorithm from the token header.`))
81+
return nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`The signing key algorithm does not match the algorithm from the token header.`))
8182
})
8283
if err != nil {
8384
if e, ok := errors.Cause(err).(*jwt.ValidationError); ok {
@@ -100,13 +101,14 @@ func (v *VerifierDefault) Verify(
100101
parsedClaims := jwtx.ParseMapStringInterfaceClaims(claims)
101102
for _, audience := range r.Audiences {
102103
if !stringslice.Has(parsedClaims.Audience, audience) {
103-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Token audience %v is not intended for target audience %s.", parsedClaims.Audience, audience))
104+
return nil, herodot.ErrUnauthorized.WithReasonf("Token audience %v is not intended for target audience %s.", parsedClaims.Audience, audience)
104105
}
105106
}
106107

107108
if len(r.Issuers) > 0 {
108109
if !stringslice.Has(r.Issuers, parsedClaims.Issuer) {
109-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Token issuer does not match any trusted issuer."))
110+
return nil, herodot.ErrUnauthorized.WithReasonf("Token issuer does not match any trusted issuer %s.", parsedClaims.Issuer).
111+
WithDetail("received issuers", strings.Join(r.Issuers, ", "))
110112
}
111113
}
112114

@@ -117,7 +119,7 @@ func (v *VerifierDefault) Verify(
117119
if r.ScopeStrategy != nil {
118120
for _, sc := range r.Scope {
119121
if !r.ScopeStrategy(s, sc) {
120-
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf(`JSON Web Token is missing required scope "%s".`, sc))
122+
return nil, herodot.ErrUnauthorized.WithReasonf(`JSON Web Token is missing required scope "%s".`, sc)
121123
}
122124
}
123125
} else {

driver/registry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type Registry interface {
3030
BuildDate() string
3131
BuildHash() string
3232

33-
ProxyRequestHandler() *proxy.RequestHandler
33+
ProxyRequestHandler() proxy.RequestHandler
3434
HealthEventManager() health.EventManager
3535
HealthHandler() *healthx.Handler
3636
RuleHandler() *api.RuleHandler

driver/registry_memory.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ type RegistryMemory struct {
5353
apiJudgeHandler *api.DecisionHandler
5454
healthxHandler *healthx.Handler
5555

56-
proxyRequestHandler *proxy.RequestHandler
56+
proxyRequestHandler proxy.RequestHandler
5757
proxyProxy *proxy.Proxy
5858
ruleFetcher rule.Fetcher
5959

@@ -89,7 +89,7 @@ func (r *RegistryMemory) WithRuleFetcher(fetcher rule.Fetcher) Registry {
8989
return r
9090
}
9191

92-
func (r *RegistryMemory) ProxyRequestHandler() *proxy.RequestHandler {
92+
func (r *RegistryMemory) ProxyRequestHandler() proxy.RequestHandler {
9393
if r.proxyRequestHandler == nil {
9494
r.proxyRequestHandler = proxy.NewRequestHandler(r, r.c)
9595
}

pipeline/errors/error_redirect.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ import (
1212

1313
var _ Handler = new(ErrorRedirect)
1414

15+
const (
16+
xForwardedProto = "X-Forwarded-Proto"
17+
xForwardedHost = "X-Forwarded-Host"
18+
xForwardedUri = "X-Forwarded-Uri"
19+
)
20+
1521
type (
1622
ErrorRedirectConfig struct {
1723
To string `json:"to"`
@@ -40,7 +46,11 @@ func (a *ErrorRedirect) Handle(w http.ResponseWriter, r *http.Request, config js
4046
return err
4147
}
4248

43-
http.Redirect(w, r, a.RedirectURL(r, c), c.Code)
49+
r.URL.Scheme = x.OrDefaultString(r.Header.Get(xForwardedProto), r.URL.Scheme)
50+
r.URL.Host = x.OrDefaultString(r.Header.Get(xForwardedHost), r.URL.Host)
51+
r.URL.Path = x.OrDefaultString(r.Header.Get(xForwardedUri), r.URL.Path)
52+
53+
http.Redirect(w, r, a.RedirectURL(r.URL, c), c.Code)
4454
return nil
4555
}
4656

@@ -69,7 +79,7 @@ func (a *ErrorRedirect) GetID() string {
6979
return "redirect"
7080
}
7181

72-
func (a *ErrorRedirect) RedirectURL(r *http.Request, c *ErrorRedirectConfig) string {
82+
func (a *ErrorRedirect) RedirectURL(uri *url.URL, c *ErrorRedirectConfig) string {
7383
if c.ReturnToQueryParam == "" {
7484
return c.To
7585
}
@@ -78,8 +88,9 @@ func (a *ErrorRedirect) RedirectURL(r *http.Request, c *ErrorRedirectConfig) str
7888
if err != nil {
7989
return c.To
8090
}
91+
8192
q := u.Query()
82-
q.Set(c.ReturnToQueryParam, r.URL.String())
93+
q.Set(c.ReturnToQueryParam, uri.String())
8394
u.RawQuery = q.Encode()
8495
return u.String()
8596
}

0 commit comments

Comments
 (0)