Skip to content

Commit 64c5b05

Browse files
committed
Use webhook dispatcher when creating and updating endpoints (#2280)
* Use webhook dispatcher when creating and updating endpoints
1 parent e7fc093 commit 64c5b05

13 files changed

+277
-40
lines changed

api/handlers/endpoint.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ func (h *Handler) CreateEndpoint(w http.ResponseWriter, r *http.Request) {
8484
Licenser: h.A.Licenser,
8585
E: e,
8686
ProjectID: project.UID,
87+
FeatureFlag: h.A.FFlag,
88+
Logger: h.A.Logger,
8789
}
8890

8991
endpoint, err := ce.Run(r.Context())
@@ -319,6 +321,8 @@ func (h *Handler) UpdateEndpoint(w http.ResponseWriter, r *http.Request) {
319321
EndpointRepo: postgres.NewEndpointRepo(h.A.DB),
320322
ProjectRepo: postgres.NewProjectRepo(h.A.DB),
321323
Licenser: h.A.Licenser,
324+
FeatureFlag: h.A.FFlag,
325+
Logger: h.A.Logger,
322326
E: e,
323327
Endpoint: endpoint,
324328
Project: project,

internal/pkg/fflag/fflag.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ func NewFFlag(enableFeatureFlags []string) *FFlag {
7878
return f
7979
}
8080

81+
func NoopFflag() *FFlag {
82+
return &FFlag{
83+
Features: clone(DefaultFeaturesState),
84+
}
85+
}
86+
8187
func clone(src map[FeatureFlagKey]FeatureFlagState) map[FeatureFlagKey]FeatureFlagState {
8288
dst := make(map[FeatureFlagKey]FeatureFlagState)
8389
for k, v := range src {

internal/pkg/middleware/middleware.go

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ import (
66
"encoding/base64"
77
"errors"
88
"fmt"
9-
"github.com/frain-dev/convoy/internal/pkg/license"
109
"net/http"
1110
"strconv"
1211
"strings"
1312
"time"
1413

14+
"github.com/frain-dev/convoy/internal/pkg/license"
15+
1516
"github.com/frain-dev/convoy/internal/pkg/limiter"
1617
rlimiter "github.com/frain-dev/convoy/internal/pkg/limiter/redis"
1718

@@ -43,7 +44,49 @@ const (
4344
pageableCtx types.ContextKey = "pageable"
4445
)
4546

46-
var ErrValidLicenseRequired = errors.New("access to this resource requires a valid license")
47+
var (
48+
ErrValidLicenseRequired = errors.New("access to this resource requires a valid license")
49+
50+
// skipLoggingPaths defines paths that should not be logged by the request logger
51+
skipLoggingPaths []string
52+
)
53+
54+
// shouldSkipLogging checks if the given path should be excluded from logging
55+
func shouldSkipLogging(r map[string]interface{}, w map[string]interface{}) bool {
56+
for _, skipPath := range skipLoggingPaths {
57+
if strings.Contains(r["requestURL"].(string), skipPath) {
58+
return true
59+
}
60+
}
61+
62+
headers := w["header"].(map[string]string)
63+
64+
if strings.Contains(headers["content-type"], "application/javascript") {
65+
return true
66+
}
67+
68+
if strings.Contains(headers["content-type"], "image") {
69+
return true
70+
}
71+
72+
if strings.Contains(headers["content-type"], "font") {
73+
return true
74+
}
75+
76+
if strings.Contains(headers["content-type"], "text/html") {
77+
return true
78+
}
79+
80+
if strings.Contains(headers["content-type"], "text/javascript") {
81+
return true
82+
}
83+
84+
if strings.Contains(headers["content-type"], "text/css") {
85+
return true
86+
}
87+
88+
return false
89+
}
4790

4891
type AuthorizedLogin struct {
4992
Username string `json:"username,omitempty"`
@@ -316,6 +359,10 @@ func LogHttpRequest(a *types.APIOptions) func(next http.Handler) http.Handler {
316359
"httpResponse": responseFields,
317360
}
318361

362+
if shouldSkipLogging(requestFields, responseFields) {
363+
return
364+
}
365+
319366
log.FromContext(r.Context()).WithFields(logFields).Log(lvl, requestFields["requestURL"])
320367
}()
321368

net/dispatcher.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ var (
3737
ErrLoggerIsRequired = errors.New("logger is required")
3838
ErrInvalidIPPrefix = errors.New("invalid IP prefix")
3939
ErrTracerIsRequired = errors.New("tracer cannot be nil")
40+
ErrNon2xxResponse = errors.New("endpoint returned a non-2xx response")
4041
)
4142

4243
type DispatcherOption func(d *Dispatcher) error
@@ -96,7 +97,7 @@ type Dispatcher struct {
9697
ff *fflag.FFlag
9798
l license.Licenser
9899

99-
logger *log.Logger
100+
logger log.StdLogger
100101
transport *http.Transport
101102
client *http.Client
102103
rules *netjail.Rules
@@ -242,7 +243,7 @@ func TLSConfigOption(insecureSkipVerify bool, licenser license.Licenser, caCertT
242243
}
243244
}
244245

245-
func LoggerOption(logger *log.Logger) DispatcherOption {
246+
func LoggerOption(logger log.StdLogger) DispatcherOption {
246247
return func(d *Dispatcher) error {
247248
if logger == nil {
248249
return ErrLoggerIsRequired
@@ -289,7 +290,7 @@ func (d *Dispatcher) validateProxy(proxyURL string) (*url.URL, bool, error) {
289290
return nil, false, nil
290291
}
291292

292-
func (d *Dispatcher) SendRequest(ctx context.Context, endpoint, method string, jsonData json.RawMessage, signatureHeader string, hmac string, maxResponseSize int64, headers httpheader.HTTPHeader, idempotencyKey string, timeout time.Duration) (*Response, error) {
293+
func (d *Dispatcher) SendWebhook(ctx context.Context, endpoint string, jsonData json.RawMessage, signatureHeader string, hmac string, maxResponseSize int64, headers httpheader.HTTPHeader, idempotencyKey string, timeout time.Duration) (*Response, error) {
293294
d.logger.Debugf("rules: %+v", d.rules)
294295

295296
ctx, cancel := context.WithTimeout(ctx, timeout)
@@ -307,7 +308,7 @@ func (d *Dispatcher) SendRequest(ctx context.Context, endpoint, method string, j
307308
ctx = netjail.ContextWithRules(ctx, d.rules)
308309
}
309310

310-
req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewBuffer(jsonData))
311+
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(jsonData))
311312
if err != nil {
312313
d.logger.WithError(err).Error("error occurred while creating request")
313314
return r, err
@@ -462,3 +463,48 @@ func (d *Dispatcher) do(ctx context.Context, req *http.Request, res *Response, m
462463

463464
return nil
464465
}
466+
467+
// Ping sends a GET request to the specified endpoint and verifies it returns a 2xx response.
468+
// It returns an error if the endpoint is unreachable or returns a non-2xx status code.
469+
func (d *Dispatcher) Ping(ctx context.Context, endpoint string, timeout time.Duration) error {
470+
d.logger.Debugf("rules: %+v", d.rules)
471+
472+
ctx, cancel := context.WithTimeout(ctx, timeout)
473+
defer cancel()
474+
475+
if d.ff.CanAccessFeature(fflag.IpRules) && d.l.IpRules() {
476+
ctx = netjail.ContextWithRules(ctx, d.rules)
477+
}
478+
479+
req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil)
480+
if err != nil {
481+
d.logger.WithError(err).Error("error creating ping request")
482+
return err
483+
}
484+
485+
trace := &httptrace.ClientTrace{
486+
GotConn: func(connInfo httptrace.GotConnInfo) {
487+
d.logger.Debugf("IP address resolved for %s to: %s", endpoint, connInfo.Conn.RemoteAddr())
488+
},
489+
}
490+
491+
ctx = httptrace.WithClientTrace(ctx, trace)
492+
req = req.WithContext(ctx)
493+
494+
req.Header.Add("User-Agent", defaultUserAgent())
495+
496+
response, err := d.client.Do(req)
497+
if err != nil {
498+
d.logger.WithError(err).Error("error sending ping request")
499+
return err
500+
}
501+
defer response.Body.Close()
502+
503+
if response.StatusCode < 200 || response.StatusCode > 299 {
504+
err = fmt.Errorf("%w: got status code %d", ErrNon2xxResponse, response.StatusCode)
505+
d.logger.WithError(err).Error("ping request failed")
506+
return err
507+
}
508+
509+
return nil
510+
}

net/dispatcher_test.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ func TestDispatcher_SendRequest(t *testing.T) {
295295
defer deferFn()
296296
}
297297

298-
got, err := d.SendRequest(context.Background(), tt.args.endpoint, tt.args.method, tt.args.jsonData, tt.args.project.Config.Signature.Header.String(), tt.args.hmac, config.MaxResponseSize, tt.args.headers, "", time.Minute)
298+
got, err := d.SendWebhook(context.Background(), tt.args.endpoint, tt.args.jsonData, tt.args.project.Config.Signature.Header.String(), tt.args.hmac, config.MaxResponseSize, tt.args.headers, "", time.Minute)
299299
if tt.wantErr {
300300
require.NotNil(t, err)
301301
require.Contains(t, err.Error(), tt.want.Error)
@@ -398,7 +398,7 @@ func TestNewDispatcher(t *testing.T) {
398398
}
399399
}
400400

401-
// TestDispatcherSendRequest tests the basic functionality of SendRequest
401+
// TestDispatcherSendRequest tests the basic functionality of SendWebhook
402402
func TestDispatcherSendRequest(t *testing.T) {
403403
// Start a test server
404404
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -437,10 +437,9 @@ func TestDispatcherSendRequest(t *testing.T) {
437437
}
438438

439439
// Send request
440-
resp, err := dispatcher.SendRequest(
440+
resp, err := dispatcher.SendWebhook(
441441
context.Background(),
442442
server.URL,
443-
"POST",
444443
jsonData,
445444
"X-Signature",
446445
"test-hmac",
@@ -483,10 +482,9 @@ func TestDispatcherWithTimeout(t *testing.T) {
483482
require.NoError(t, err)
484483

485484
// Send request with a short timeout
486-
_, err = dispatcher.SendRequest(
485+
_, err = dispatcher.SendWebhook(
487486
context.Background(),
488487
server.URL,
489-
"GET",
490488
nil,
491489
"X-Signature",
492490
"test-hmac",
@@ -529,10 +527,9 @@ func TestDispatcherWithBlockedIP(t *testing.T) {
529527
require.NoError(t, err)
530528

531529
// Attempt to send a request
532-
_, err = dispatcher.SendRequest(
530+
_, err = dispatcher.SendWebhook(
533531
context.Background(),
534532
server.URL,
535-
"GET",
536533
nil,
537534
"X-Signature",
538535
"test-hmac",

services/create_endpoint.go

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@ package services
33
import (
44
"context"
55
"errors"
6+
"fmt"
7+
"github.com/frain-dev/convoy/config"
8+
"github.com/frain-dev/convoy/internal/pkg/fflag"
69
"github.com/frain-dev/convoy/internal/pkg/keys"
10+
"github.com/frain-dev/convoy/net"
711
"net/http"
12+
"net/url"
813
"time"
914

1015
"github.com/frain-dev/convoy"
@@ -23,9 +28,10 @@ type CreateEndpointService struct {
2328
EndpointRepo datastore.EndpointRepository
2429
ProjectRepo datastore.ProjectRepository
2530
Licenser license.Licenser
26-
27-
E models.CreateEndpoint
28-
ProjectID string
31+
FeatureFlag *fflag.FFlag
32+
Logger log.StdLogger
33+
E models.CreateEndpoint
34+
ProjectID string
2935
}
3036

3137
func (a *CreateEndpointService) Run(ctx context.Context) (*datastore.Endpoint, error) {
@@ -34,12 +40,12 @@ func (a *CreateEndpointService) Run(ctx context.Context) (*datastore.Endpoint, e
3440
return nil, &ServiceError{ErrMsg: "failed to load endpoint project", Err: err}
3541
}
3642

37-
url, err := util.ValidateEndpoint(a.E.URL, project.Config.SSL.EnforceSecureEndpoints, a.Licenser.CustomCertificateAuthority())
43+
endpointUrl, err := a.ValidateEndpoint(ctx, project.Config.SSL.EnforceSecureEndpoints)
3844
if err != nil {
3945
return nil, &ServiceError{ErrMsg: err.Error()}
4046
}
4147

42-
a.E.URL = url
48+
a.E.URL = endpointUrl
4349

4450
truthValue := true
4551
switch project.Type {
@@ -125,6 +131,56 @@ func (a *CreateEndpointService) Run(ctx context.Context) (*datastore.Endpoint, e
125131
return endpoint, nil
126132
}
127133

134+
func (a *CreateEndpointService) ValidateEndpoint(ctx context.Context, enforceSecure bool) (string, error) {
135+
if util.IsStringEmpty(a.E.URL) {
136+
return "", errors.New("please provide the endpoint url")
137+
}
138+
139+
u, pingErr := url.Parse(a.E.URL)
140+
if pingErr != nil {
141+
return "", pingErr
142+
}
143+
144+
switch u.Scheme {
145+
case "http":
146+
if enforceSecure {
147+
return "", errors.New("only https endpoints allowed")
148+
}
149+
case "https":
150+
cfg, innerErr := config.Get()
151+
if innerErr != nil {
152+
return "", innerErr
153+
}
154+
155+
caCertTLSCfg, innerErr := config.GetCaCert()
156+
if innerErr != nil {
157+
return "", innerErr
158+
}
159+
160+
dispatcher, innerErr := net.NewDispatcher(
161+
a.Licenser,
162+
a.FeatureFlag,
163+
net.LoggerOption(a.Logger),
164+
net.ProxyOption(cfg.Server.HTTP.HttpProxy),
165+
net.AllowListOption(cfg.Dispatcher.AllowList),
166+
net.BlockListOption(cfg.Dispatcher.BlockList),
167+
net.TLSConfigOption(cfg.Dispatcher.InsecureSkipVerify, a.Licenser, caCertTLSCfg),
168+
)
169+
if innerErr != nil {
170+
return "", innerErr
171+
}
172+
173+
pingErr = dispatcher.Ping(ctx, a.E.URL, 10*time.Second)
174+
if pingErr != nil {
175+
return "", fmt.Errorf("failed to ping tls endpoint: %v", pingErr)
176+
}
177+
default:
178+
return "", errors.New("invalid endpoint scheme")
179+
}
180+
181+
return u.String(), nil
182+
}
183+
128184
func ValidateEndpointAuthentication(auth *datastore.EndpointAuthentication) (*datastore.EndpointAuthentication, error) {
129185
if auth != nil && !util.IsStringEmpty(string(auth.Type)) {
130186
if err := util.Validate(auth); err != nil {

0 commit comments

Comments
 (0)