Browse Source

Middleware tests + bugfixes

My tests caught some bugs! Amazing!
Carl Jackson 12 years ago
parent
commit
3fffa7df4a
3 changed files with 281 additions and 9 deletions
  1. +25
    -8
      web/middleware.go
  2. +255
    -0
      web/middleware_test.go
  3. +1
    -1
      web/mux.go

+ 25
- 8
web/middleware.go View File

@ -28,7 +28,8 @@ type mStack struct {
// fully assembled middleware stacks (the "c" stands for "cached"). // fully assembled middleware stacks (the "c" stands for "cached").
type cStack struct { type cStack struct {
C C
m http.Handler
m http.Handler
pool chan *cStack
} }
func (s *cStack) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *cStack) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -72,6 +73,7 @@ func (m *mStack) invalidate() {
old := m.pool old := m.pool
m.pool = make(chan *cStack, mPoolSize) m.pool = make(chan *cStack, mPoolSize)
close(old) close(old)
// Bleed down the old pool so it gets GC'd
for _ = range old { for _ = range old {
} }
} }
@ -94,26 +96,41 @@ func (m *mStack) newStack() *cStack {
} }
func (m *mStack) alloc() *cStack { func (m *mStack) alloc() *cStack {
// This is a little sloppy: this is only safe if this pointer
// dereference is atomic. Maybe someday I'll replace it with
// sync/atomic, but for now I happen to know that on all the
// architecures I care about it happens to be atomic.
p := m.pool
var cs *cStack
select { select {
case cs := <-m.pool:
case cs = <-p:
// This can happen if we race against an invalidation. It's
// completely peaceful, so long as we assume we can grab a cStack before
// our stack blows out.
if cs == nil { if cs == nil {
return m.alloc() return m.alloc()
} }
return cs
default: default:
return m.newStack()
cs = m.newStack()
} }
cs.pool = p
return cs
} }
func (m *mStack) release(cs *cStack) { func (m *mStack) release(cs *cStack) {
// It's possible that the pool has been invalidated and therefore
// closed, in which case we'll start panicing, which is dumb. I'm not
// sure this is actually better than just grabbing a lock, but whatever.
if cs.pool != m.pool {
return
}
// It's possible that the pool has been invalidated (and closed) between
// the check above and now, in which case we'll start panicing, which is
// dumb. I'm not sure this is actually better than just grabbing a lock,
// but whatever.
defer func() { defer func() {
recover() recover()
}() }()
select { select {
case m.pool <- cs:
case cs.pool <- cs:
default: default:
} }
} }


+ 255
- 0
web/middleware_test.go View File

@ -0,0 +1,255 @@
package web
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
)
func makeStack(ch chan string) *mStack {
router := func(c C, w http.ResponseWriter, r *http.Request) {
ch <- "router"
}
return &mStack{
stack: make([]mLayer, 0),
pool: make(chan *cStack, mPoolSize),
router: HandlerFunc(router),
}
}
func chanWare(ch chan string, s string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ch <- s
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
func simpleRequest(ch chan string, st *mStack) {
defer func() {
ch <- "end"
}()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
cs := st.alloc()
defer st.release(cs)
cs.ServeHTTP(w, r)
}
func assertOrder(t *testing.T, ch chan string, strings ...string) {
for i, s := range strings {
var v string
select {
case v = <-ch:
case <-time.After(5 * time.Millisecond):
t.Fatalf("Expected %q as %d'th value, but timed out", s,
i+1)
}
if s != v {
t.Errorf("%d'th value was %q, expected %q", i+1, v, s)
}
}
}
func TestSimple(t *testing.T) {
t.Parallel()
ch := make(chan string)
st := makeStack(ch)
st.Use("one", chanWare(ch, "one"))
st.Use("two", chanWare(ch, "two"))
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "two", "router", "end")
}
func TestTypes(t *testing.T) {
t.Parallel()
ch := make(chan string)
st := makeStack(ch)
st.Use("one", func(h http.Handler) http.Handler {
return h
})
st.Use("two", func(c *C, h http.Handler) http.Handler {
return h
})
}
func TestAddMore(t *testing.T) {
t.Parallel()
ch := make(chan string)
st := makeStack(ch)
st.Use("one", chanWare(ch, "one"))
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "router", "end")
st.Use("two", chanWare(ch, "two"))
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "two", "router", "end")
st.Use("three", chanWare(ch, "three"))
st.Use("four", chanWare(ch, "four"))
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "two", "three", "four", "router", "end")
}
func TestInsert(t *testing.T) {
t.Parallel()
ch := make(chan string)
st := makeStack(ch)
st.Use("one", chanWare(ch, "one"))
st.Use("two", chanWare(ch, "two"))
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "two", "router", "end")
err := st.Insert("sloth", chanWare(ch, "sloth"), "squirrel")
if err == nil {
t.Error("Expected error when referencing unknown middleware")
}
st.Insert("middle", chanWare(ch, "middle"), "two")
st.Insert("start", chanWare(ch, "start"), "one")
go simpleRequest(ch, st)
assertOrder(t, ch, "start", "one", "middle", "two", "router", "end")
}
func TestAbandon(t *testing.T) {
t.Parallel()
ch := make(chan string)
st := makeStack(ch)
st.Use("one", chanWare(ch, "one"))
st.Use("two", chanWare(ch, "two"))
st.Use("three", chanWare(ch, "three"))
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "two", "three", "router", "end")
st.Abandon("two")
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "three", "router", "end")
err := st.Abandon("panda")
if err == nil {
t.Error("Expected error when deleting unknown middleware")
}
st.Abandon("one")
st.Abandon("three")
go simpleRequest(ch, st)
assertOrder(t, ch, "router", "end")
st.Use("one", chanWare(ch, "one"))
go simpleRequest(ch, st)
assertOrder(t, ch, "one", "router", "end")
}
func TestMiddlewareList(t *testing.T) {
t.Parallel()
ch := make(chan string)
st := makeStack(ch)
st.Use("one", chanWare(ch, "one"))
st.Use("two", chanWare(ch, "two"))
st.Insert("mid", chanWare(ch, "mid"), "two")
st.Insert("before", chanWare(ch, "before"), "mid")
st.Abandon("one")
m := st.Middleware()
if !reflect.DeepEqual(m, []string{"before", "mid", "two"}) {
t.Error("Middleware list was not as expected")
}
go simpleRequest(ch, st)
assertOrder(t, ch, "before", "mid", "two", "router", "end")
}
// This is a pretty sketchtacular test
func TestCaching(t *testing.T) {
ch := make(chan string)
st := makeStack(ch)
cs1 := st.alloc()
cs2 := st.alloc()
if cs1 == cs2 {
t.Fatal("cs1 and cs2 are the same")
}
st.release(cs2)
cs3 := st.alloc()
if cs2 != cs3 {
t.Fatalf("Expected cs2 to equal cs3")
}
st.release(cs1)
st.release(cs3)
cs4 := st.alloc()
cs5 := st.alloc()
if cs4 != cs1 {
t.Fatal("Expected cs4 to equal cs1")
}
if cs5 != cs3 {
t.Fatal("Expected cs5 to equal cs3")
}
}
func TestInvalidation(t *testing.T) {
ch := make(chan string)
st := makeStack(ch)
cs1 := st.alloc()
cs2 := st.alloc()
st.release(cs1)
st.invalidate()
cs3 := st.alloc()
if cs3 == cs1 {
t.Fatal("Expected cs3 to be fresh, instead got cs1")
}
st.release(cs2)
cs4 := st.alloc()
if cs4 == cs2 {
t.Fatal("Expected cs4 to be fresh, instead got cs2")
}
}
func TestContext(t *testing.T) {
router := func(c C, w http.ResponseWriter, r *http.Request) {
if c.Env["reqId"].(int) != 2 {
t.Error("Request id was not 2 :(")
}
}
st := mStack{
stack: make([]mLayer, 0),
pool: make(chan *cStack, mPoolSize),
router: HandlerFunc(router),
}
st.Use("one", func(c *C, h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if c.Env != nil || c.UrlParams != nil {
t.Error("Expected a clean context")
}
c.Env = make(map[string]interface{})
c.Env["reqId"] = 1
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
})
st.Use("two", func(c *C, h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if c.Env == nil {
t.Error("Expected env from last middleware")
}
c.Env["reqId"] = c.Env["reqId"].(int) + 1
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
})
ch := make(chan string)
go simpleRequest(ch, &st)
assertOrder(t, ch, "end")
}

+ 1
- 1
web/mux.go View File

@ -49,7 +49,7 @@ func New() *Mux {
mux := Mux{ mux := Mux{
mStack: mStack{ mStack: mStack{
stack: make([]mLayer, 0), stack: make([]mLayer, 0),
pool: make(chan *cStack),
pool: make(chan *cStack, mPoolSize),
}, },
router: router{ router: router{
routes: make([]route, 0), routes: make([]route, 0),


Loading…
Cancel
Save