diff --git a/web/func_equal.go b/web/func_equal.go index 3206b04..9c8f7cb 100644 --- a/web/func_equal.go +++ b/web/func_equal.go @@ -4,10 +4,6 @@ import ( "reflect" ) -func isFunc(fn interface{}) bool { - return reflect.ValueOf(fn).Kind() == reflect.Func -} - /* This is more than a little sketchtacular. Go's rules for function pointer equality are pretty restrictive: nil function pointers always compare equal, and @@ -25,12 +21,10 @@ purposes. If you're curious, you can read more about the representation of functions here: http://golang.org/s/go11func We're in effect comparing the pointers of the indirect layer. + +This function also works on non-function values. */ func funcEqual(a, b interface{}) bool { - if !isFunc(a) || !isFunc(b) { - panic("funcEqual: type error!") - } - av := reflect.ValueOf(&a).Elem() bv := reflect.ValueOf(&b).Elem() diff --git a/web/handler.go b/web/handler.go new file mode 100644 index 0000000..746c9f0 --- /dev/null +++ b/web/handler.go @@ -0,0 +1,42 @@ +package web + +import ( + "log" + "net/http" +) + +const unknownHandler = `Unknown handler type %T. See http://godoc.org/github.com/zenazn/goji/web#HandlerType for a list of acceptable types.` + +type netHTTPHandlerWrap struct{ http.Handler } +type netHTTPHandlerFuncWrap struct { + fn func(http.ResponseWriter, *http.Request) +} +type handlerFuncWrap struct { + fn func(C, http.ResponseWriter, *http.Request) +} + +func (h netHTTPHandlerWrap) ServeHTTPC(c C, w http.ResponseWriter, r *http.Request) { + h.Handler.ServeHTTP(w, r) +} +func (h netHTTPHandlerFuncWrap) ServeHTTPC(c C, w http.ResponseWriter, r *http.Request) { + h.fn(w, r) +} +func (h handlerFuncWrap) ServeHTTPC(c C, w http.ResponseWriter, r *http.Request) { + h.fn(c, w, r) +} + +func parseHandler(h HandlerType) Handler { + switch f := h.(type) { + case func(c C, w http.ResponseWriter, r *http.Request): + return handlerFuncWrap{f} + case func(w http.ResponseWriter, r *http.Request): + return netHTTPHandlerFuncWrap{f} + case Handler: + return f + case http.Handler: + return netHTTPHandlerWrap{f} + default: + log.Fatalf(unknownHandler, h) + panic("log.Fatalf does not return") + } +} diff --git a/web/match.go b/web/match.go new file mode 100644 index 0000000..1a44144 --- /dev/null +++ b/web/match.go @@ -0,0 +1,66 @@ +package web + +// The key used to store route Matches in the Goji environment. If this key is +// present in the environment and contains a value of type Match, routing will +// not be performed, and the Match's Handler will be used instead. +const MatchKey = "goji.web.Match" + +// Match is the type of routing matches. It is inserted into C.Env under +// MatchKey when the Mux.Router middleware is invoked. If MatchKey is present at +// route dispatch time, the Handler of the corresponding Match will be called +// instead of performing routing as usual. +// +// By computing a Match and inserting it into the Goji environment as part of a +// middleware stack (see Mux.Router, for instance), it is possible to customize +// Goji's routing behavior or replace it entirely. +type Match struct { + // Pattern is the Pattern that matched during routing. Will be nil if no + // route matched (Handler will be set to the Mux's NotFound handler) + Pattern Pattern + // The Handler corresponding to the matched pattern. + Handler Handler +} + +// GetMatch returns the Match stored in the Goji environment, or an empty Match +// if none exists (valid Matches always have a Handler property). +func GetMatch(c C) Match { + if c.Env == nil { + return Match{} + } + mi, ok := c.Env[MatchKey] + if !ok { + return Match{} + } + if m, ok := mi.(Match); ok { + return m + } + return Match{} +} + +// RawPattern returns the PatternType that was originally passed to ParsePattern +// or any of the HTTP method functions (Get, Post, etc.). +func (m Match) RawPattern() PatternType { + switch v := m.Pattern.(type) { + case regexpPattern: + return v.re + case stringPattern: + return v.raw + default: + return v + } +} + +// RawHandler returns the HandlerType that was originally passed to the HTTP +// method functions (Get, Post, etc.). +func (m Match) RawHandler() HandlerType { + switch v := m.Handler.(type) { + case netHTTPHandlerWrap: + return v.Handler + case handlerFuncWrap: + return v.fn + case netHTTPHandlerFuncWrap: + return v.fn + default: + return v + } +} diff --git a/web/match_test.go b/web/match_test.go new file mode 100644 index 0000000..aefff04 --- /dev/null +++ b/web/match_test.go @@ -0,0 +1,50 @@ +package web + +import ( + "net/http" + "regexp" + "testing" +) + +var rawPatterns = []PatternType{ + "/hello/:name", + regexp.MustCompile("^/hello/(?P[^/]+)$"), + testPattern{}, +} + +func TestRawPattern(t *testing.T) { + t.Parallel() + + for _, p := range rawPatterns { + m := Match{Pattern: ParsePattern(p)} + if rp := m.RawPattern(); rp != p { + t.Errorf("got %#v, expected %#v", rp, p) + } + } +} + +type httpHandlerOnly struct{} + +func (httpHandlerOnly) ServeHTTP(w http.ResponseWriter, r *http.Request) {} + +type handlerOnly struct{} + +func (handlerOnly) ServeHTTPC(c C, w http.ResponseWriter, r *http.Request) {} + +var rawHandlers = []HandlerType{ + func(w http.ResponseWriter, r *http.Request) {}, + func(c C, w http.ResponseWriter, r *http.Request) {}, + httpHandlerOnly{}, + handlerOnly{}, +} + +func TestRawHandler(t *testing.T) { + t.Parallel() + + for _, h := range rawHandlers { + m := Match{Handler: parseHandler(h)} + if rh := m.RawHandler(); !funcEqual(rh, h) { + t.Errorf("got %#v, expected %#v", rh, h) + } + } +} diff --git a/web/mux.go b/web/mux.go index c48ed4b..228c0a0 100644 --- a/web/mux.go +++ b/web/mux.go @@ -86,6 +86,31 @@ func (m *Mux) Abandon(middleware MiddlewareType) error { // Router functions +type routerMiddleware struct { + m *Mux + c *C + h http.Handler +} + +func (rm routerMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if rm.c.Env == nil { + rm.c.Env = make(map[interface{}]interface{}, 1) + } + rm.c.Env[MatchKey] = rm.m.rt.getMatch(rm.c, w, r) + rm.h.ServeHTTP(w, r) +} + +// Router is a middleware that performs routing and stores the resulting Match +// in Goji's environment. If a routing Match is present at the end of the +// middleware stack, that Match is used instead of re-routing. +// +// This middleware is especially useful to create post-routing middleware, e.g. +// a request logger which prints which pattern or handler was selected, or an +// authentication middleware which only applies to certain routes. +func (m *Mux) Router(c *C, h http.Handler) http.Handler { + return routerMiddleware{m, c, h} +} + /* Dispatch to the given handler when the pattern matches, regardless of HTTP method. diff --git a/web/router.go b/web/router.go index f2dd6b5..1fbc41f 100644 --- a/web/router.go +++ b/web/router.go @@ -1,7 +1,6 @@ package web import ( - "log" "net/http" "sort" "strings" @@ -30,7 +29,7 @@ const ( // The key used to communicate to the NotFound handler what methods would have // been allowed if they'd been provided. -const ValidMethodsKey = "goji.web.validMethods" +const ValidMethodsKey = "goji.web.ValidMethods" var validMethodsMap = map[string]method{ "CONNECT": mCONNECT, @@ -58,32 +57,6 @@ type router struct { machine *routeMachine } -type netHTTPWrap struct { - http.Handler -} - -func (h netHTTPWrap) ServeHTTPC(c C, w http.ResponseWriter, r *http.Request) { - h.Handler.ServeHTTP(w, r) -} - -const unknownHandler = `Unknown handler type %T. See http://godoc.org/github.com/zenazn/goji/web#HandlerType for a list of acceptable types.` - -func parseHandler(h interface{}) Handler { - switch f := h.(type) { - case Handler: - return f - case http.Handler: - return netHTTPWrap{f} - case func(c C, w http.ResponseWriter, r *http.Request): - return HandlerFunc(f) - case func(w http.ResponseWriter, r *http.Request): - return netHTTPWrap{http.HandlerFunc(f)} - default: - log.Fatalf(unknownHandler, h) - panic("log.Fatalf does not return") - } -} - func httpMethod(mname string) method { if method, ok := validMethodsMap[mname]; ok { return method @@ -102,7 +75,7 @@ func (rt *router) compile() *routeMachine { return &sm } -func (rt *router) route(c *C, w http.ResponseWriter, r *http.Request) { +func (rt *router) getMatch(c *C, w http.ResponseWriter, r *http.Request) Match { rm := rt.getMachine() if rm == nil { rm = rt.compile() @@ -110,13 +83,14 @@ func (rt *router) route(c *C, w http.ResponseWriter, r *http.Request) { methods, route := rm.route(c, w, r) if route != nil { - route.handler.ServeHTTPC(*c, w, r) - return + return Match{ + Pattern: route.pattern, + Handler: route.handler, + } } if methods == 0 { - rt.notFound.ServeHTTPC(*c, w, r) - return + return Match{Handler: rt.notFound} } var methodsList = make([]string, 0) @@ -134,10 +108,18 @@ func (rt *router) route(c *C, w http.ResponseWriter, r *http.Request) { } else { c.Env[ValidMethodsKey] = methodsList } - rt.notFound.ServeHTTPC(*c, w, r) + return Match{Handler: rt.notFound} +} + +func (rt *router) route(c *C, w http.ResponseWriter, r *http.Request) { + match := GetMatch(*c) + if match.Handler == nil { + match = rt.getMatch(c, w, r) + } + match.Handler.ServeHTTPC(*c, w, r) } -func (rt *router) handleUntyped(p interface{}, m method, h interface{}) { +func (rt *router) handleUntyped(p PatternType, m method, h HandlerType) { rt.handle(ParsePattern(p), m, parseHandler(h)) } diff --git a/web/router_middleware_test.go b/web/router_middleware_test.go new file mode 100644 index 0000000..d8f0a28 --- /dev/null +++ b/web/router_middleware_test.go @@ -0,0 +1,35 @@ +package web + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRouterMiddleware(t *testing.T) { + t.Parallel() + + m := New() + ch := make(chan string, 1) + m.Get("/a", chHandler(ch, "a")) + m.Get("/b", chHandler(ch, "b")) + m.Use(m.Router) + m.Use(func(c *C, h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + m := GetMatch(*c) + if rp := m.RawPattern(); rp != "/a" { + t.Fatalf("RawPattern was not /a: %v", rp) + } + r.URL.Path = "/b" + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + }) + + r, _ := http.NewRequest("GET", "/a", nil) + w := httptest.NewRecorder() + m.ServeHTTP(w, r) + if v := <-ch; v != "a" { + t.Error("Routing was not frozen! %s", v) + } +}