diff --git a/graceful/conn_set.go b/graceful/conn_set.go deleted file mode 100644 index c08099c..0000000 --- a/graceful/conn_set.go +++ /dev/null @@ -1,80 +0,0 @@ -package graceful - -import ( - "runtime" - "sync" -) - -type connShard struct { - mu sync.Mutex - // We sort of abuse this field to also act as a "please shut down" flag. - // If it's nil, you should die at your earliest opportunity. - set map[*conn]struct{} -} - -type connSet struct { - // This is an incrementing connection counter so we round-robin - // connections across shards. Use atomic when touching it. - id uint64 - shards []*connShard -} - -var idleSet connSet - -// We pretty aggressively preallocate set entries in the hopes that we never -// have to allocate memory with the lock held. This is definitely a premature -// optimization and is probably misguided, but luckily it costs us essentially -// nothing. -const prealloc = 2048 - -func init() { - // To keep the expected contention rate constant we'd have to grow this - // as numcpu**2. In practice, CPU counts don't generally grow without - // bound, and contention is probably going to be small enough that - // nobody cares anyways. - idleSet.shards = make([]*connShard, 2*runtime.NumCPU()) - for i := range idleSet.shards { - idleSet.shards[i] = &connShard{ - set: make(map[*conn]struct{}, prealloc), - } - } -} - -func (cs connSet) markIdle(c *conn) { - c.busy = false - shard := cs.shards[int(c.id%uint64(len(cs.shards)))] - shard.mu.Lock() - if shard.set == nil { - shard.mu.Unlock() - c.die = true - } else { - shard.set[c] = struct{}{} - shard.mu.Unlock() - } -} - -func (cs connSet) markBusy(c *conn) { - c.busy = true - shard := cs.shards[int(c.id%uint64(len(cs.shards)))] - shard.mu.Lock() - if shard.set == nil { - shard.mu.Unlock() - c.die = true - } else { - delete(shard.set, c) - shard.mu.Unlock() - } -} - -func (cs connSet) killall() { - for _, shard := range cs.shards { - shard.mu.Lock() - set := shard.set - shard.set = nil - shard.mu.Unlock() - - for conn := range set { - conn.closeIfIdle() - } - } -} diff --git a/graceful/conn_test.go b/graceful/conn_test.go deleted file mode 100644 index d86f12e..0000000 --- a/graceful/conn_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package graceful - -import ( - "net" - "time" -) - -// Stub out a net.Conn. This is going to be painful. - -type fakeAddr struct{} - -func (f fakeAddr) Network() string { - return "fake" -} -func (f fakeAddr) String() string { - return "fake" -} - -type fakeConn struct { - onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func() - onSetDeadline, onSetReadDeadline, onSetWriteDeadline func() -} - -// Here's my number, so... -func callMeMaybe(f func()) { - // I apologize for nothing. - if f != nil { - f() - } -} - -func (f fakeConn) Read(b []byte) (int, error) { - callMeMaybe(f.onRead) - return len(b), nil -} -func (f fakeConn) Write(b []byte) (int, error) { - callMeMaybe(f.onWrite) - return len(b), nil -} -func (f fakeConn) Close() error { - callMeMaybe(f.onClose) - return nil -} -func (f fakeConn) LocalAddr() net.Addr { - callMeMaybe(f.onLocalAddr) - return fakeAddr{} -} -func (f fakeConn) RemoteAddr() net.Addr { - callMeMaybe(f.onRemoteAddr) - return fakeAddr{} -} -func (f fakeConn) SetDeadline(t time.Time) error { - callMeMaybe(f.onSetDeadline) - return nil -} -func (f fakeConn) SetReadDeadline(t time.Time) error { - callMeMaybe(f.onSetReadDeadline) - return nil -} -func (f fakeConn) SetWriteDeadline(t time.Time) error { - callMeMaybe(f.onSetWriteDeadline) - return nil -} diff --git a/graceful/graceful.go b/graceful/graceful.go index 8958126..98edd8b 100644 --- a/graceful/graceful.go +++ b/graceful/graceful.go @@ -3,18 +3,7 @@ Package graceful implements graceful shutdown for HTTP servers by closing idle connections after receiving a signal. By default, this package listens for interrupts (i.e., SIGINT), but when it detects that it is running under Einhorn it will additionally listen for SIGUSR2 as well, giving your application -automatic support for graceful upgrades. - -It's worth mentioning explicitly that this package is a hack to shim graceful -shutdown behavior into the net/http package provided in Go 1.2. It was written -by carefully reading the sequence of function calls net/http happened to use as -of this writing and finding enough surface area with which to add appropriate -behavior. There's a very good chance that this package will cease to work in -future versions of Go, but with any luck the standard library will add support -of its own by then (https://code.google.com/p/go/issues/detail?id=4674). - -If you're interested in figuring out how this package works, we suggest you read -the documentation for WrapConn() and net.go. +automatic support for graceful restarts/code upgrades. */ package graceful @@ -22,19 +11,11 @@ import ( "crypto/tls" "net" "net/http" -) -/* -You might notice that these methods look awfully similar to the methods of the -same name from the go standard library--that's because they were stolen from -there! If go were more like, say, Ruby, it'd actually be possible to shim just -the Serve() method, since we can do everything we want from there. However, it's -not possible to get the other methods which call Serve() (ListenAndServe(), say) -to call your shimmed copy--they always call the original. + "github.com/zenazn/goji/graceful/listener" +) -Since I couldn't come up with a better idea, I just copy-and-pasted both -ListenAndServe and ListenAndServeTLS here more-or-less verbatim. "Oh well!" -*/ +// Most of the code here is lifted straight from net/http // Type Server is exactly the same as an http.Server, but provides more graceful // implementations of its methods. @@ -98,3 +79,19 @@ func Serve(l net.Listener, handler http.Handler) error { server := &Server{Handler: handler} return server.Serve(l) } + +// WrapListener wraps an arbitrary net.Listener for use with graceful shutdowns. +// In the background, it uses the listener sub-package to Wrap the listener in +// Deadline mode. If another mode of operation is desired, you should call +// listener.Wrap yourself: this function is smart enough to not double-wrap +// listeners. +func WrapListener(l net.Listener) net.Listener { + if lt, ok := l.(*listener.T); ok { + appendListener(lt) + return lt + } + + lt := listener.Wrap(l, listener.Deadline) + appendListener(lt) + return lt +} diff --git a/graceful/listener/conn.go b/graceful/listener/conn.go new file mode 100644 index 0000000..22e986f --- /dev/null +++ b/graceful/listener/conn.go @@ -0,0 +1,147 @@ +package listener + +import ( + "errors" + "io" + "net" + "sync" + "time" +) + +type conn struct { + net.Conn + + shard *shard + mode mode + + mu sync.Mutex // Protects the state machine below + busy bool // connection is in use (i.e., not idle) + closed bool // connection is closed + disowned bool // if true, this connection is no longer under our management +} + +// This intentionally looks a lot like the one in package net. +var errClosing = errors.New("use of closed network connection") + +func (c *conn) init() error { + c.shard.wg.Add(1) + if shouldExit := c.shard.markIdle(c); shouldExit { + c.Close() + return errClosing + } + return nil +} + +func (c *conn) Read(b []byte) (n int, err error) { + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.disowned { + return + } + + // This protects against a Close/Read race. We're not really + // concerned about the general case (it's fundamentally racy), + // but are mostly trying to prevent a race between a new request + // getting read off the wire in one thread while the connection + // is being gracefully shut down in another. + if c.closed && err == nil { + n = 0 + err = errClosing + return + } + + if c.mode != Manual && !c.busy && !c.closed { + c.busy = true + c.shard.markInUse(c) + } + }() + + return c.Conn.Read(b) +} + +func (c *conn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.closed && !c.disowned { + c.closed = true + c.shard.markInUse(c) + defer c.shard.wg.Done() + } + + return c.Conn.Close() +} + +func (c *conn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + if !c.disowned && c.mode == Deadline { + defer c.markIdle() + } + c.mu.Unlock() + return c.Conn.SetReadDeadline(t) +} + +func (c *conn) ReadFrom(r io.Reader) (int64, error) { + return io.Copy(c.Conn, r) +} + +func (c *conn) markIdle() { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.busy { + return + } + c.busy = false + + if exit := c.shard.markIdle(c); exit && !c.closed && !c.disowned { + c.closed = true + c.shard.markInUse(c) + defer c.shard.wg.Done() + c.Conn.Close() + return + } +} + +func (c *conn) markInUse() { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.busy && !c.closed && !c.disowned { + c.busy = true + c.shard.markInUse(c) + } +} + +func (c *conn) closeIfIdle() error { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.busy && !c.closed && !c.disowned { + c.closed = true + c.shard.markInUse(c) + defer c.shard.wg.Done() + return c.Conn.Close() + } + + return nil +} + +var errAlreadyDisowned = errors.New("listener: conn already disowned") + +func (c *conn) disown() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.disowned { + return errAlreadyDisowned + } + + c.shard.markInUse(c) + c.disowned = true + c.shard.wg.Done() + + return nil +} diff --git a/graceful/listener/conn_test.go b/graceful/listener/conn_test.go new file mode 100644 index 0000000..cd5d878 --- /dev/null +++ b/graceful/listener/conn_test.go @@ -0,0 +1,182 @@ +package listener + +import ( + "io" + "strings" + "testing" + "time" +) + +func TestManualRead(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Manual) + + go c.AllowRead() + wc.Read(make([]byte, 1024)) + + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if !c.Closed() { + t.Error("Read() should not make connection not-idle") + } +} + +func TestAutomaticRead(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Automatic) + + go c.AllowRead() + wc.Read(make([]byte, 1024)) + + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if c.Closed() { + t.Error("expected Read() to mark connection as in-use") + } +} + +func TestDeadlineRead(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Deadline) + + go c.AllowRead() + if _, err := wc.Read(make([]byte, 1024)); err != nil { + t.Fatalf("error reading from connection: %v", err) + } + + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if c.Closed() { + t.Error("expected Read() to mark connection as in-use") + } +} + +func TestDisownedRead(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Deadline) + + if err := Disown(wc); err != nil { + t.Fatalf("unexpected error disowning conn: %v", err) + } + if err := l.Close(); err != nil { + t.Fatalf("unexpected error closing listener: %v", err) + } + if err := l.Drain(); err != nil { + t.Fatalf("unexpected error draining listener: %v", err) + } + + go c.AllowRead() + if _, err := wc.Read(make([]byte, 1024)); err != nil { + t.Fatalf("error reading from connection: %v", err) + } +} + +func TestCloseConn(t *testing.T) { + t.Parallel() + l, _, wc := singleConn(t, Deadline) + + if err := MarkInUse(wc); err != nil { + t.Fatalf("error marking conn in use: %v", err) + } + if err := wc.Close(); err != nil { + t.Errorf("error closing connection: %v", err) + } + // This will hang if wc.Close() doesn't un-track the connection + if err := l.Drain(); err != nil { + t.Errorf("error draining listener: %v", err) + } +} + +func TestManualReadDeadline(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Manual) + + if err := MarkInUse(wc); err != nil { + t.Fatalf("error marking connection in use: %v", err) + } + if err := wc.SetReadDeadline(time.Now()); err != nil { + t.Fatalf("error setting read deadline: %v", err) + } + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if c.Closed() { + t.Error("SetReadDeadline() should not mark manual conn as idle") + } +} + +func TestAutomaticReadDeadline(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Automatic) + + if err := MarkInUse(wc); err != nil { + t.Fatalf("error marking connection in use: %v", err) + } + if err := wc.SetReadDeadline(time.Now()); err != nil { + t.Fatalf("error setting read deadline: %v", err) + } + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if c.Closed() { + t.Error("SetReadDeadline() should not mark automatic conn as idle") + } +} + +func TestDeadlineReadDeadline(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Deadline) + + if err := MarkInUse(wc); err != nil { + t.Fatalf("error marking connection in use: %v", err) + } + if err := wc.SetReadDeadline(time.Now()); err != nil { + t.Fatalf("error setting read deadline: %v", err) + } + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if !c.Closed() { + t.Error("SetReadDeadline() should mark deadline conn as idle") + } +} + +type readerConn struct { + fakeConn +} + +func (rc *readerConn) ReadFrom(r io.Reader) (int64, error) { + return 123, nil +} + +func TestReadFrom(t *testing.T) { + t.Parallel() + + l := makeFakeListener("net.Listener") + wl := Wrap(l, Manual) + c := &readerConn{ + fakeConn{ + read: make(chan struct{}), + write: make(chan struct{}), + closed: make(chan struct{}), + me: fakeAddr{"tcp", "local"}, + you: fakeAddr{"tcp", "remote"}, + }, + } + + go l.Enqueue(c) + wc, err := wl.Accept() + if err != nil { + t.Fatalf("error accepting connection: %v", err) + } + + // The io.MultiReader is a convenient hack to ensure that we're using + // our ReadFrom, not strings.Reader's WriteTo. + r := io.MultiReader(strings.NewReader("hello world")) + if _, err := io.Copy(wc, r); err != nil { + t.Fatalf("error copying: %v", err) + } +} diff --git a/graceful/listener/fake_test.go b/graceful/listener/fake_test.go new file mode 100644 index 0000000..083f6a8 --- /dev/null +++ b/graceful/listener/fake_test.go @@ -0,0 +1,123 @@ +package listener + +import ( + "net" + "time" +) + +type fakeAddr struct { + network, addr string +} + +func (f fakeAddr) Network() string { + return f.network +} +func (f fakeAddr) String() string { + return f.addr +} + +type fakeListener struct { + ch chan net.Conn + closed chan struct{} + addr net.Addr +} + +func makeFakeListener(addr string) *fakeListener { + a := fakeAddr{"tcp", addr} + return &fakeListener{ + ch: make(chan net.Conn), + closed: make(chan struct{}), + addr: a, + } +} + +func (f *fakeListener) Accept() (net.Conn, error) { + select { + case c := <-f.ch: + return c, nil + case <-f.closed: + return nil, errClosing + } +} +func (f *fakeListener) Close() error { + close(f.closed) + return nil +} + +func (f *fakeListener) Addr() net.Addr { + return f.addr +} + +func (f *fakeListener) Enqueue(c net.Conn) { + f.ch <- c +} + +type fakeConn struct { + read, write, closed chan struct{} + me, you net.Addr +} + +func makeFakeConn(me, you string) *fakeConn { + return &fakeConn{ + read: make(chan struct{}), + write: make(chan struct{}), + closed: make(chan struct{}), + me: fakeAddr{"tcp", me}, + you: fakeAddr{"tcp", you}, + } +} + +func (f *fakeConn) Read(buf []byte) (int, error) { + select { + case <-f.read: + return len(buf), nil + case <-f.closed: + return 0, errClosing + } +} + +func (f *fakeConn) Write(buf []byte) (int, error) { + select { + case <-f.write: + return len(buf), nil + case <-f.closed: + return 0, errClosing + } +} + +func (f *fakeConn) Close() error { + close(f.closed) + return nil +} + +func (f *fakeConn) LocalAddr() net.Addr { + return f.me +} +func (f *fakeConn) RemoteAddr() net.Addr { + return f.you +} +func (f *fakeConn) SetDeadline(t time.Time) error { + return nil +} +func (f *fakeConn) SetReadDeadline(t time.Time) error { + return nil +} +func (f *fakeConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (f *fakeConn) Closed() bool { + select { + case <-f.closed: + return true + default: + return false + } +} + +func (f *fakeConn) AllowRead() { + f.read <- struct{}{} +} +func (f *fakeConn) AllowWrite() { + f.write <- struct{}{} +} diff --git a/graceful/listener/listener.go b/graceful/listener/listener.go new file mode 100644 index 0000000..bebfb31 --- /dev/null +++ b/graceful/listener/listener.go @@ -0,0 +1,165 @@ +/* +Package listener provides a way to incorporate graceful shutdown to any +net.Listener. + +This package provides low-level primitives, not a high-level API. If you're +looking for a package that provides graceful shutdown for HTTP servers, I +recommend this package's parent package, github.com/zenazn/goji/graceful. +*/ +package listener + +import ( + "errors" + "net" + "runtime" + "sync" + "sync/atomic" +) + +type mode int8 + +const ( + // Manual mode is completely manual: users must use use MarkIdle and + // MarkInUse to indicate when connections are busy servicing requests or + // are eligible for termination. + Manual mode = iota + // Automatic mode is what most users probably want: calling Read on a + // connection will mark it as in use, but users must manually call + // MarkIdle to indicate when connections may be safely closed. + Automatic + // Deadline mode is like automatic mode, except that calling + // SetReadDeadline on a connection will also mark it as being idle. This + // is useful for many servers like net/http, where SetReadDeadline is + // used to implement read timeouts on new requests. + Deadline +) + +// Wrap a net.Listener, returning a net.Listener which supports idle connection +// tracking and shutdown. Listeners can be placed in to one of three modes, +// exported as variables from this package: most users will probably want the +// "Automatic" mode. +func Wrap(l net.Listener, m mode) *T { + t := &T{ + l: l, + mode: m, + // To keep the expected contention rate constant we'd have to + // grow this as numcpu**2. In practice, CPU counts don't + // generally grow without bound, and contention is probably + // going to be small enough that nobody cares anyways. + shards: make([]shard, 2*runtime.NumCPU()), + } + for i := range t.shards { + t.shards[i].init(t) + } + return t +} + +// T is the type of this package's graceful listeners. +type T struct { + mu sync.Mutex + l net.Listener + + // TODO(carl): a count of currently outstanding connections. + connCount uint64 + shards []shard + + mode mode +} + +var _ net.Listener = &T{} + +// Accept waits for and returns the next connection to the listener. The +// returned net.Conn's idleness is tracked, and idle connections can be closed +// from the associated T. +func (t *T) Accept() (net.Conn, error) { + c, err := t.l.Accept() + if err != nil { + return nil, err + } + + connID := atomic.AddUint64(&t.connCount, 1) + shard := &t.shards[int(connID)%len(t.shards)] + wc := &conn{ + Conn: c, + shard: shard, + mode: t.mode, + } + + if err = wc.init(); err != nil { + return nil, err + } + return wc, nil +} + +// Addr returns the wrapped listener's network address. +func (t *T) Addr() net.Addr { + return t.l.Addr() +} + +// Close closes the wrapped listener. +func (t *T) Close() error { + return t.l.Close() +} + +// CloseIdle closes all connections that are currently marked as being idle. It, +// however, makes no attempt to wait for in-use connections to die, or to close +// connections which become idle in the future. Call this function if you're +// interested in shedding useless connections, but otherwise wish to continue +// serving requests. +func (t *T) CloseIdle() error { + for i := range t.shards { + t.shards[i].closeIdle(false) + } + // Not sure if returning errors is actually useful here :/ + return nil +} + +// Drain immediately closes all idle connections, prevents new connections from +// being accepted, and waits for all outstanding connections to finish. +// +// Once a listener has been drained, there is no way to re-enable it. You +// probably want to Close the listener before draining it, otherwise new +// connections will be accepted and immediately closed. +func (t *T) Drain() error { + for i := range t.shards { + t.shards[i].closeIdle(true) + } + for i := range t.shards { + t.shards[i].wait() + } + return nil +} + +var notManagedErr = errors.New("listener: passed net.Conn is not managed by us") + +// Disown causes a connection to no longer be tracked by the listener. The +// passed connection must have been returned by a call to Accept from this +// listener. +func Disown(c net.Conn) error { + if cn, ok := c.(*conn); ok { + return cn.disown() + } + return notManagedErr +} + +// MarkIdle marks the given connection as being idle, and therefore eligible for +// closing at any time. The passed connection must have been returned by a call +// to Accept from this listener. +func MarkIdle(c net.Conn) error { + if cn, ok := c.(*conn); ok { + cn.markIdle() + return nil + } + return notManagedErr +} + +// MarkInUse marks this connection as being in use, removing it from the set of +// connections which are eligible for closing. The passed connection must have +// been returned by a call to Accept from this listener. +func MarkInUse(c net.Conn) error { + if cn, ok := c.(*conn); ok { + cn.markInUse() + return nil + } + return notManagedErr +} diff --git a/graceful/listener/listener_test.go b/graceful/listener/listener_test.go new file mode 100644 index 0000000..160434b --- /dev/null +++ b/graceful/listener/listener_test.go @@ -0,0 +1,143 @@ +package listener + +import ( + "net" + "testing" + "time" +) + +// Helper for tests acting on a single accepted connection +func singleConn(t *testing.T, m mode) (*T, *fakeConn, net.Conn) { + l := makeFakeListener("net.Listener") + wl := Wrap(l, m) + c := makeFakeConn("local", "remote") + + go l.Enqueue(c) + wc, err := wl.Accept() + if err != nil { + t.Fatalf("error accepting connection: %v", err) + } + return wl, c, wc +} + +func TestAddr(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Manual) + + if a := l.Addr(); a.String() != "net.Listener" { + t.Errorf("addr was %v, wanted net.Listener", a) + } + + if c.LocalAddr() != wc.LocalAddr() { + t.Errorf("local addresses don't match: %v, %v", c.LocalAddr(), + wc.LocalAddr()) + } + if c.RemoteAddr() != wc.RemoteAddr() { + t.Errorf("remote addresses don't match: %v, %v", c.RemoteAddr(), + wc.RemoteAddr()) + } +} + +func TestBasicCloseIdle(t *testing.T) { + t.Parallel() + l, c, _ := singleConn(t, Manual) + + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if !c.Closed() { + t.Error("idle connection not closed") + } +} + +func TestMark(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Manual) + + if err := MarkInUse(wc); err != nil { + t.Fatalf("error marking %v in-use: %v", wc, err) + } + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if c.Closed() { + t.Errorf("manually in-use connection was closed") + } + + if err := MarkIdle(wc); err != nil { + t.Fatalf("error marking %v idle: %v", wc, err) + } + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + if !c.Closed() { + t.Error("manually idle connection was not closed") + } +} + +func TestDisown(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Manual) + + if err := Disown(wc); err != nil { + t.Fatalf("error disowning connection: %v", err) + } + if err := l.CloseIdle(); err != nil { + t.Fatalf("error closing idle connections: %v", err) + } + + if c.Closed() { + t.Errorf("disowned connection got closed") + } +} + +func TestDrain(t *testing.T) { + t.Parallel() + l, _, wc := singleConn(t, Manual) + + MarkInUse(wc) + start := time.Now() + go func() { + time.Sleep(50 * time.Millisecond) + MarkIdle(wc) + }() + if err := l.Drain(); err != nil { + t.Fatalf("error draining listener: %v", err) + } + end := time.Now() + if dt := end.Sub(start); dt < 50*time.Millisecond { + t.Errorf("expected at least 50ms wait, but got %v", dt) + } +} + +func TestErrors(t *testing.T) { + t.Parallel() + _, c, wc := singleConn(t, Manual) + if err := Disown(c); err == nil { + t.Error("expected error when disowning unmanaged net.Conn") + } + if err := MarkIdle(c); err == nil { + t.Error("expected error when marking unmanaged net.Conn idle") + } + if err := MarkInUse(c); err == nil { + t.Error("expected error when marking unmanaged net.Conn in use") + } + + if err := Disown(wc); err != nil { + t.Fatalf("unexpected error disowning socket: %v", err) + } + if err := Disown(wc); err == nil { + t.Error("expected error disowning socket twice") + } +} + +func TestClose(t *testing.T) { + t.Parallel() + l, c, _ := singleConn(t, Manual) + if err := l.Close(); err != nil { + t.Fatalf("error while closing listener: %v", err) + } + if c.Closed() { + t.Error("connection closed when listener was?") + } +} diff --git a/graceful/listener/race_test.go b/graceful/listener/race_test.go new file mode 100644 index 0000000..835d6b4 --- /dev/null +++ b/graceful/listener/race_test.go @@ -0,0 +1,103 @@ +package listener + +import ( + "fmt" + "math/rand" + "runtime" + "sync/atomic" + "testing" + "time" +) + +func init() { + // Just to make sure we get some variety + runtime.GOMAXPROCS(4 * runtime.NumCPU()) +} + +// Chosen by random die roll +const seed = 4611413766552969250 + +// This is mostly just fuzzing to see what happens. +func TestRace(t *testing.T) { + t.Parallel() + + l := makeFakeListener("net.Listener") + wl := Wrap(l, Automatic) + + var flag int32 + + go func() { + for i := 0; ; i++ { + laddr := fmt.Sprintf("local%d", i) + raddr := fmt.Sprintf("remote%d", i) + c := makeFakeConn(laddr, raddr) + go func() { + defer func() { + if r := recover(); r != nil { + if atomic.LoadInt32(&flag) != 0 { + return + } + panic(r) + } + }() + l.Enqueue(c) + }() + wc, err := wl.Accept() + if err != nil { + if atomic.LoadInt32(&flag) != 0 { + return + } + t.Fatalf("error accepting connection: %v", err) + } + + go func() { + for { + time.Sleep(50 * time.Millisecond) + c.AllowRead() + } + }() + + go func(i int64) { + rng := rand.New(rand.NewSource(i + seed)) + buf := make([]byte, 1024) + for j := 0; j < 1024; j++ { + if _, err := wc.Read(buf); err != nil { + if atomic.LoadInt32(&flag) != 0 { + // Peaceful; the connection has + // probably been closed while + // idle + return + } + t.Errorf("error reading in conn %d: %v", + i, err) + } + time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond) + // This one is to make sure the connection + // hasn't closed underneath us + if _, err := wc.Read(buf); err != nil { + t.Errorf("error reading in conn %d: %v", + i, err) + } + MarkIdle(wc) + time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond) + } + }(int64(i)) + + time.Sleep(time.Duration(i) * time.Millisecond / 2) + } + }() + + if testing.Short() { + time.Sleep(2 * time.Second) + } else { + time.Sleep(10 * time.Second) + } + start := time.Now() + atomic.StoreInt32(&flag, 1) + wl.Close() + wl.Drain() + end := time.Now() + if dt := end.Sub(start); dt > 300*time.Millisecond { + t.Errorf("took %v to drain; expected shorter", dt) + } +} diff --git a/graceful/listener/shard.go b/graceful/listener/shard.go new file mode 100644 index 0000000..e47deac --- /dev/null +++ b/graceful/listener/shard.go @@ -0,0 +1,71 @@ +package listener + +import "sync" + +type shard struct { + l *T + + mu sync.Mutex + set map[*conn]struct{} + wg sync.WaitGroup + drain bool + + // We pack shards together in an array, but we don't want them packed + // too closely, since we want to give each shard a dedicated CPU cache + // line. This amount of padding works out well for the common case of + // x64 processors (64-bit pointers with a 64-byte cache line). + _ [12]byte +} + +// We pretty aggressively preallocate set entries in the hopes that we never +// have to allocate memory with the lock held. This is definitely a premature +// optimization and is probably misguided, but luckily it costs us essentially +// nothing. +const prealloc = 2048 + +func (s *shard) init(l *T) { + s.l = l + s.set = make(map[*conn]struct{}, prealloc) +} + +func (s *shard) markIdle(c *conn) (shouldClose bool) { + s.mu.Lock() + if s.drain { + s.mu.Unlock() + return true + } + s.set[c] = struct{}{} + s.mu.Unlock() + return false +} + +func (s *shard) markInUse(c *conn) { + s.mu.Lock() + delete(s.set, c) + s.mu.Unlock() +} + +func (s *shard) closeIdle(drain bool) { + s.mu.Lock() + if drain { + s.drain = true + } + set := s.set + s.set = make(map[*conn]struct{}, prealloc) + // We have to drop the shard lock here to avoid deadlock: we cannot + // acquire the shard lock after the connection lock, and the closeIfIdle + // call below will grab a connection lock. + s.mu.Unlock() + + for conn := range set { + // This might return an error (from Close), but I don't think we + // can do anything about it, so let's just pretend it didn't + // happen. (I also expect that most errors returned in this way + // are going to be pretty boring) + conn.closeIfIdle() + } +} + +func (s *shard) wait() { + s.wg.Wait() +} diff --git a/graceful/middleware.go b/graceful/middleware.go index 4aace7d..7f1371e 100644 --- a/graceful/middleware.go +++ b/graceful/middleware.go @@ -7,6 +7,9 @@ import ( "io" "net" "net/http" + "sync/atomic" + + "github.com/zenazn/goji/graceful/listener" ) /* @@ -62,10 +65,8 @@ type basicWriter struct { func (b *basicWriter) maybeClose() { b.headerWritten = true - select { - case <-kill: + if atomic.LoadInt32(&closing) != 0 { b.ResponseWriter.Header().Set("Connection", "close") - default: } } @@ -103,8 +104,8 @@ func (f *fancyWriter) Hijack() (c net.Conn, b *bufio.ReadWriter, e error) { hj := f.basicWriter.ResponseWriter.(http.Hijacker) c, b, e = hj.Hijack() - if conn, ok := c.(*conn); ok { - c = conn.hijack() + if e == nil { + e = listener.Disown(c) } return diff --git a/graceful/middleware_test.go b/graceful/middleware_test.go index 15b4fcc..fe336f4 100644 --- a/graceful/middleware_test.go +++ b/graceful/middleware_test.go @@ -4,6 +4,7 @@ package graceful import ( "net/http" + "sync/atomic" "testing" ) @@ -36,7 +37,7 @@ func testClose(t *testing.T, h http.Handler, expectClose bool) { } func TestNormal(t *testing.T) { - kill = make(chan struct{}) + atomic.StoreInt32(&closing, 0) h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte{}) }) @@ -44,26 +45,26 @@ func TestNormal(t *testing.T) { } func TestClose(t *testing.T) { - kill = make(chan struct{}) + atomic.StoreInt32(&closing, 0) h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(kill) + atomic.StoreInt32(&closing, 1) }) testClose(t, h, true) } func TestCloseWriteHeader(t *testing.T) { - kill = make(chan struct{}) + atomic.StoreInt32(&closing, 0) h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(kill) + atomic.StoreInt32(&closing, 1) w.WriteHeader(200) }) testClose(t, h, true) } func TestCloseWrite(t *testing.T) { - kill = make(chan struct{}) + atomic.StoreInt32(&closing, 0) h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(kill) + atomic.StoreInt32(&closing, 1) w.Write([]byte{}) }) testClose(t, h, true) diff --git a/graceful/net.go b/graceful/net.go deleted file mode 100644 index 784003d..0000000 --- a/graceful/net.go +++ /dev/null @@ -1,185 +0,0 @@ -package graceful - -import ( - "io" - "net" - "sync" - "sync/atomic" - "time" -) - -type listener struct { - net.Listener -} - -// WrapListener wraps an arbitrary net.Listener for use with graceful shutdowns. -// All net.Conn's Accept()ed by this listener will be auto-wrapped as if -// WrapConn() were called on them. -func WrapListener(l net.Listener) net.Listener { - return listener{l} -} - -func (l listener) Accept() (net.Conn, error) { - conn, err := l.Listener.Accept() - return WrapConn(conn), err -} - -type conn struct { - mu sync.Mutex - cs *connSet - net.Conn - id uint64 - busy, die bool - dead bool - hijacked bool -} - -/* -WrapConn wraps an arbitrary connection for use with graceful shutdowns. The -graceful shutdown process will ensure that this connection is closed before -terminating the process. - -In order to use this function, you must call SetReadDeadline() before the call -to Read() you might make to read a new request off the wire. The connection is -eligible for abrupt closing at any point between when the call to -SetReadDeadline() returns and when the call to Read returns with new data. It -does not matter what deadline is given to SetReadDeadline()--if a deadline is -inappropriate, providing one extremely far into the future will suffice. - -Unfortunately, this means that it's difficult to use SetReadDeadline() in a -great many perfectly reasonable circumstances, such as to extend a deadline -after more data has been read, without the connection being eligible for -"graceful" termination at an undesirable time. Since this package was written -explicitly to target net/http, which does not as of this writing do any of this, -fixing the semantics here does not seem especially urgent. -*/ -func WrapConn(c net.Conn) net.Conn { - if c == nil { - return nil - } - - // Avoid race with termination code. - wgLock.Lock() - defer wgLock.Unlock() - - // Determine whether the app is shutting down. - if acceptingRequests { - wg.Add(1) - return &conn{ - Conn: c, - id: atomic.AddUint64(&idleSet.id, 1), - } - } else { - return nil - } -} - -func (c *conn) Read(b []byte) (n int, err error) { - c.mu.Lock() - if !c.hijacked { - defer func() { - c.mu.Lock() - if c.hijacked { - // It's a little unclear to me how this case - // would happen, but we *did* drop the lock, so - // let's play it safe. - return - } - - if c.dead { - // Dead sockets don't tell tales. This is to - // prevent the case where a Read manages to suck - // an entire request off the wire in a race with - // someone trying to close idle connections. - // Whoever grabs the conn lock first wins, and - // if that's the closing process, we need to - // "take back" the read. - n = 0 - err = io.EOF - } else { - idleSet.markBusy(c) - } - c.mu.Unlock() - }() - } - c.mu.Unlock() - - return c.Conn.Read(b) -} - -func (c *conn) SetReadDeadline(t time.Time) error { - c.mu.Lock() - if !c.hijacked { - defer c.markIdle() - } - c.mu.Unlock() - return c.Conn.SetReadDeadline(t) -} - -func (c *conn) Close() error { - kill := false - c.mu.Lock() - kill, c.dead = !c.dead, true - idleSet.markBusy(c) - c.mu.Unlock() - - if kill { - defer wg.Done() - } - return c.Conn.Close() -} - -type writerOnly struct { - w io.Writer -} - -func (w writerOnly) Write(buf []byte) (int, error) { - return w.w.Write(buf) -} - -func (c *conn) ReadFrom(r io.Reader) (int64, error) { - if rf, ok := c.Conn.(io.ReaderFrom); ok { - return rf.ReadFrom(r) - } - return io.Copy(writerOnly{c}, r) -} - -func (c *conn) markIdle() { - kill := false - c.mu.Lock() - idleSet.markIdle(c) - if c.die { - kill, c.dead = !c.dead, true - } - c.mu.Unlock() - - if kill { - defer wg.Done() - c.Conn.Close() - } - -} - -func (c *conn) closeIfIdle() { - kill := false - c.mu.Lock() - c.die = true - if !c.busy && !c.hijacked { - kill, c.dead = !c.dead, true - } - c.mu.Unlock() - - if kill { - defer wg.Done() - c.Conn.Close() - } -} - -func (c *conn) hijack() net.Conn { - c.mu.Lock() - idleSet.markBusy(c) - c.hijacked = true - c.mu.Unlock() - - return c.Conn -} diff --git a/graceful/race_test.go b/graceful/race_test.go deleted file mode 100644 index cd48bae..0000000 --- a/graceful/race_test.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build race - -package graceful - -import "testing" - -func TestWaitGroupRace(t *testing.T) { - go func() { - go WrapConn(fakeConn{}).Close() - }() - Shutdown() -} diff --git a/graceful/serve.go b/graceful/serve.go index 626f807..f4ed19f 100644 --- a/graceful/serve.go +++ b/graceful/serve.go @@ -6,19 +6,14 @@ import ( "net" "net/http" "time" + + "github.com/zenazn/goji/graceful/listener" ) // About 200 years, also known as "forever" const forever time.Duration = 200 * 365 * 24 * time.Hour func (srv *Server) Serve(l net.Listener) error { - go func() { - <-kill - l.Close() - idleSet.killall() - }() - l = WrapListener(l) - // Spawn a shadow http.Server to do the actual servering. We do this // because we need to sketch on some of the parameters you passed in, // and it's nice to keep our sketching to ourselves. @@ -29,14 +24,9 @@ func (srv *Server) Serve(l net.Listener) error { } shadow.Handler = Middleware(shadow.Handler) - err := shadow.Serve(l) + wrap := listener.Wrap(l, listener.Deadline) + appendListener(wrap) - // We expect an error when we close the listener, so we indiscriminately - // swallow Serve errors when we're in a shutdown state. - select { - case <-kill: - return nil - default: - return err - } + err := shadow.Serve(wrap) + return peacefulError(err) } diff --git a/graceful/serve13.go b/graceful/serve13.go index e0acd18..5f5e4ed 100644 --- a/graceful/serve13.go +++ b/graceful/serve13.go @@ -3,67 +3,75 @@ package graceful import ( + "log" "net" "net/http" + + "github.com/zenazn/goji/graceful/listener" ) -func (srv *Server) Serve(l net.Listener) error { - l = WrapListener(l) +// This is a slightly hacky shim to disable keepalives when shutting a server +// down. We could have added extra functionality in listener or signal.go to +// deal with this case, but this seems simpler. +type gracefulServer struct { + net.Listener + s *http.Server +} - // Spawn a shadow http.Server to do the actual servering. We do this - // because we need to sketch on some of the parameters you passed in, - // and it's nice to keep our sketching to ourselves. - shadow := *(*http.Server)(srv) +func (g gracefulServer) Close() error { + g.s.SetKeepAlivesEnabled(false) + return g.Listener.Close() +} + +// A chaining http.ConnState wrapper +type connState func(net.Conn, http.ConnState) - cs := shadow.ConnState - shadow.ConnState = func(nc net.Conn, s http.ConnState) { - if c, ok := nc.(*conn); ok { - // There are a few other states defined, most notably - // StateActive. Unfortunately it doesn't look like it's - // possible to make use of StateActive to implement - // graceful shutdown, since StateActive is set after a - // complete request has been read off the wire with an - // intent to process it. If we were to race a graceful - // shutdown against a connection that was just read off - // the wire (but not yet in StateActive), we would - // accidentally close the connection out from underneath - // an active request. - // - // We already needed to work around this for Go 1.2 by - // shimming out a full net.Conn object, so we can just - // fall back to the old behavior there. - // - // I started a golang-nuts thread about this here: - // https://groups.google.com/forum/#!topic/golang-nuts/Xi8yjBGWfCQ - // I'd be very eager to find a better way to do this, so - // reach out to me if you have any ideas. - switch s { - case http.StateIdle: - c.markIdle() - case http.StateHijacked: - c.hijack() - } +func (c connState) Wrap(nc net.Conn, s http.ConnState) { + // There are a few other states defined, most notably StateActive. + // Unfortunately it doesn't look like it's possible to make use of + // StateActive to implement graceful shutdown, since StateActive is set + // after a complete request has been read off the wire with an intent to + // process it. If we were to race a graceful shutdown against a + // connection that was just read off the wire (but not yet in + // StateActive), we would accidentally close the connection out from + // underneath an active request. + // + // We already needed to work around this for Go 1.2 by shimming out a + // full net.Conn object, so we can just fall back to the old behavior + // there. + // + // I started a golang-nuts thread about this here: + // https://groups.google.com/forum/#!topic/golang-nuts/Xi8yjBGWfCQ I'd + // be very eager to find a better way to do this, so reach out to me if + // you have any ideas. + switch s { + case http.StateIdle: + if err := listener.MarkIdle(nc); err != nil { + log.Printf("error marking conn as idle: %v", + err) } - if cs != nil { - cs(nc, s) + case http.StateHijacked: + if err := listener.Disown(nc); err != nil { + log.Printf("error disowning hijacked conn: %v", + err) } } + if c != nil { + c(nc, s) + } +} - go func() { - <-kill - l.Close() - shadow.SetKeepAlivesEnabled(false) - idleSet.killall() - }() +func (srv *Server) Serve(l net.Listener) error { + // Spawn a shadow http.Server to do the actual servering. We do this + // because we need to sketch on some of the parameters you passed in, + // and it's nice to keep our sketching to ourselves. + shadow := *(*http.Server)(srv) + shadow.ConnState = connState(shadow.ConnState).Wrap - err := shadow.Serve(l) + l = gracefulServer{l, &shadow} + wrap := listener.Wrap(l, listener.Automatic) + appendListener(wrap) - // We expect an error when we close the listener, so we indiscriminately - // swallow Serve errors when we're in a shutdown state. - select { - case <-kill: - return nil - default: - return err - } + err := shadow.Serve(wrap) + return peacefulError(err) } diff --git a/graceful/signal.go b/graceful/signal.go index 82a13d6..6bd3d84 100644 --- a/graceful/signal.go +++ b/graceful/signal.go @@ -1,32 +1,22 @@ package graceful import ( + "net" "os" "os/signal" "sync" -) - -// This is the channel that the connections select on. When it is closed, the -// connections should gracefully exit. -var kill = make(chan struct{}) - -// This is the channel that the Wait() function selects on. It should only be -// closed once all the posthooks have been called. -var wait = make(chan struct{}) - -// Whether new requests should be accepted. When false, new requests are refused. -var acceptingRequests bool = true + "sync/atomic" -// This is the WaitGroup that indicates when all the connections have gracefully -// shut down. -var wg sync.WaitGroup -var wgLock sync.Mutex + "github.com/zenazn/goji/graceful/listener" +) -// This lock protects the list of pre- and post- hooks below. -var hookLock sync.Mutex +var mu sync.Mutex // protects everything that follows +var listeners = make([]*listener.T, 0) var prehooks = make([]func(), 0) var posthooks = make([]func(), 0) +var closing int32 +var wait = make(chan struct{}) var stdSignals = []os.Signal{os.Interrupt} var sigchan = make(chan os.Signal, 1) @@ -71,8 +61,8 @@ func Shutdown() { // shutdown actions. All listeners will be called in the order they were added, // from a single goroutine. func PreHook(f func()) { - hookLock.Lock() - defer hookLock.Unlock() + mu.Lock() + defer mu.Unlock() prehooks = append(prehooks, f) } @@ -82,12 +72,12 @@ func PreHook(f func()) { // from a single goroutine, and are guaranteed to be called after all listening // connections have been closed, but before Wait() returns. // -// If you've Hijack()ed any connections that must be gracefully shut down in -// some other way (since this library disowns all hijacked connections), it's -// reasonable to use a PostHook() to signal and wait for them. +// If you've Hijacked any connections that must be gracefully shut down in some +// other way (since this library disowns all hijacked connections), it's +// reasonable to use a PostHook to signal and wait for them. func PostHook(f func()) { - hookLock.Lock() - defer hookLock.Unlock() + mu.Lock() + defer mu.Unlock() posthooks = append(posthooks, f) } @@ -95,19 +85,23 @@ func PostHook(f func()) { func waitForSignal() { <-sigchan - // Prevent servicing of any new requests. - wgLock.Lock() - acceptingRequests = false - wgLock.Unlock() - - hookLock.Lock() - defer hookLock.Unlock() + mu.Lock() + defer mu.Unlock() for _, f := range prehooks { f() } - close(kill) + atomic.StoreInt32(&closing, 1) + var wg sync.WaitGroup + wg.Add(len(listeners)) + for _, l := range listeners { + go func(l *listener.T) { + defer wg.Done() + l.Close() + l.Drain() + }(l) + } wg.Wait() for _, f := range posthooks { @@ -123,3 +117,29 @@ func waitForSignal() { func Wait() { <-wait } + +func appendListener(l *listener.T) { + mu.Lock() + defer mu.Unlock() + + listeners = append(listeners, l) +} + +const errClosing = "use of closed network connection" + +// During graceful shutdown, calls to Accept will start returning errors. This +// is inconvenient, since we know these sorts of errors are peaceful, so we +// silently swallow them. +func peacefulError(err error) error { + if atomic.LoadInt32(&closing) == 0 { + return err + } + // Unfortunately Go doesn't really give us a better way to select errors + // than this, so *shrug*. + if oe, ok := err.(*net.OpError); ok { + if oe.Op == "accept" && oe.Err.Error() == errClosing { + return nil + } + } + return err +}