diff --git a/mux.go b/mux.go index 3857173..6136ab9 100644 --- a/mux.go +++ b/mux.go @@ -14,7 +14,7 @@ import ( // NewRouter returns a new router instance. func NewRouter() *Router { - return &Router{namedRoutes: make(map[string]*Route)} + return &Router{namedRoutes: make(map[string]*Route), KeepContext: false} } // Router registers routes to be matched and dispatches a handler. @@ -46,6 +46,8 @@ type Router struct { namedRoutes map[string]*Route // See Router.StrictSlash(). This defines the flag for new routes. strictSlash bool + // If true, do not clear the the request context after handling the request + KeepContext bool } // Match matches registered routes against the request. @@ -82,7 +84,9 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } handler = r.NotFoundHandler } - defer context.Clear(req) + if !r.KeepContext { + defer context.Clear(req) + } handler.ServeHTTP(w, req) } diff --git a/mux_test.go b/mux_test.go index 55159bd..8789697 100644 --- a/mux_test.go +++ b/mux_test.go @@ -8,6 +8,8 @@ import ( "fmt" "net/http" "testing" + + "github.com/gorilla/context" ) type routeTest struct { @@ -656,6 +658,36 @@ func testRoute(t *testing.T, test routeTest) { } } +// Tests that the context is cleared or not cleared properly depending on +// the configuration of the router +func TestKeepContext(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.HandleFunc("/", func1).Name("func1") + + req, _ := http.NewRequest("GET", "http://localhost/", nil) + context.Set(req, "t", 1) + + res := new(http.ResponseWriter) + r.ServeHTTP(*res, req) + + if _, ok := context.GetOk(req, "t"); ok { + t.Error("Context should have been cleared at end of request") + } + + r.KeepContext = true + + req, _ = http.NewRequest("GET", "http://localhost/", nil) + context.Set(req, "t", 1) + + r.ServeHTTP(*res, req) + if _, ok := context.GetOk(req, "t"); !ok { + t.Error("Context should NOT have been cleared at end of request") + } + +} + // https://plus.google.com/101022900381697718949/posts/eWy6DjFJ6uW func TestSubrouterHeader(t *testing.T) { expected := "func1 response"