diff --git a/web/pattern.go b/web/pattern.go index 952fc81..4916827 100644 --- a/web/pattern.go +++ b/web/pattern.go @@ -19,13 +19,13 @@ type regexpPattern struct { func (p regexpPattern) Prefix() string { return p.prefix } -func (p regexpPattern) Match(r *http.Request, c *C) bool { +func (p regexpPattern) Match(r *http.Request, c *C, dryrun bool) bool { matches := p.re.FindStringSubmatch(r.URL.Path) if matches == nil || len(matches) == 0 { return false } - if len(matches) == 1 { + if c == nil || dryrun || len(matches) == 1 { return true } @@ -149,7 +149,7 @@ func (s stringPattern) Prefix() string { return s.literals[0] } -func (s stringPattern) Match(r *http.Request, c *C) bool { +func (s stringPattern) Match(r *http.Request, c *C, dryrun bool) bool { path := r.URL.Path matches := make([]string, len(s.pats)) for i := 0; i < len(s.pats); i++ { @@ -181,6 +181,10 @@ func (s stringPattern) Match(r *http.Request, c *C) bool { } } + if c == nil || dryrun { + return true + } + if c.UrlParams == nil && len(matches) > 0 { c.UrlParams = make(map[string]string, len(matches)-1) } diff --git a/web/pattern_test.go b/web/pattern_test.go index 7fbe132..b6911d7 100644 --- a/web/pattern_test.go +++ b/web/pattern_test.go @@ -149,7 +149,7 @@ func TestPatterns(t *testing.T) { } func runTest(t *testing.T, p Pattern, test patternTest) { - result := p.Match(test.r, test.c) + result := p.Match(test.r, test.c, false) if result != test.match { t.Errorf("Expected match(%v, %#v) to return %v", p, test.r.URL.Path, test.match) diff --git a/web/router.go b/web/router.go index 63198af..f3f63ae 100644 --- a/web/router.go +++ b/web/router.go @@ -28,6 +28,8 @@ const ( mPOST | mPUT | mTRACE | mIDK ) +const validMethods = "goji.web.validMethods" + type route struct { // Theory: most real world routes have a string prefix which is both // cheap(-ish) to test against and pretty selective. And, conveniently, @@ -61,8 +63,9 @@ type Pattern interface { // Returns true if the request satisfies the pattern. This function is // free to examine both the request and the context to make this // decision. After it is certain that the request matches, this function - // should mutate or create c.UrlParams if necessary. - Match(r *http.Request, c *C) bool + // should mutate or create c.UrlParams if necessary, unless dryrun is + // set. + Match(r *http.Request, c *C, dryrun bool) bool } func parsePattern(p interface{}, isPrefix bool) Pattern { @@ -139,16 +142,64 @@ func httpMethod(mname string) method { func (rt *router) route(c C, w http.ResponseWriter, r *http.Request) { m := httpMethod(r.Method) + var methods method for _, route := range rt.routes { - if route.method&m == 0 || - !strings.HasPrefix(r.URL.Path, route.prefix) || - !route.pattern.Match(r, &c) { + if !strings.HasPrefix(r.URL.Path, route.prefix) || + !route.pattern.Match(r, &c, false) { + continue } - route.handler.ServeHTTPC(c, w, r) + + if route.method&m != 0 { + route.handler.ServeHTTPC(c, w, r) + return + } else if route.pattern.Match(r, &c, true) { + methods |= route.method + } + } + + if methods == 0 { + rt.notFound.ServeHTTPC(c, w, r) return } + // Oh god kill me now + var methodsList = make([]string, 0) + if methods&mCONNECT != 0 { + methodsList = append(methodsList, "CONNECT") + } + if methods&mDELETE != 0 { + methodsList = append(methodsList, "DELETE") + } + if methods&mGET != 0 { + methodsList = append(methodsList, "GET") + } + if methods&mHEAD != 0 { + methodsList = append(methodsList, "HEAD") + } + if methods&mOPTIONS != 0 { + methodsList = append(methodsList, "OPTIONS") + } + if methods&mPATCH != 0 { + methodsList = append(methodsList, "PATCH") + } + if methods&mPOST != 0 { + methodsList = append(methodsList, "POST") + } + if methods&mPUT != 0 { + methodsList = append(methodsList, "PUT") + } + if methods&mTRACE != 0 { + methodsList = append(methodsList, "TRACE") + } + + if c.Env == nil { + c.Env = map[string]interface{}{ + validMethods: methodsList, + } + } else { + c.Env[validMethods] = methodsList + } rt.notFound.ServeHTTPC(c, w, r) } @@ -270,6 +321,10 @@ func (m *router) Sub(pattern string, handler interface{}) { // Set the fallback (i.e., 404) handler for this mux. See the documentation for // type Mux for a description of what types are accepted for handler. +// +// As a convenience, the environment variable "goji.web.validMethods" will be +// set to the list of HTTP methods that could have been routed had they been +// provided on an otherwise identical request func (m *router) NotFound(handler interface{}) { m.notFound = parseHandler(handler) } diff --git a/web/router_test.go b/web/router_test.go index 77a372d..44331ae 100644 --- a/web/router_test.go +++ b/web/router_test.go @@ -3,6 +3,7 @@ package web import ( "net/http" "net/http/httptest" + "reflect" "regexp" "testing" "time" @@ -63,10 +64,12 @@ func (t testPattern) Prefix() string { return "" } -func (t testPattern) Match(r *http.Request, c *C) bool { +func (t testPattern) Match(r *http.Request, c *C, dryrun bool) bool { return true } +var _ Pattern = testPattern{} + func TestPatternTypes(t *testing.T) { t.Parallel() rt := makeRouter() @@ -173,3 +176,50 @@ func TestSub(t *testing.T) { t.Errorf("Timeout waiting for hello") } } + +var validMethodsTable = map[string][]string{ + "/hello/carl": {"DELETE", "GET", "PATCH", "POST", "PUT"}, + "/hello/bob": {"DELETE", "GET", "HEAD", "PATCH", "PUT"}, + "/hola/carl": {"DELETE", "GET", "PUT"}, + "/hola/bob": {"DELETE"}, + "/does/not/compute": {}, +} + +func TestValidMethods(t *testing.T) { + t.Parallel() + rt := makeRouter() + ch := make(chan []string, 1) + + rt.NotFound(func(c C, w http.ResponseWriter, r *http.Request) { + if c.Env == nil { + ch <- []string{} + return + } + methods, ok := c.Env[validMethods] + if !ok { + ch <- []string{} + return + } + ch <- methods.([]string) + }) + + rt.Get("/hello/carl", http.NotFound) + rt.Post("/hello/carl", http.NotFound) + rt.Head("/hello/bob", http.NotFound) + rt.Get("/hello/:name", http.NotFound) + rt.Put("/hello/:name", http.NotFound) + rt.Patch("/hello/:name", http.NotFound) + rt.Get("/:greet/carl", http.NotFound) + rt.Put("/:greet/carl", http.NotFound) + rt.Delete("/:greet/:anyone", http.NotFound) + + for path, eMethods := range validMethodsTable { + r, _ := http.NewRequest("BOGUS", path, nil) + rt.route(C{}, httptest.NewRecorder(), r) + aMethods := <-ch + if !reflect.DeepEqual(eMethods, aMethods) { + t.Errorf("For %q, expected %v, got %v", path, eMethods, + aMethods) + } + } +}