package web import ( "net/http" "net/http/httptest" "reflect" "testing" "time" ) func makeStack(ch chan string) *mStack { router := func(c C, w http.ResponseWriter, r *http.Request) { ch <- "router" } return &mStack{ stack: make([]mLayer, 0), pool: make(chan *cStack, mPoolSize), router: HandlerFunc(router), } } func chanWare(ch chan string, s string) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ch <- s h.ServeHTTP(w, r) } return http.HandlerFunc(fn) } } func simpleRequest(ch chan string, st *mStack) { defer func() { ch <- "end" }() r, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() cs := st.alloc() defer st.release(cs) cs.ServeHTTP(w, r) } func assertOrder(t *testing.T, ch chan string, strings ...string) { for i, s := range strings { var v string select { case v = <-ch: case <-time.After(5 * time.Millisecond): t.Fatalf("Expected %q as %d'th value, but timed out", s, i+1) } if s != v { t.Errorf("%d'th value was %q, expected %q", i+1, v, s) } } } func TestSimple(t *testing.T) { t.Parallel() ch := make(chan string) st := makeStack(ch) st.Use("one", chanWare(ch, "one")) st.Use("two", chanWare(ch, "two")) go simpleRequest(ch, st) assertOrder(t, ch, "one", "two", "router", "end") } func TestTypes(t *testing.T) { t.Parallel() ch := make(chan string) st := makeStack(ch) st.Use("one", func(h http.Handler) http.Handler { return h }) st.Use("two", func(c *C, h http.Handler) http.Handler { return h }) } func TestAddMore(t *testing.T) { t.Parallel() ch := make(chan string) st := makeStack(ch) st.Use("one", chanWare(ch, "one")) go simpleRequest(ch, st) assertOrder(t, ch, "one", "router", "end") st.Use("two", chanWare(ch, "two")) go simpleRequest(ch, st) assertOrder(t, ch, "one", "two", "router", "end") st.Use("three", chanWare(ch, "three")) st.Use("four", chanWare(ch, "four")) go simpleRequest(ch, st) assertOrder(t, ch, "one", "two", "three", "four", "router", "end") } func TestInsert(t *testing.T) { t.Parallel() ch := make(chan string) st := makeStack(ch) st.Use("one", chanWare(ch, "one")) st.Use("two", chanWare(ch, "two")) go simpleRequest(ch, st) assertOrder(t, ch, "one", "two", "router", "end") err := st.Insert("sloth", chanWare(ch, "sloth"), "squirrel") if err == nil { t.Error("Expected error when referencing unknown middleware") } st.Insert("middle", chanWare(ch, "middle"), "two") st.Insert("start", chanWare(ch, "start"), "one") go simpleRequest(ch, st) assertOrder(t, ch, "start", "one", "middle", "two", "router", "end") } func TestAbandon(t *testing.T) { t.Parallel() ch := make(chan string) st := makeStack(ch) st.Use("one", chanWare(ch, "one")) st.Use("two", chanWare(ch, "two")) st.Use("three", chanWare(ch, "three")) go simpleRequest(ch, st) assertOrder(t, ch, "one", "two", "three", "router", "end") st.Abandon("two") go simpleRequest(ch, st) assertOrder(t, ch, "one", "three", "router", "end") err := st.Abandon("panda") if err == nil { t.Error("Expected error when deleting unknown middleware") } st.Abandon("one") st.Abandon("three") go simpleRequest(ch, st) assertOrder(t, ch, "router", "end") st.Use("one", chanWare(ch, "one")) go simpleRequest(ch, st) assertOrder(t, ch, "one", "router", "end") } func TestMiddlewareList(t *testing.T) { t.Parallel() ch := make(chan string) st := makeStack(ch) st.Use("one", chanWare(ch, "one")) st.Use("two", chanWare(ch, "two")) st.Insert("mid", chanWare(ch, "mid"), "two") st.Insert("before", chanWare(ch, "before"), "mid") st.Abandon("one") m := st.Middleware() if !reflect.DeepEqual(m, []string{"before", "mid", "two"}) { t.Error("Middleware list was not as expected") } go simpleRequest(ch, st) assertOrder(t, ch, "before", "mid", "two", "router", "end") } // This is a pretty sketchtacular test func TestCaching(t *testing.T) { ch := make(chan string) st := makeStack(ch) cs1 := st.alloc() cs2 := st.alloc() if cs1 == cs2 { t.Fatal("cs1 and cs2 are the same") } st.release(cs2) cs3 := st.alloc() if cs2 != cs3 { t.Fatalf("Expected cs2 to equal cs3") } st.release(cs1) st.release(cs3) cs4 := st.alloc() cs5 := st.alloc() if cs4 != cs1 { t.Fatal("Expected cs4 to equal cs1") } if cs5 != cs3 { t.Fatal("Expected cs5 to equal cs3") } } func TestInvalidation(t *testing.T) { ch := make(chan string) st := makeStack(ch) cs1 := st.alloc() cs2 := st.alloc() st.release(cs1) st.invalidate() cs3 := st.alloc() if cs3 == cs1 { t.Fatal("Expected cs3 to be fresh, instead got cs1") } st.release(cs2) cs4 := st.alloc() if cs4 == cs2 { t.Fatal("Expected cs4 to be fresh, instead got cs2") } } func TestContext(t *testing.T) { router := func(c C, w http.ResponseWriter, r *http.Request) { if c.Env["reqId"].(int) != 2 { t.Error("Request id was not 2 :(") } } st := mStack{ stack: make([]mLayer, 0), pool: make(chan *cStack, mPoolSize), router: HandlerFunc(router), } st.Use("one", func(c *C, h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { if c.Env != nil || c.UrlParams != nil { t.Error("Expected a clean context") } c.Env = make(map[string]interface{}) c.Env["reqId"] = 1 h.ServeHTTP(w, r) } return http.HandlerFunc(fn) }) st.Use("two", func(c *C, h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { if c.Env == nil { t.Error("Expected env from last middleware") } c.Env["reqId"] = c.Env["reqId"].(int) + 1 h.ServeHTTP(w, r) } return http.HandlerFunc(fn) }) ch := make(chan string) go simpleRequest(ch, &st) assertOrder(t, ch, "end") }