From 3fffa7df4ac1454eccd0a711599418c0cd3be6b6 Mon Sep 17 00:00:00 2001 From: Carl Jackson Date: Sat, 22 Mar 2014 19:26:14 -0700 Subject: [PATCH] Middleware tests + bugfixes My tests caught some bugs! Amazing! --- web/middleware.go | 33 ++++-- web/middleware_test.go | 255 +++++++++++++++++++++++++++++++++++++++++ web/mux.go | 2 +- 3 files changed, 281 insertions(+), 9 deletions(-) create mode 100644 web/middleware_test.go diff --git a/web/middleware.go b/web/middleware.go index fe8da8c..b62c67a 100644 --- a/web/middleware.go +++ b/web/middleware.go @@ -28,7 +28,8 @@ type mStack struct { // fully assembled middleware stacks (the "c" stands for "cached"). type cStack struct { C - m http.Handler + m http.Handler + pool chan *cStack } func (s *cStack) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -72,6 +73,7 @@ func (m *mStack) invalidate() { old := m.pool m.pool = make(chan *cStack, mPoolSize) close(old) + // Bleed down the old pool so it gets GC'd for _ = range old { } } @@ -94,26 +96,41 @@ func (m *mStack) newStack() *cStack { } func (m *mStack) alloc() *cStack { + // This is a little sloppy: this is only safe if this pointer + // dereference is atomic. Maybe someday I'll replace it with + // sync/atomic, but for now I happen to know that on all the + // architecures I care about it happens to be atomic. + p := m.pool + var cs *cStack select { - case cs := <-m.pool: + case cs = <-p: + // This can happen if we race against an invalidation. It's + // completely peaceful, so long as we assume we can grab a cStack before + // our stack blows out. if cs == nil { return m.alloc() } - return cs default: - return m.newStack() + cs = m.newStack() } + + cs.pool = p + return cs } func (m *mStack) release(cs *cStack) { - // It's possible that the pool has been invalidated and therefore - // closed, in which case we'll start panicing, which is dumb. I'm not - // sure this is actually better than just grabbing a lock, but whatever. + if cs.pool != m.pool { + return + } + // It's possible that the pool has been invalidated (and closed) between + // the check above and now, in which case we'll start panicing, which is + // dumb. I'm not sure this is actually better than just grabbing a lock, + // but whatever. defer func() { recover() }() select { - case m.pool <- cs: + case cs.pool <- cs: default: } } diff --git a/web/middleware_test.go b/web/middleware_test.go new file mode 100644 index 0000000..747b541 --- /dev/null +++ b/web/middleware_test.go @@ -0,0 +1,255 @@ +package web + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" +) + +func makeStack(ch chan string) *mStack { + router := func(c C, w http.ResponseWriter, r *http.Request) { + ch <- "router" + } + return &mStack{ + stack: make([]mLayer, 0), + pool: make(chan *cStack, mPoolSize), + router: HandlerFunc(router), + } +} + +func chanWare(ch chan string, s string) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ch <- s + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} + +func simpleRequest(ch chan string, st *mStack) { + defer func() { + ch <- "end" + }() + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + cs := st.alloc() + defer st.release(cs) + + cs.ServeHTTP(w, r) +} + +func assertOrder(t *testing.T, ch chan string, strings ...string) { + for i, s := range strings { + var v string + select { + case v = <-ch: + case <-time.After(5 * time.Millisecond): + t.Fatalf("Expected %q as %d'th value, but timed out", s, + i+1) + } + if s != v { + t.Errorf("%d'th value was %q, expected %q", i+1, v, s) + } + } +} + +func TestSimple(t *testing.T) { + t.Parallel() + + ch := make(chan string) + st := makeStack(ch) + st.Use("one", chanWare(ch, "one")) + st.Use("two", chanWare(ch, "two")) + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "two", "router", "end") +} + +func TestTypes(t *testing.T) { + t.Parallel() + + ch := make(chan string) + st := makeStack(ch) + st.Use("one", func(h http.Handler) http.Handler { + return h + }) + st.Use("two", func(c *C, h http.Handler) http.Handler { + return h + }) +} + +func TestAddMore(t *testing.T) { + t.Parallel() + + ch := make(chan string) + st := makeStack(ch) + st.Use("one", chanWare(ch, "one")) + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "router", "end") + + st.Use("two", chanWare(ch, "two")) + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "two", "router", "end") + + st.Use("three", chanWare(ch, "three")) + st.Use("four", chanWare(ch, "four")) + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "two", "three", "four", "router", "end") +} + +func TestInsert(t *testing.T) { + t.Parallel() + + ch := make(chan string) + st := makeStack(ch) + st.Use("one", chanWare(ch, "one")) + st.Use("two", chanWare(ch, "two")) + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "two", "router", "end") + + err := st.Insert("sloth", chanWare(ch, "sloth"), "squirrel") + if err == nil { + t.Error("Expected error when referencing unknown middleware") + } + + st.Insert("middle", chanWare(ch, "middle"), "two") + st.Insert("start", chanWare(ch, "start"), "one") + go simpleRequest(ch, st) + assertOrder(t, ch, "start", "one", "middle", "two", "router", "end") +} + +func TestAbandon(t *testing.T) { + t.Parallel() + + ch := make(chan string) + st := makeStack(ch) + st.Use("one", chanWare(ch, "one")) + st.Use("two", chanWare(ch, "two")) + st.Use("three", chanWare(ch, "three")) + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "two", "three", "router", "end") + + st.Abandon("two") + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "three", "router", "end") + + err := st.Abandon("panda") + if err == nil { + t.Error("Expected error when deleting unknown middleware") + } + + st.Abandon("one") + st.Abandon("three") + go simpleRequest(ch, st) + assertOrder(t, ch, "router", "end") + + st.Use("one", chanWare(ch, "one")) + go simpleRequest(ch, st) + assertOrder(t, ch, "one", "router", "end") +} + +func TestMiddlewareList(t *testing.T) { + t.Parallel() + + ch := make(chan string) + st := makeStack(ch) + st.Use("one", chanWare(ch, "one")) + st.Use("two", chanWare(ch, "two")) + st.Insert("mid", chanWare(ch, "mid"), "two") + st.Insert("before", chanWare(ch, "before"), "mid") + st.Abandon("one") + + m := st.Middleware() + if !reflect.DeepEqual(m, []string{"before", "mid", "two"}) { + t.Error("Middleware list was not as expected") + } + + go simpleRequest(ch, st) + assertOrder(t, ch, "before", "mid", "two", "router", "end") +} + +// This is a pretty sketchtacular test +func TestCaching(t *testing.T) { + ch := make(chan string) + st := makeStack(ch) + cs1 := st.alloc() + cs2 := st.alloc() + if cs1 == cs2 { + t.Fatal("cs1 and cs2 are the same") + } + st.release(cs2) + cs3 := st.alloc() + if cs2 != cs3 { + t.Fatalf("Expected cs2 to equal cs3") + } + st.release(cs1) + st.release(cs3) + cs4 := st.alloc() + cs5 := st.alloc() + if cs4 != cs1 { + t.Fatal("Expected cs4 to equal cs1") + } + if cs5 != cs3 { + t.Fatal("Expected cs5 to equal cs3") + } +} + +func TestInvalidation(t *testing.T) { + ch := make(chan string) + st := makeStack(ch) + cs1 := st.alloc() + cs2 := st.alloc() + st.release(cs1) + st.invalidate() + cs3 := st.alloc() + if cs3 == cs1 { + t.Fatal("Expected cs3 to be fresh, instead got cs1") + } + st.release(cs2) + cs4 := st.alloc() + if cs4 == cs2 { + t.Fatal("Expected cs4 to be fresh, instead got cs2") + } +} + +func TestContext(t *testing.T) { + router := func(c C, w http.ResponseWriter, r *http.Request) { + if c.Env["reqId"].(int) != 2 { + t.Error("Request id was not 2 :(") + } + } + st := mStack{ + stack: make([]mLayer, 0), + pool: make(chan *cStack, mPoolSize), + router: HandlerFunc(router), + } + st.Use("one", func(c *C, h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if c.Env != nil || c.UrlParams != nil { + t.Error("Expected a clean context") + } + c.Env = make(map[string]interface{}) + c.Env["reqId"] = 1 + + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + }) + + st.Use("two", func(c *C, h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if c.Env == nil { + t.Error("Expected env from last middleware") + } + c.Env["reqId"] = c.Env["reqId"].(int) + 1 + + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + }) + ch := make(chan string) + go simpleRequest(ch, &st) + assertOrder(t, ch, "end") +} diff --git a/web/mux.go b/web/mux.go index d95e929..c8f8d2c 100644 --- a/web/mux.go +++ b/web/mux.go @@ -49,7 +49,7 @@ func New() *Mux { mux := Mux{ mStack: mStack{ stack: make([]mLayer, 0), - pool: make(chan *cStack), + pool: make(chan *cStack, mPoolSize), }, router: router{ routes: make([]route, 0),