Skip to content

Commit 57851ce

Browse files
committed
muxer: fix route matching and add wildcard support
This patch fixes a critical issue in the route matching logic where nested paths like /validate/foo/bar/ and other paths with wildcards would incorrectly receive "Method not allowed" errors, despite having valid handlers registered. The problem was that our regex pattern handling didn't properly support wildcards (*) in routes, which is essential for things like auth middleware where we need to match and validate arbitrary paths. The fix introduces proper wildcard handling while maintaining backward compatibility with the existing param-based routes. Key changes: - Add support for * wildcard at the end of routes - Properly handle nested paths in wildcard matches using /(.+) - Maintain backward compatibility for :param style routes using ([-\w.]+) - Add comprehensive test coverage for wildcard behavior
1 parent 7984116 commit 57851ce

File tree

2 files changed

+137
-7
lines changed

2 files changed

+137
-7
lines changed

router.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,34 @@ as its parameters.
126126
// handle the request
127127
// ...
128128
})
129-
130-
If there's an error compiling the regular expression that matches the path, it returns the error.
131129
*/
132130
func (r *Router) HandleRoute(method, path string, handler http.HandlerFunc) {
133-
// Parse path to extract parameter names
134131
paramNames := make([]string, 0)
132+
133+
// First handle catch-all wildcard
134+
if strings.Contains(path, "*") {
135+
base := strings.TrimSuffix(path, "*")
136+
base = strings.TrimSuffix(base, "/")
137+
// Match everything after the base path, but don't capture the leading slash
138+
pathRegex := regexp.QuoteMeta(base) + `/(.+)`
139+
paramNames = append(paramNames, "path")
140+
141+
r.routes = append(r.routes, Route{
142+
method: method,
143+
path: regexp.MustCompile("^" + pathRegex + "$"),
144+
handler: handler,
145+
params: paramNames,
146+
template: path,
147+
})
148+
return
149+
}
150+
151+
// Handle standard path parameters with the original pattern
135152
re := regexp.MustCompile(`:([\w-]+)`)
136153
pathRegex := re.ReplaceAllStringFunc(path, func(m string) string {
137154
paramName := m[1:]
138155
paramNames = append(paramNames, paramName)
139-
return `([-\w.]+)`
156+
return `([-\w.]+)` // Maintain original pattern
140157
})
141158

142159
exactPath := regexp.MustCompile("^" + pathRegex + "$")

router_test.go

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"io/ioutil"
98
"net/http"
109
"net/http/httptest"
1110
"reflect"
@@ -179,7 +178,7 @@ func TestSubrouter(t *testing.T) {
179178
// Create subrouter for www.example.com
180179
example := router.Subrouter("www.example.com")
181180
example.HandlerFunc(http.MethodGet, "/example", func(w http.ResponseWriter, r *http.Request) {
182-
fmt.Fprint(w, "Example")
181+
fmt.Fprint(w, "Example") // nolint: errcheck
183182
})
184183

185184
// Create subrouter for /api
@@ -388,7 +387,7 @@ func TestMaxRequestBodySize(t *testing.T) {
388387
router := NewRouter(WithMaxRequestBodySize(maxRequestBodySize))
389388

390389
handlerFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
391-
body, err := ioutil.ReadAll(r.Body)
390+
body, err := io.ReadAll(r.Body)
392391
if err != nil {
393392
w.WriteHeader(http.StatusRequestEntityTooLarge)
394393
return
@@ -619,3 +618,117 @@ func TestCurrentRoute(t *testing.T) {
619618
})
620619
}
621620
}
621+
622+
func TestNestedParams(t *testing.T) {
623+
router := NewRouter()
624+
625+
// Track captured params
626+
var capturedParams map[string]string
627+
628+
router.HandleRoute("GET", "/foo/:id/bar/:desc", func(w http.ResponseWriter, r *http.Request) {
629+
capturedParams = router.Params(r)
630+
})
631+
632+
req := httptest.NewRequest("GET", "/foo/123/bar/test-1", nil)
633+
w := httptest.NewRecorder()
634+
router.ServeHTTP(w, req)
635+
636+
expected := map[string]string{
637+
"id": "123",
638+
"desc": "test-1",
639+
}
640+
641+
if !reflect.DeepEqual(capturedParams, expected) {
642+
t.Errorf("expected params %v, got %v", expected, capturedParams)
643+
}
644+
}
645+
646+
func TestWildcardRoutes(t *testing.T) {
647+
tests := []struct {
648+
name string
649+
method string
650+
routePath string
651+
requestPath string
652+
expectedCode int
653+
expectedParam string
654+
wantMatch bool
655+
}{
656+
{
657+
name: "simple wildcard",
658+
method: http.MethodGet,
659+
routePath: "/validate/*",
660+
requestPath: "/validate/foo",
661+
expectedCode: http.StatusOK,
662+
expectedParam: "foo",
663+
wantMatch: true,
664+
},
665+
{
666+
name: "nested wildcard",
667+
method: http.MethodGet,
668+
routePath: "/validate/*",
669+
requestPath: "/validate/foo/bar",
670+
expectedCode: http.StatusOK,
671+
expectedParam: "foo/bar",
672+
wantMatch: true,
673+
},
674+
{
675+
name: "wildcard with query params",
676+
method: http.MethodGet,
677+
routePath: "/validate/*",
678+
requestPath: "/validate/foo?key=value",
679+
expectedCode: http.StatusOK,
680+
expectedParam: "foo",
681+
wantMatch: true,
682+
},
683+
{
684+
name: "no match without prefix",
685+
method: http.MethodGet,
686+
routePath: "/validate/*",
687+
requestPath: "/foo/bar",
688+
expectedCode: http.StatusNotFound,
689+
expectedParam: "",
690+
wantMatch: false,
691+
},
692+
{
693+
name: "method not allowed",
694+
method: http.MethodGet,
695+
routePath: "/validate/*",
696+
requestPath: "/validate/foo",
697+
expectedCode: http.StatusMethodNotAllowed,
698+
expectedParam: "",
699+
wantMatch: false,
700+
},
701+
}
702+
703+
for _, tc := range tests {
704+
t.Run(tc.name, func(t *testing.T) {
705+
router := NewRouter()
706+
707+
router.HandleRoute(tc.method, tc.routePath, func(w http.ResponseWriter, r *http.Request) {
708+
if tc.wantMatch {
709+
params := router.Params(r)
710+
if got := params["path"]; got != tc.expectedParam {
711+
t.Errorf("expected param %q, got %q", tc.expectedParam, got)
712+
}
713+
}
714+
w.WriteHeader(http.StatusOK)
715+
})
716+
717+
var method string
718+
if tc.name == "method not allowed" {
719+
method = http.MethodPost
720+
} else {
721+
method = tc.method
722+
}
723+
724+
req := httptest.NewRequest(method, tc.requestPath, nil)
725+
w := httptest.NewRecorder()
726+
727+
router.ServeHTTP(w, req)
728+
729+
if got := w.Code; got != tc.expectedCode {
730+
t.Errorf("expected status code %d, got %d", tc.expectedCode, got)
731+
}
732+
})
733+
}
734+
}

0 commit comments

Comments
 (0)