Skip to content

Commit e19abc5

Browse files
authored
Merge pull request #8 from shellfu/wildcards
muxer: fix route matching and add wildcard support
2 parents 7984116 + 57851ce commit e19abc5

File tree

2 files changed

+137
-7
lines changed

2 files changed

+137
-7
lines changed

router.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,34 @@ as its parameters.
126126
// handle the request
127127
// ...
128128
})
129-
130-
If there's an error compiling the regular expression that matches the path, it returns the error.
131129
*/
132130
func (r *Router) HandleRoute(method, path string, handler http.HandlerFunc) {
133-
// Parse path to extract parameter names
134131
paramNames := make([]string, 0)
132+
133+
// First handle catch-all wildcard
134+
if strings.Contains(path, "*") {
135+
base := strings.TrimSuffix(path, "*")
136+
base = strings.TrimSuffix(base, "/")
137+
// Match everything after the base path, but don't capture the leading slash
138+
pathRegex := regexp.QuoteMeta(base) + `/(.+)`
139+
paramNames = append(paramNames, "path")
140+
141+
r.routes = append(r.routes, Route{
142+
method: method,
143+
path: regexp.MustCompile("^" + pathRegex + "$"),
144+
handler: handler,
145+
params: paramNames,
146+
template: path,
147+
})
148+
return
149+
}
150+
151+
// Handle standard path parameters with the original pattern
135152
re := regexp.MustCompile(`:([\w-]+)`)
136153
pathRegex := re.ReplaceAllStringFunc(path, func(m string) string {
137154
paramName := m[1:]
138155
paramNames = append(paramNames, paramName)
139-
return `([-\w.]+)`
156+
return `([-\w.]+)` // Maintain original pattern
140157
})
141158

142159
exactPath := regexp.MustCompile("^" + pathRegex + "$")

router_test.go

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"io/ioutil"
98
"net/http"
109
"net/http/httptest"
1110
"reflect"
@@ -179,7 +178,7 @@ func TestSubrouter(t *testing.T) {
179178
// Create subrouter for www.example.com
180179
example := router.Subrouter("www.example.com")
181180
example.HandlerFunc(http.MethodGet, "/example", func(w http.ResponseWriter, r *http.Request) {
182-
fmt.Fprint(w, "Example")
181+
fmt.Fprint(w, "Example") // nolint: errcheck
183182
})
184183

185184
// Create subrouter for /api
@@ -388,7 +387,7 @@ func TestMaxRequestBodySize(t *testing.T) {
388387
router := NewRouter(WithMaxRequestBodySize(maxRequestBodySize))
389388

390389
handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
391-
body, err := ioutil.ReadAll(r.Body)
390+
body, err := io.ReadAll(r.Body)
392391
if err != nil {
393392
w.WriteHeader(http.StatusRequestEntityTooLarge)
394393
return
@@ -619,3 +618,117 @@ func TestCurrentRoute(t *testing.T) {
619618
})
620619
}
621620
}
621+
622+
func TestNestedParams(t *testing.T) {
623+
router := NewRouter()
624+
625+
// Track captured params
626+
var capturedParams map[string]string
627+
628+
router.HandleRoute("GET", "/foo/:id/bar/:desc", func(w http.ResponseWriter, r *http.Request) {
629+
capturedParams = router.Params(r)
630+
})
631+
632+
req := httptest.NewRequest("GET", "/foo/123/bar/test-1", nil)
633+
w := httptest.NewRecorder()
634+
router.ServeHTTP(w, req)
635+
636+
expected := map[string]string{
637+
"id": "123",
638+
"desc": "test-1",
639+
}
640+
641+
if !reflect.DeepEqual(capturedParams, expected) {
642+
t.Errorf("expected params %v, got %v", expected, capturedParams)
643+
}
644+
}
645+
646+
func TestWildcardRoutes(t *testing.T) {
647+
tests := []struct {
648+
name string
649+
method string
650+
routePath string
651+
requestPath string
652+
expectedCode int
653+
expectedParam string
654+
wantMatch bool
655+
}{
656+
{
657+
name: "simple wildcard",
658+
method: http.MethodGet,
659+
routePath: "/validate/*",
660+
requestPath: "/validate/foo",
661+
expectedCode: http.StatusOK,
662+
expectedParam: "foo",
663+
wantMatch: true,
664+
},
665+
{
666+
name: "nested wildcard",
667+
method: http.MethodGet,
668+
routePath: "/validate/*",
669+
requestPath: "/validate/foo/bar",
670+
expectedCode: http.StatusOK,
671+
expectedParam: "foo/bar",
672+
wantMatch: true,
673+
},
674+
{
675+
name: "wildcard with query params",
676+
method: http.MethodGet,
677+
routePath: "/validate/*",
678+
requestPath: "/validate/foo?key=value",
679+
expectedCode: http.StatusOK,
680+
expectedParam: "foo",
681+
wantMatch: true,
682+
},
683+
{
684+
name: "no match without prefix",
685+
method: http.MethodGet,
686+
routePath: "/validate/*",
687+
requestPath: "/foo/bar",
688+
expectedCode: http.StatusNotFound,
689+
expectedParam: "",
690+
wantMatch: false,
691+
},
692+
{
693+
name: "method not allowed",
694+
method: http.MethodGet,
695+
routePath: "/validate/*",
696+
requestPath: "/validate/foo",
697+
expectedCode: http.StatusMethodNotAllowed,
698+
expectedParam: "",
699+
wantMatch: false,
700+
},
701+
}
702+
703+
for _, tc := range tests {
704+
t.Run(tc.name, func(t *testing.T) {
705+
router := NewRouter()
706+
707+
router.HandleRoute(tc.method, tc.routePath, func(w http.ResponseWriter, r *http.Request) {
708+
if tc.wantMatch {
709+
params := router.Params(r)
710+
if got := params["path"]; got != tc.expectedParam {
711+
t.Errorf("expected param %q, got %q", tc.expectedParam, got)
712+
}
713+
}
714+
w.WriteHeader(http.StatusOK)
715+
})
716+
717+
var method string
718+
if tc.name == "method not allowed" {
719+
method = http.MethodPost
720+
} else {
721+
method = tc.method
722+
}
723+
724+
req := httptest.NewRequest(method, tc.requestPath, nil)
725+
w := httptest.NewRecorder()
726+
727+
router.ServeHTTP(w, req)
728+
729+
if got := w.Code; got != tc.expectedCode {
730+
t.Errorf("expected status code %d, got %d", tc.expectedCode, got)
731+
}
732+
})
733+
}
734+
}

0 commit comments

Comments
 (0)