diff --git a/callback.go b/callback.go index 6ff9b0f..298b56f 100644 --- a/callback.go +++ b/callback.go @@ -7,13 +7,21 @@ import ( ) type CallbackConfiguration struct { - MaxWaitSeconds uint `json:"max_wait_seconds"` - Port uint `json:"port"` - Route string `json:"route"` - SSL bool `json:"ssl" envconfig:"IMMUNE_SSL"` - SSLKeyFile string `json:"ssl_key_file" envconfig:"IMMUNE_SSL_KEY_FILE"` - SSLCertFile string `json:"ssl_cert_file" envconfig:"IMMUNE_SSL_CERT_FILE"` - IDLocation string `json:"id_location"` + MaxWaitSeconds uint `json:"max_wait_seconds"` + Port uint `json:"port"` + Route string `json:"route"` + SSL bool `json:"ssl" envconfig:"IMMUNE_SSL"` + SSLKeyFile string `json:"ssl_key_file" envconfig:"IMMUNE_SSL_KEY_FILE"` + SSLCertFile string `json:"ssl_cert_file" envconfig:"IMMUNE_SSL_CERT_FILE"` + IDLocation string `json:"id_location"` + Signature SignatureConfiguration `json:"signature"` +} + +type SignatureConfiguration struct { + ReplayAttacks bool `json:"replay_attacks" envconfig:"IMMUNE_REPLAY_ATTACKS"` + Secret string `json:"secret" envconfig:"IMMUNE_SIGNATURE_SECRET"` + Header string `json:"header" envconfig:"IMMUNE_SIGNATURE_HEADER"` + Hash string `json:"hash" envconfig:"IMMUNE_SIGNATURE_HASH"` } const CallbackIDFieldName = "immune_callback_id" @@ -52,3 +60,7 @@ type CallbackServer interface { Start(ctx context.Context) error Stop() } + +type CallbackSignatureVerifier interface { + VerifyCallbackSignature(s *Signal) error +} diff --git a/callback/server.go b/callback/server.go index ba6716c..ef6488f 100644 --- a/callback/server.go +++ b/callback/server.go @@ -1,9 +1,11 @@ package callback import ( + "bytes" "context" "encoding/json" "fmt" + "io" "net/http" "strconv" "time" @@ -95,10 +97,20 @@ func (s *server) Start(ctx context.Context) error { func handleCallback(outbound chan<- *immune.Signal) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { sig := &immune.Signal{} - err := json.NewDecoder(r.Body).Decode(sig) + clone := r.Clone(context.Background()) + + buf, err := io.ReadAll(r.Body) if err != nil { - sig.Err = fmt.Errorf("failed to decode callback body: %v", err) + sig.Err = fmt.Errorf("failed to read callback body: %v", err) + } else { + err = json.Unmarshal(buf, sig) + if err != nil { + sig.Err = fmt.Errorf("failed to decode callback body: %v", err) + } } + + clone.Body = io.NopCloser(bytes.NewBuffer(buf)) + sig.Request = clone w.WriteHeader(http.StatusOK) outbound <- sig } diff --git a/callback/server_test.go b/callback/server_test.go index a9ca88d..47e4ab7 100644 --- a/callback/server_test.go +++ b/callback/server_test.go @@ -57,7 +57,9 @@ func Test_handleCallback(t *testing.T) { require.Equal(t, tt.wantErrMsg, s.Error()) return } - require.Equal(t, tt.wantSignal, <-tt.args.outbound) + sig := <-tt.args.outbound + sig.Request = nil + require.Equal(t, tt.wantSignal, sig) }) } } diff --git a/callback/signature_verifier.go b/callback/signature_verifier.go new file mode 100644 index 0000000..6a676ff --- /dev/null +++ b/callback/signature_verifier.go @@ -0,0 +1,113 @@ +package callback + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "hash" + "io/ioutil" + "strconv" + "time" + + "github.com/frain-dev/immune" + "github.com/pkg/errors" + "golang.org/x/crypto/sha3" +) + +const ConvoyTimestampHeader = "Convoy-Timestamp" + +type SignatureVerifier struct { + ReplayAttacks bool `json:"replay_attacks"` + Secret string `json:"secret"` + Header string `json:"header"` + Hash string `json:"hash"` + hashFn func() hash.Hash +} + +func NewSignatureVerifier(replayAttacks bool, secret, header, hash string) (immune.CallbackSignatureVerifier, error) { + fn, err := getHashFunction(hash) + if err != nil { + return nil, err + } + + return &SignatureVerifier{ + ReplayAttacks: replayAttacks, + Secret: secret, + Header: header, + Hash: hash, + hashFn: fn, + }, nil +} + +func (sv *SignatureVerifier) VerifyCallbackSignature(s *immune.Signal) error { + r := s.Request + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrap(err, "unable to read request body") + } + + signatureHex := []byte(r.Header.Get(sv.Header)) + var signature = make([]byte, hex.DecodedLen(len(signatureHex))) + _, err = hex.Decode(signature, signatureHex) + if err != nil { + return errors.Wrap(err, "unable to hex decode signature body") + } + + hasher := hmac.New(sv.hashFn, []byte(sv.Secret)) + + if sv.ReplayAttacks { + timestampStr := r.Header.Get(ConvoyTimestampHeader) + timestamp, err := strconv.ParseInt(timestampStr, 10, 64) + if err != nil { + return errors.Wrap(err, "unable to parse signature timestamp") + } + + t := time.Unix(timestamp, 0) + d := time.Since(t) + if d > time.Minute { + return errors.Errorf("replay attack timestamp is more than a minute ago") + } + + hasher.Write([]byte(timestampStr)) + hasher.Write([]byte(",")) + } + + hasher.Write(buf) + if !hmac.Equal(signature, hasher.Sum(nil)) { + return errors.New("signature invalid") + } + return nil +} + +func getHashFunction(algorithm string) (func() hash.Hash, error) { + switch algorithm { + case "MD5": + return md5.New, nil + case "SHA1": + return sha1.New, nil + case "SHA224": + return sha256.New224, nil + case "SHA256": + return sha256.New, nil + case "SHA384": + return sha512.New384, nil + case "SHA512": + return sha512.New, nil + case "SHA3_224": + return sha3.New224, nil + case "SHA3_256": + return sha3.New256, nil + case "SHA3_384": + return sha3.New384, nil + case "SHA3_512": + return sha3.New512, nil + case "SHA512_224": + return sha512.New512_224, nil + case "SHA512_256": + return sha512.New512_256, nil + } + return nil, errors.New("unknown hash algorithm") +} diff --git a/callback/signature_verifier_test.go b/callback/signature_verifier_test.go new file mode 100644 index 0000000..8ddbfea --- /dev/null +++ b/callback/signature_verifier_test.go @@ -0,0 +1,366 @@ +package callback + +import ( + "crypto/hmac" + "encoding/hex" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/frain-dev/immune" + "github.com/stretchr/testify/require" +) + +func TestSignatureVerifier_VerifyCallbackSignature(t *testing.T) { + type fields struct { + ReplayAttacks bool + Secret string + Header string + Hash string + } + type args struct { + s *immune.Signal + } + tests := []struct { + name string + fields fields + args args + bodyStr string + t int64 + mismatchRequestBody bool // so we can make cases where the signature should not match, see it's usage + wantErr bool + wantErrMsg string + }{ + { + name: "should_verify_signature_header_SHA512", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA512", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_MD5", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "MD5", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA1", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA1", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA224", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA224", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA256", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA256", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA384", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA384", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_224", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA3_224", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_256", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA3_256", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_384", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA3_384", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_512", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA3_512", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA512_224", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA512_224", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA512_256", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA512_256", + }, + args: args{s: &immune.Signal{}}, + wantErr: false, + }, + { + name: "should_error_for_old_timestamp", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Add(-time.Hour).Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA512_256", + }, + args: args{s: &immune.Signal{}}, + wantErr: true, + wantErrMsg: "replay attack timestamp is more than a minute ago", + }, + { + name: "should_error_for_invalid_signature", + bodyStr: `{"name":"Daniel"}`, + t: time.Now().Unix(), + fields: fields{ + ReplayAttacks: true, + Secret: "1234", + Header: "X-Test", + Hash: "SHA512", + }, + mismatchRequestBody: true, + args: args{s: &immune.Signal{}}, + wantErr: true, + wantErrMsg: "signature invalid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sv, err := NewSignatureVerifier( + tt.fields.ReplayAttacks, + tt.fields.Secret, + tt.fields.Header, + tt.fields.Hash, + ) + require.NoError(t, err) + + r, err := http.NewRequest(http.MethodPost, "/", nil) + require.NoError(t, err) + + tt.args.s.Request = r + + generateSignatureHeader(sv.(*SignatureVerifier), tt.bodyStr, tt.t, tt.mismatchRequestBody, tt.args.s.Request) + err = sv.VerifyCallbackSignature(tt.args.s) + if tt.wantErr { + require.Error(t, err) + require.Equal(t, tt.wantErrMsg, err.Error()) + return + } + + require.NoError(t, err) + }) + } +} + +func generateSignatureHeader(sv *SignatureVerifier, bodyStr string, t int64, mismatchRequestBody bool, r *http.Request) { + body := strings.NewReader(bodyStr) + + var signedPayload strings.Builder + var timestamp string + + if sv.ReplayAttacks { + timestamp = fmt.Sprint(t) + r.Header.Set(ConvoyTimestampHeader, timestamp) + signedPayload.WriteString(timestamp) + signedPayload.WriteString(",") + } + signedPayload.WriteString(bodyStr) + + if mismatchRequestBody { + r.Body = io.NopCloser(strings.NewReader(bodyStr + bodyStr)) // scramble the request body + } else { + r.Body = io.NopCloser(body) // set the normal value + } + + h := hmac.New(sv.hashFn, []byte(sv.Secret)) + h.Write([]byte(signedPayload.String())) + e := hex.EncodeToString(h.Sum(nil)) + r.Header.Set(sv.Header, e) +} + +func Test_getHashFunction(t *testing.T) { + type args struct { + algorithm string + } + tests := []struct { + name string + args args + wantErr bool + wantErrMsg string + }{ + { + name: "should_verify_signature_header_SHA512", + args: args{algorithm: "SHA512"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_MD5", + args: args{algorithm: "MD5"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA1", + args: args{algorithm: "SHA1"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA224", + args: args{algorithm: "SHA224"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA256", + args: args{algorithm: "SHA256"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA384", + args: args{algorithm: "SHA384"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_224", + args: args{algorithm: "SHA3_224"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_256", + args: args{algorithm: "SHA3_256"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_384", + args: args{algorithm: "SHA3_384"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA3_512", + args: args{algorithm: "SHA3_512"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA512_224", + args: args{algorithm: "SHA512_224"}, + wantErr: false, + }, + { + name: "should_verify_signature_header_SHA512_256", + args: args{algorithm: "SHA512_256"}, + wantErr: false, + }, + { + name: "should_error_for_unknown_hash_algorithm", + args: args{algorithm: "abc"}, + wantErr: true, + wantErrMsg: "unknown hash algorithm", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getHashFunction(tt.args.algorithm) + if tt.wantErr { + require.Error(t, err) + require.Equal(t, tt.wantErrMsg, err.Error()) + return + } + + require.NoError(t, err) + }) + } +} diff --git a/exec/executor.go b/exec/executor.go index f44d684..8bc6abf 100644 --- a/exec/executor.go +++ b/exec/executor.go @@ -24,6 +24,7 @@ type Executor struct { idFn func() string client *http.Client dbTruncator database.Truncator + sv immune.CallbackSignatureVerifier vm *immune.VariableMap s immune.CallbackServer } @@ -32,6 +33,7 @@ func NewExecutor( s immune.CallbackServer, client *http.Client, vm *immune.VariableMap, + sv immune.CallbackSignatureVerifier, maxCallbackWaitSeconds uint, baseURL string, callbackIDLocation string, @@ -39,6 +41,7 @@ func NewExecutor( return &Executor{ s: s, vm: vm, + sv: sv, idFn: idFn, client: client, baseURL: baseURL, @@ -191,6 +194,12 @@ func (ex *Executor) ExecuteTestCase(ctx context.Context, tc *immune.TestCase) er if sig.ImmuneCallBackID != uid { return errors.Errorf("test_case %s: incorrect callback_id: expected_callback_id '%s', got_callback_id '%s'", tc.Name, uid, sig.ImmuneCallBackID) } + + err = ex.sv.VerifyCallbackSignature(sig) + if err != nil { + return errors.Wrap(err, "failed to verify callback signature header") + } + log.Infof("callback %d for test_case %s received", i, tc.Name) } } diff --git a/exec/executor_test.go b/exec/executor_test.go index 8ca3f8b..e87883c 100644 --- a/exec/executor_test.go +++ b/exec/executor_test.go @@ -14,7 +14,7 @@ import ( ) func TestExecutor_ExecuteSetupTestCase(t *testing.T) { - ex := NewExecutor(nil, http.DefaultClient, nil, 10, "http://localhost:5005", "data", nil, nil) + ex := NewExecutor(nil, http.DefaultClient, nil, nil, 10, "http://localhost:5005", "data", nil, nil) type fields struct { vm *immune.VariableMap @@ -352,7 +352,7 @@ func TestExecutor_ExecuteSetupTestCase(t *testing.T) { } func TestExecutor_ExecuteTestCase(t *testing.T) { - ex := NewExecutor(nil, http.DefaultClient, nil, 10, "http://localhost:5005", "data", nil, nil) + ex := NewExecutor(nil, http.DefaultClient, nil, nil, 10, "http://localhost:5005", "data", nil, nil) type fields struct { vm *immune.VariableMap } @@ -366,7 +366,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn func() string fields fields args args - arrangeFn func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() + arrangeFn func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() wantErr bool wantErrMsg string }{ @@ -404,13 +404,14 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { var rc chan<- *immune.Signal server.EXPECT().ReceiveCallback(gomock.AssignableToTypeOf(rc)).Times(2).DoAndReturn(func(c chan<- *immune.Signal) { c <- &immune.Signal{ImmuneCallBackID: "12345"} }) tr.EXPECT().Truncate(gomock.Any()).Times(1) + sv.EXPECT().VerifyCallbackSignature(gomock.Any()).Times(2) httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user/1234", @@ -446,7 +447,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { tr.EXPECT().Truncate(gomock.Any()).Times(1) httpmock.Activate() @@ -492,7 +493,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { var rc chan<- *immune.Signal server.EXPECT().ReceiveCallback(gomock.AssignableToTypeOf(rc)).Times(1).DoAndReturn(func(c chan<- *immune.Signal) { c <- &immune.Signal{Err: errors.New("failed to decode callback body")} @@ -539,7 +540,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { }, }, }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user/1234", @@ -584,7 +585,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user", @@ -630,7 +631,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user", @@ -675,7 +676,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user", @@ -720,7 +721,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user", @@ -765,7 +766,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user", @@ -810,7 +811,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { httpmock.Activate() httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user", @@ -855,7 +856,7 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { idFn: func() string { return "12345" }, - arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator) func() { + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { var rc chan<- *immune.Signal server.EXPECT().ReceiveCallback(gomock.AssignableToTypeOf(rc)).Times(1).Do(func(rc chan<- *immune.Signal) { rc <- &immune.Signal{ImmuneCallBackID: "1234"} @@ -872,21 +873,74 @@ func TestExecutor_ExecuteTestCase(t *testing.T) { wantErr: true, wantErrMsg: "test_case abc: incorrect callback_id: expected_callback_id '12345', got_callback_id '1234'", }, + { + name: "should_fail_to_verify_callback_signature", + fields: fields{ + vm: &immune.VariableMap{ + VariableToValue: immune.M{ + "user_id": "1234", + }, + }, + }, + args: args{ + ctx: context.Background(), + tc: &immune.TestCase{ + Name: "abc", + Setup: nil, + StatusCode: 200, + HTTPMethod: "POST", + Endpoint: "/update_user", + ResponseBody: true, + Callback: immune.Callback{ + Enabled: true, + Times: 2, + }, + RequestBody: immune.M{ + "email": "dan@gmail.com", + "phone": 23453530833, + "data": map[string]interface{}{}, + }, + }, + }, + idFn: func() string { + return "1234" + }, + arrangeFn: func(server *mocks.MockCallbackServer, tr *mocks.MockTruncator, sv *mocks.MockCallbackSignatureVerifier) func() { + var rc chan<- *immune.Signal + server.EXPECT().ReceiveCallback(gomock.AssignableToTypeOf(rc)).Times(1).Do(func(rc chan<- *immune.Signal) { + rc <- &immune.Signal{ImmuneCallBackID: "1234"} + }) + sv.EXPECT().VerifyCallbackSignature(gomock.Any()).Times(1).Return(errors.New("failed")) + + httpmock.Activate() + + httpmock.RegisterResponder(http.MethodPost, "http://localhost:5005/update_user", + httpmock.NewStringResponder(http.StatusOK, `{"immune_callback_id":"1234"}`)) + + return func() { + httpmock.DeactivateAndReset() + } + }, + wantErr: true, + wantErrMsg: "failed to verify callback signature header: failed", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() mockDBTruncator := mocks.NewMockTruncator(ctrl) - mockCallbackServer := mocks.NewMockCallbackServer(ctrl) + sv := mocks.NewMockCallbackSignatureVerifier(ctrl) + if tt.arrangeFn != nil { - deferFn := tt.arrangeFn(mockCallbackServer, mockDBTruncator) + deferFn := tt.arrangeFn(mockCallbackServer, mockDBTruncator, sv) defer deferFn() } ex.s = mockCallbackServer ex.dbTruncator = mockDBTruncator + ex.sv = sv ex.vm = tt.fields.vm ex.idFn = tt.idFn err := ex.ExecuteTestCase(tt.args.ctx, tt.args.tc) diff --git a/exec/signature.go b/exec/signature.go new file mode 100644 index 0000000..81c33c4 --- /dev/null +++ b/exec/signature.go @@ -0,0 +1 @@ +package exec diff --git a/funcs/funcs.go b/funcs/funcs.go index 156303d..c87fcc2 100644 --- a/funcs/funcs.go +++ b/funcs/funcs.go @@ -10,13 +10,14 @@ import ( "github.com/google/uuid" ) -func SetupGroup(ctx context.Context, ex *exec.Executor) error { +func SetupGroup(ctx context.Context, ex *exec.Executor, signatureConfig *immune.SignatureConfiguration) error { req := `{ "config": { + "replay_attacks": %t, "disableEndpoint": true, "signature": { - "hash": "SHA256", - "header": "X-Retro-Signature" + "hash": "%s", + "header": "%s" }, "strategy": { "default": { @@ -30,7 +31,7 @@ func SetupGroup(ctx context.Context, ex *exec.Executor) error { "name": "immune-group-%s" }` - req = fmt.Sprintf(req, uuid.New().String()) + req = fmt.Sprintf(req, signatureConfig.ReplayAttacks, signatureConfig.Hash, signatureConfig.Header, uuid.New().String()) mapper := map[string]interface{}{} err := json.Unmarshal([]byte(req), &mapper) if err != nil { @@ -82,17 +83,17 @@ func SetupApp(ctx context.Context, ex *exec.Executor) error { return ex.ExecuteSetupTestCase(ctx, tc) } -func SetupAppEndpoint(ctx context.Context, targetURL string, ex *exec.Executor) error { +func SetupAppEndpoint(ctx context.Context, targetURL string, secret string, ex *exec.Executor) error { req := `{ "url": "%s", - "secret": "12345", + "secret": "%s", "description": "Local ngrok endpoint", "events": [ "payment.failed" ] }` - req = fmt.Sprintf(req, targetURL) + req = fmt.Sprintf(req, targetURL, secret) mapper := map[string]interface{}{} err := json.Unmarshal([]byte(req), &mapper) if err != nil { diff --git a/immune.json b/immune.json index 7a78ddd..6f25e0e 100644 --- a/immune.json +++ b/immune.json @@ -7,22 +7,28 @@ "ssl_key_file": "", "max_wait_seconds": 20, "route": "/", - "id_location": "data" + "id_location": "data", + "signature": { + "replay_attacks": true, + "secret": "12345", + "header": "X-Retro-Signature", + "hash": "SHA512" + } }, "database": { "type": "mongo", "dsn": "mongodb+srv://admin:7h5tAfZiYuCEe6KC42873272642331@cluster1.eqj2e.mongodb.net/convoy-immune" }, - "event_target_url": "https://5721-102-219-153-96.ngrok.io", + "event_target_url": "https://7e58-197-211-61-35.eu.ngrok.io", "test_cases": [ { "name": "test_convoy_can_push_event_to_app_with_one_endpoint", - "setup": ["setup_group", "setup_app", "setup_endpoint", "setup_event"], + "setup": ["setup_group", "setup_app", "setup_endpoint"], "http_method": "POST", "endpoint": "/events?groupId={group_id}", "callback": { "enabled": true, - "times": 2 + "times": 1 }, "request_body": { "app_id": "{app_id}", diff --git a/mocks/callback.go b/mocks/callback.go index 7413536..5b9d4df 100644 --- a/mocks/callback.go +++ b/mocks/callback.go @@ -72,3 +72,40 @@ func (mr *MockCallbackServerMockRecorder) Stop() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockCallbackServer)(nil).Stop)) } + +// MockCallbackSignatureVerifier is a mock of CallbackSignatureVerifier interface. +type MockCallbackSignatureVerifier struct { + ctrl *gomock.Controller + recorder *MockCallbackSignatureVerifierMockRecorder +} + +// MockCallbackSignatureVerifierMockRecorder is the mock recorder for MockCallbackSignatureVerifier. +type MockCallbackSignatureVerifierMockRecorder struct { + mock *MockCallbackSignatureVerifier +} + +// NewMockCallbackSignatureVerifier creates a new mock instance. +func NewMockCallbackSignatureVerifier(ctrl *gomock.Controller) *MockCallbackSignatureVerifier { + mock := &MockCallbackSignatureVerifier{ctrl: ctrl} + mock.recorder = &MockCallbackSignatureVerifierMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCallbackSignatureVerifier) EXPECT() *MockCallbackSignatureVerifierMockRecorder { + return m.recorder +} + +// VerifyCallbackSignature mocks base method. +func (m *MockCallbackSignatureVerifier) VerifyCallbackSignature(s *immune.Signal) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VerifyCallbackSignature", s) + ret0, _ := ret[0].(error) + return ret0 +} + +// VerifyCallbackSignature indicates an expected call of VerifyCallbackSignature. +func (mr *MockCallbackSignatureVerifierMockRecorder) VerifyCallbackSignature(s interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyCallbackSignature", reflect.TypeOf((*MockCallbackSignatureVerifier)(nil).VerifyCallbackSignature), s) +} diff --git a/signal.go b/signal.go index 73e97c3..dda2efa 100644 --- a/signal.go +++ b/signal.go @@ -1,12 +1,15 @@ package immune +import "net/http" + // A Signal represents a single callback type Signal struct { // ImmuneCallBackID collects the callback id from the request body, it's json tag // must always match immune.CallbackIDFieldName ImmuneCallBackID string `json:"immune_callback_id"` - Err error + Request *http.Request // the http request that carried this callback signal + Err error } func (s *Signal) Error() string { diff --git a/system/run.go b/system/run.go index bb44adf..8dba45d 100644 --- a/system/run.go +++ b/system/run.go @@ -2,6 +2,7 @@ package system import ( "context" + "fmt" "net/http" "github.com/frain-dev/immune" @@ -47,16 +48,28 @@ func (s *System) Run(ctx context.Context) error { idFn := func() string { return uuid.New().String() } - ex := exec.NewExecutor(cs, http.DefaultClient, s.Variables, s.Callback.MaxWaitSeconds, s.BaseURL, s.Callback.IDLocation, truncator, idFn) - //log.Info("starting execution of setup test cases") - //for i := range s.SetupTestCases { - // err = ex.ExecuteSetupTestCase(ctx, &s.SetupTestCases[i]) - // if err != nil { - // return err - // } - //} - //log.Info("finished execution of setup test cases") + sv, err := callback.NewSignatureVerifier( + s.Callback.Signature.ReplayAttacks, + s.Callback.Signature.Secret, + s.Callback.Signature.Header, + s.Callback.Signature.Hash, + ) + if err != nil { + return fmt.Errorf("failed to get new signature verifier: %v", err) + } + + ex := exec.NewExecutor( + cs, + http.DefaultClient, + s.Variables, + sv, + s.Callback.MaxWaitSeconds, + s.BaseURL, + s.Callback.IDLocation, + truncator, + idFn, + ) log.Info("starting execution of test cases") for i := range s.TestCases { @@ -64,7 +77,7 @@ func (s *System) Run(ctx context.Context) error { for _, setupName := range tc.Setup { switch setupName { case "setup_group": - err = funcs.SetupGroup(ctx, ex) + err = funcs.SetupGroup(ctx, ex, &s.Callback.Signature) if err != nil { return err } @@ -74,7 +87,7 @@ func (s *System) Run(ctx context.Context) error { return err } case "setup_endpoint": - err = funcs.SetupAppEndpoint(ctx, s.EventTargetURL, ex) + err = funcs.SetupAppEndpoint(ctx, s.EventTargetURL, s.Callback.Signature.Secret, ex) if err != nil { return err } diff --git a/system/system.go b/system/system.go index 00e4ecd..eb5093e 100644 --- a/system/system.go +++ b/system/system.go @@ -66,6 +66,22 @@ func processOverride(sys, override *System) { if override.Callback.SSLCertFile != "" { sys.Callback.SSLCertFile = override.Callback.SSLCertFile } + + if _, ok := os.LookupEnv("IMMUNE_REPLAY_ATTACKS"); ok { + sys.Callback.Signature.ReplayAttacks = override.Callback.Signature.ReplayAttacks + } + + if override.Callback.Signature.Secret != "" { + sys.Callback.Signature.Secret = override.Callback.Signature.Secret + } + + if override.Callback.Signature.Header != "" { + sys.Callback.Signature.Header = override.Callback.Signature.Header + } + + if override.Callback.Signature.Hash != "" { + sys.Callback.Signature.Hash = override.Callback.Signature.Hash + } } const maxCallbackWait = 5 @@ -82,6 +98,16 @@ func (s *System) Clean() error { } } + if s.Callback.Signature.Header == "" { + return errors.New("callback signature header cannot be empty") + } + if s.Callback.Signature.Hash == "" { + return errors.New("callback signature hash cannot be empty") + } + if s.Callback.Signature.Secret == "" { + return errors.New("callback signature secret cannot be empty") + } + _, err := url.Parse(s.BaseURL) if err != nil { return fmt.Errorf("base url is not a vaild url: %v", err)