Skip to content

Commit 7984116

Browse files
authored
Merge pull request #7 from louieliu97/ll-add-path-template
Add the ability to store and retrieve the path template
2 parents 8b43379 + 4ce90dd commit 7984116

File tree

3 files changed

+140
-10
lines changed

3 files changed

+140
-10
lines changed

route.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package muxer
22

33
import (
4+
"errors"
45
"net/http"
56
"regexp"
67
)
@@ -11,10 +12,11 @@ It contains the regular expression that matches the request path, the HTTP metho
1112
the handler to be executed for that request, and the parameter names extracted from the path.
1213
*/
1314
type Route struct {
14-
path *regexp.Regexp
15-
method string
16-
handler http.Handler
17-
params []string
15+
path *regexp.Regexp
16+
method string
17+
handler http.Handler
18+
params []string
19+
template string
1820
}
1921

2022
func (r *Route) match(path string) map[string]string {
@@ -30,3 +32,16 @@ func (r *Route) match(path string) map[string]string {
3032

3133
return params
3234
}
35+
36+
// PathTemplate retrieves the path template of the current route
37+
func (r *Route) PathTemplate() (string, error) {
38+
if r == nil {
39+
return "", errors.New("route is nil, no template")
40+
}
41+
42+
if r.template == "" {
43+
return r.template, errors.New("template is empty")
44+
}
45+
46+
return r.template, nil
47+
}

router.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ type contextKey string
1212
const (
1313
// ParamsKey is the key used to store the extracted parameters in the request context.
1414
ParamsKey contextKey = "params"
15+
// RouteContextKey is the key used to store the matched route in the request context
16+
RouteContextKey contextKey = "matched_route"
1517
)
1618

1719
/*
@@ -131,19 +133,20 @@ func (r *Router) HandleRoute(method, path string, handler http.HandlerFunc) {
131133
// Parse path to extract parameter names
132134
paramNames := make([]string, 0)
133135
re := regexp.MustCompile(`:([\w-]+)`)
134-
path = re.ReplaceAllStringFunc(path, func(m string) string {
136+
pathRegex := re.ReplaceAllStringFunc(path, func(m string) string {
135137
paramName := m[1:]
136138
paramNames = append(paramNames, paramName)
137139
return `([-\w.]+)`
138140
})
139141

140-
exactPath := regexp.MustCompile("^" + path + "$")
142+
exactPath := regexp.MustCompile("^" + pathRegex + "$")
141143

142144
r.routes = append(r.routes, Route{
143-
method: method,
144-
path: exactPath,
145-
handler: handler,
146-
params: paramNames,
145+
method: method,
146+
path: exactPath,
147+
handler: handler,
148+
params: paramNames,
149+
template: path,
147150
})
148151
}
149152

@@ -208,6 +211,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
208211

209212
ctx := req.Context()
210213
ctx = context.WithValue(ctx, ParamsKey, params)
214+
ctx = context.WithValue(ctx, RouteContextKey, &route)
211215

212216
handler := route.handler
213217
for i := len(r.middleware) - 1; i >= 0; i-- {
@@ -256,3 +260,14 @@ the given order before executing the main handler.
256260
func (r *Router) Use(middleware ...func(http.Handler) http.Handler) {
257261
r.middleware = append(r.middleware, middleware...)
258262
}
263+
264+
// CurrentRoute returns the matched route for the current request, if any.
265+
// This only works when called inside the handler of the matched route
266+
// because the matched route is stored inside the request's context,
267+
// which is wiped after the handler returns.
268+
func CurrentRoute(r *http.Request) *Route {
269+
if rv := r.Context().Value(RouteContextKey); rv != nil {
270+
return rv.(*Route)
271+
}
272+
return nil
273+
}

router_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package muxer
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"io/ioutil"
@@ -519,3 +520,102 @@ func TestEnableCORSOption(t *testing.T) {
519520
})
520521
}
521522
}
523+
524+
func TestPathTemplate(t *testing.T) {
525+
tests := []struct {
526+
name string
527+
route *Route
528+
expectedOutput string
529+
expectedError error
530+
}{
531+
{
532+
name: "Error with nil Route",
533+
route: nil,
534+
expectedOutput: "",
535+
expectedError: errors.New("route is nil, no template"),
536+
},
537+
{
538+
name: "Error with empty template",
539+
route: &Route{template: ""},
540+
expectedOutput: "",
541+
expectedError: errors.New("template is empty"),
542+
},
543+
{
544+
name: "Valid Route with Template and path param",
545+
route: &Route{template: "/users/:id"},
546+
expectedOutput: "/users/:id",
547+
expectedError: nil,
548+
},
549+
{
550+
name: "Valid Route with simple Template",
551+
route: &Route{template: "/metrics"},
552+
expectedOutput: "/metrics",
553+
expectedError: nil,
554+
},
555+
}
556+
557+
for _, tt := range tests {
558+
t.Run(tt.name, func(t *testing.T) {
559+
output, err := tt.route.PathTemplate()
560+
561+
if tt.expectedOutput != output {
562+
t.Errorf("expected output %v, got %v", tt.expectedOutput, output)
563+
}
564+
if tt.expectedError != nil {
565+
if tt.expectedError.Error() != err.Error() {
566+
t.Errorf("expected error %v, got %v", tt.expectedError, err)
567+
}
568+
} else {
569+
if err != nil {
570+
t.Errorf("expected error to be nil, got %v", err)
571+
}
572+
}
573+
})
574+
}
575+
}
576+
577+
func TestCurrentRoute(t *testing.T) {
578+
route := &Route{template: "/users/:id"}
579+
580+
tests := []struct {
581+
name string
582+
contextKey interface{}
583+
contextValue interface{}
584+
expectedRoute *Route
585+
}{
586+
{
587+
name: "Route in context",
588+
contextKey: RouteContextKey,
589+
contextValue: route,
590+
expectedRoute: route,
591+
},
592+
{
593+
name: "No route in context",
594+
contextKey: "some_other_key",
595+
contextValue: "some_value",
596+
expectedRoute: nil,
597+
},
598+
{
599+
name: "Empty context",
600+
contextKey: nil,
601+
contextValue: nil,
602+
expectedRoute: nil,
603+
},
604+
}
605+
606+
for _, tt := range tests {
607+
t.Run(tt.name, func(t *testing.T) {
608+
req, _ := http.NewRequest(http.MethodGet, "/users/123", nil)
609+
610+
if tt.contextKey != nil {
611+
req = req.WithContext(context.WithValue(req.Context(), tt.contextKey, tt.contextValue))
612+
}
613+
614+
result := CurrentRoute(req)
615+
616+
if tt.expectedRoute != result {
617+
t.Errorf("expected route %v got %v", tt.expectedRoute, result)
618+
}
619+
})
620+
}
621+
}

0 commit comments

Comments
 (0)