Browse Source

Refactor package graceful

This change refactors package graceful into two packages: one very well
tested package that deals with graceful shutdown of arbitrary
net.Listeners in the abstract, and one less-well-tested package that
works with the nitty-gritty details of net/http and signal handling.

This is a breaking API change for advanced users of package graceful:
the WrapConn function no longer exists. This shouldn't affect most users
or use cases.
Carl Jackson 11 years ago
parent
commit
acf12155c0
17 changed files with 1086 additions and 475 deletions
  1. +0
    -80
      graceful/conn_set.go
  2. +0
    -63
      graceful/conn_test.go
  3. +20
    -23
      graceful/graceful.go
  4. +147
    -0
      graceful/listener/conn.go
  5. +182
    -0
      graceful/listener/conn_test.go
  6. +123
    -0
      graceful/listener/fake_test.go
  7. +165
    -0
      graceful/listener/listener.go
  8. +143
    -0
      graceful/listener/listener_test.go
  9. +103
    -0
      graceful/listener/race_test.go
  10. +71
    -0
      graceful/listener/shard.go
  11. +6
    -5
      graceful/middleware.go
  12. +8
    -7
      graceful/middleware_test.go
  13. +0
    -185
      graceful/net.go
  14. +0
    -12
      graceful/race_test.go
  15. +6
    -16
      graceful/serve.go
  16. +59
    -51
      graceful/serve13.go
  17. +53
    -33
      graceful/signal.go

+ 0
- 80
graceful/conn_set.go View File

@ -1,80 +0,0 @@
package graceful
import (
"runtime"
"sync"
)
type connShard struct {
mu sync.Mutex
// We sort of abuse this field to also act as a "please shut down" flag.
// If it's nil, you should die at your earliest opportunity.
set map[*conn]struct{}
}
type connSet struct {
// This is an incrementing connection counter so we round-robin
// connections across shards. Use atomic when touching it.
id uint64
shards []*connShard
}
var idleSet connSet
// We pretty aggressively preallocate set entries in the hopes that we never
// have to allocate memory with the lock held. This is definitely a premature
// optimization and is probably misguided, but luckily it costs us essentially
// nothing.
const prealloc = 2048
func init() {
// To keep the expected contention rate constant we'd have to grow this
// as numcpu**2. In practice, CPU counts don't generally grow without
// bound, and contention is probably going to be small enough that
// nobody cares anyways.
idleSet.shards = make([]*connShard, 2*runtime.NumCPU())
for i := range idleSet.shards {
idleSet.shards[i] = &connShard{
set: make(map[*conn]struct{}, prealloc),
}
}
}
func (cs connSet) markIdle(c *conn) {
c.busy = false
shard := cs.shards[int(c.id%uint64(len(cs.shards)))]
shard.mu.Lock()
if shard.set == nil {
shard.mu.Unlock()
c.die = true
} else {
shard.set[c] = struct{}{}
shard.mu.Unlock()
}
}
func (cs connSet) markBusy(c *conn) {
c.busy = true
shard := cs.shards[int(c.id%uint64(len(cs.shards)))]
shard.mu.Lock()
if shard.set == nil {
shard.mu.Unlock()
c.die = true
} else {
delete(shard.set, c)
shard.mu.Unlock()
}
}
func (cs connSet) killall() {
for _, shard := range cs.shards {
shard.mu.Lock()
set := shard.set
shard.set = nil
shard.mu.Unlock()
for conn := range set {
conn.closeIfIdle()
}
}
}

+ 0
- 63
graceful/conn_test.go View File

@ -1,63 +0,0 @@
package graceful
import (
"net"
"time"
)
// Stub out a net.Conn. This is going to be painful.
type fakeAddr struct{}
func (f fakeAddr) Network() string {
return "fake"
}
func (f fakeAddr) String() string {
return "fake"
}
type fakeConn struct {
onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func()
onSetDeadline, onSetReadDeadline, onSetWriteDeadline func()
}
// Here's my number, so...
func callMeMaybe(f func()) {
// I apologize for nothing.
if f != nil {
f()
}
}
func (f fakeConn) Read(b []byte) (int, error) {
callMeMaybe(f.onRead)
return len(b), nil
}
func (f fakeConn) Write(b []byte) (int, error) {
callMeMaybe(f.onWrite)
return len(b), nil
}
func (f fakeConn) Close() error {
callMeMaybe(f.onClose)
return nil
}
func (f fakeConn) LocalAddr() net.Addr {
callMeMaybe(f.onLocalAddr)
return fakeAddr{}
}
func (f fakeConn) RemoteAddr() net.Addr {
callMeMaybe(f.onRemoteAddr)
return fakeAddr{}
}
func (f fakeConn) SetDeadline(t time.Time) error {
callMeMaybe(f.onSetDeadline)
return nil
}
func (f fakeConn) SetReadDeadline(t time.Time) error {
callMeMaybe(f.onSetReadDeadline)
return nil
}
func (f fakeConn) SetWriteDeadline(t time.Time) error {
callMeMaybe(f.onSetWriteDeadline)
return nil
}

+ 20
- 23
graceful/graceful.go View File

@ -3,18 +3,7 @@ Package graceful implements graceful shutdown for HTTP servers by closing idle
connections after receiving a signal. By default, this package listens for connections after receiving a signal. By default, this package listens for
interrupts (i.e., SIGINT), but when it detects that it is running under Einhorn interrupts (i.e., SIGINT), but when it detects that it is running under Einhorn
it will additionally listen for SIGUSR2 as well, giving your application it will additionally listen for SIGUSR2 as well, giving your application
automatic support for graceful upgrades.
It's worth mentioning explicitly that this package is a hack to shim graceful
shutdown behavior into the net/http package provided in Go 1.2. It was written
by carefully reading the sequence of function calls net/http happened to use as
of this writing and finding enough surface area with which to add appropriate
behavior. There's a very good chance that this package will cease to work in
future versions of Go, but with any luck the standard library will add support
of its own by then (https://code.google.com/p/go/issues/detail?id=4674).
If you're interested in figuring out how this package works, we suggest you read
the documentation for WrapConn() and net.go.
automatic support for graceful restarts/code upgrades.
*/ */
package graceful package graceful
@ -22,19 +11,11 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
)
/*
You might notice that these methods look awfully similar to the methods of the
same name from the go standard library--that's because they were stolen from
there! If go were more like, say, Ruby, it'd actually be possible to shim just
the Serve() method, since we can do everything we want from there. However, it's
not possible to get the other methods which call Serve() (ListenAndServe(), say)
to call your shimmed copy--they always call the original.
"github.com/zenazn/goji/graceful/listener"
)
Since I couldn't come up with a better idea, I just copy-and-pasted both
ListenAndServe and ListenAndServeTLS here more-or-less verbatim. "Oh well!"
*/
// Most of the code here is lifted straight from net/http
// Type Server is exactly the same as an http.Server, but provides more graceful // Type Server is exactly the same as an http.Server, but provides more graceful
// implementations of its methods. // implementations of its methods.
@ -98,3 +79,19 @@ func Serve(l net.Listener, handler http.Handler) error {
server := &Server{Handler: handler} server := &Server{Handler: handler}
return server.Serve(l) return server.Serve(l)
} }
// WrapListener wraps an arbitrary net.Listener for use with graceful shutdowns.
// In the background, it uses the listener sub-package to Wrap the listener in
// Deadline mode. If another mode of operation is desired, you should call
// listener.Wrap yourself: this function is smart enough to not double-wrap
// listeners.
func WrapListener(l net.Listener) net.Listener {
if lt, ok := l.(*listener.T); ok {
appendListener(lt)
return lt
}
lt := listener.Wrap(l, listener.Deadline)
appendListener(lt)
return lt
}

+ 147
- 0
graceful/listener/conn.go View File

@ -0,0 +1,147 @@
package listener
import (
"errors"
"io"
"net"
"sync"
"time"
)
type conn struct {
net.Conn
shard *shard
mode mode
mu sync.Mutex // Protects the state machine below
busy bool // connection is in use (i.e., not idle)
closed bool // connection is closed
disowned bool // if true, this connection is no longer under our management
}
// This intentionally looks a lot like the one in package net.
var errClosing = errors.New("use of closed network connection")
func (c *conn) init() error {
c.shard.wg.Add(1)
if shouldExit := c.shard.markIdle(c); shouldExit {
c.Close()
return errClosing
}
return nil
}
func (c *conn) Read(b []byte) (n int, err error) {
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.disowned {
return
}
// This protects against a Close/Read race. We're not really
// concerned about the general case (it's fundamentally racy),
// but are mostly trying to prevent a race between a new request
// getting read off the wire in one thread while the connection
// is being gracefully shut down in another.
if c.closed && err == nil {
n = 0
err = errClosing
return
}
if c.mode != Manual && !c.busy && !c.closed {
c.busy = true
c.shard.markInUse(c)
}
}()
return c.Conn.Read(b)
}
func (c *conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.closed && !c.disowned {
c.closed = true
c.shard.markInUse(c)
defer c.shard.wg.Done()
}
return c.Conn.Close()
}
func (c *conn) SetReadDeadline(t time.Time) error {
c.mu.Lock()
if !c.disowned && c.mode == Deadline {
defer c.markIdle()
}
c.mu.Unlock()
return c.Conn.SetReadDeadline(t)
}
func (c *conn) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(c.Conn, r)
}
func (c *conn) markIdle() {
c.mu.Lock()
defer c.mu.Unlock()
if !c.busy {
return
}
c.busy = false
if exit := c.shard.markIdle(c); exit && !c.closed && !c.disowned {
c.closed = true
c.shard.markInUse(c)
defer c.shard.wg.Done()
c.Conn.Close()
return
}
}
func (c *conn) markInUse() {
c.mu.Lock()
defer c.mu.Unlock()
if !c.busy && !c.closed && !c.disowned {
c.busy = true
c.shard.markInUse(c)
}
}
func (c *conn) closeIfIdle() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.busy && !c.closed && !c.disowned {
c.closed = true
c.shard.markInUse(c)
defer c.shard.wg.Done()
return c.Conn.Close()
}
return nil
}
var errAlreadyDisowned = errors.New("listener: conn already disowned")
func (c *conn) disown() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.disowned {
return errAlreadyDisowned
}
c.shard.markInUse(c)
c.disowned = true
c.shard.wg.Done()
return nil
}

+ 182
- 0
graceful/listener/conn_test.go View File

@ -0,0 +1,182 @@
package listener
import (
"io"
"strings"
"testing"
"time"
)
func TestManualRead(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Manual)
go c.AllowRead()
wc.Read(make([]byte, 1024))
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if !c.Closed() {
t.Error("Read() should not make connection not-idle")
}
}
func TestAutomaticRead(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Automatic)
go c.AllowRead()
wc.Read(make([]byte, 1024))
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if c.Closed() {
t.Error("expected Read() to mark connection as in-use")
}
}
func TestDeadlineRead(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Deadline)
go c.AllowRead()
if _, err := wc.Read(make([]byte, 1024)); err != nil {
t.Fatalf("error reading from connection: %v", err)
}
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if c.Closed() {
t.Error("expected Read() to mark connection as in-use")
}
}
func TestDisownedRead(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Deadline)
if err := Disown(wc); err != nil {
t.Fatalf("unexpected error disowning conn: %v", err)
}
if err := l.Close(); err != nil {
t.Fatalf("unexpected error closing listener: %v", err)
}
if err := l.Drain(); err != nil {
t.Fatalf("unexpected error draining listener: %v", err)
}
go c.AllowRead()
if _, err := wc.Read(make([]byte, 1024)); err != nil {
t.Fatalf("error reading from connection: %v", err)
}
}
func TestCloseConn(t *testing.T) {
t.Parallel()
l, _, wc := singleConn(t, Deadline)
if err := MarkInUse(wc); err != nil {
t.Fatalf("error marking conn in use: %v", err)
}
if err := wc.Close(); err != nil {
t.Errorf("error closing connection: %v", err)
}
// This will hang if wc.Close() doesn't un-track the connection
if err := l.Drain(); err != nil {
t.Errorf("error draining listener: %v", err)
}
}
func TestManualReadDeadline(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Manual)
if err := MarkInUse(wc); err != nil {
t.Fatalf("error marking connection in use: %v", err)
}
if err := wc.SetReadDeadline(time.Now()); err != nil {
t.Fatalf("error setting read deadline: %v", err)
}
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if c.Closed() {
t.Error("SetReadDeadline() should not mark manual conn as idle")
}
}
func TestAutomaticReadDeadline(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Automatic)
if err := MarkInUse(wc); err != nil {
t.Fatalf("error marking connection in use: %v", err)
}
if err := wc.SetReadDeadline(time.Now()); err != nil {
t.Fatalf("error setting read deadline: %v", err)
}
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if c.Closed() {
t.Error("SetReadDeadline() should not mark automatic conn as idle")
}
}
func TestDeadlineReadDeadline(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Deadline)
if err := MarkInUse(wc); err != nil {
t.Fatalf("error marking connection in use: %v", err)
}
if err := wc.SetReadDeadline(time.Now()); err != nil {
t.Fatalf("error setting read deadline: %v", err)
}
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if !c.Closed() {
t.Error("SetReadDeadline() should mark deadline conn as idle")
}
}
type readerConn struct {
fakeConn
}
func (rc *readerConn) ReadFrom(r io.Reader) (int64, error) {
return 123, nil
}
func TestReadFrom(t *testing.T) {
t.Parallel()
l := makeFakeListener("net.Listener")
wl := Wrap(l, Manual)
c := &readerConn{
fakeConn{
read: make(chan struct{}),
write: make(chan struct{}),
closed: make(chan struct{}),
me: fakeAddr{"tcp", "local"},
you: fakeAddr{"tcp", "remote"},
},
}
go l.Enqueue(c)
wc, err := wl.Accept()
if err != nil {
t.Fatalf("error accepting connection: %v", err)
}
// The io.MultiReader is a convenient hack to ensure that we're using
// our ReadFrom, not strings.Reader's WriteTo.
r := io.MultiReader(strings.NewReader("hello world"))
if _, err := io.Copy(wc, r); err != nil {
t.Fatalf("error copying: %v", err)
}
}

+ 123
- 0
graceful/listener/fake_test.go View File

@ -0,0 +1,123 @@
package listener
import (
"net"
"time"
)
type fakeAddr struct {
network, addr string
}
func (f fakeAddr) Network() string {
return f.network
}
func (f fakeAddr) String() string {
return f.addr
}
type fakeListener struct {
ch chan net.Conn
closed chan struct{}
addr net.Addr
}
func makeFakeListener(addr string) *fakeListener {
a := fakeAddr{"tcp", addr}
return &fakeListener{
ch: make(chan net.Conn),
closed: make(chan struct{}),
addr: a,
}
}
func (f *fakeListener) Accept() (net.Conn, error) {
select {
case c := <-f.ch:
return c, nil
case <-f.closed:
return nil, errClosing
}
}
func (f *fakeListener) Close() error {
close(f.closed)
return nil
}
func (f *fakeListener) Addr() net.Addr {
return f.addr
}
func (f *fakeListener) Enqueue(c net.Conn) {
f.ch <- c
}
type fakeConn struct {
read, write, closed chan struct{}
me, you net.Addr
}
func makeFakeConn(me, you string) *fakeConn {
return &fakeConn{
read: make(chan struct{}),
write: make(chan struct{}),
closed: make(chan struct{}),
me: fakeAddr{"tcp", me},
you: fakeAddr{"tcp", you},
}
}
func (f *fakeConn) Read(buf []byte) (int, error) {
select {
case <-f.read:
return len(buf), nil
case <-f.closed:
return 0, errClosing
}
}
func (f *fakeConn) Write(buf []byte) (int, error) {
select {
case <-f.write:
return len(buf), nil
case <-f.closed:
return 0, errClosing
}
}
func (f *fakeConn) Close() error {
close(f.closed)
return nil
}
func (f *fakeConn) LocalAddr() net.Addr {
return f.me
}
func (f *fakeConn) RemoteAddr() net.Addr {
return f.you
}
func (f *fakeConn) SetDeadline(t time.Time) error {
return nil
}
func (f *fakeConn) SetReadDeadline(t time.Time) error {
return nil
}
func (f *fakeConn) SetWriteDeadline(t time.Time) error {
return nil
}
func (f *fakeConn) Closed() bool {
select {
case <-f.closed:
return true
default:
return false
}
}
func (f *fakeConn) AllowRead() {
f.read <- struct{}{}
}
func (f *fakeConn) AllowWrite() {
f.write <- struct{}{}
}

+ 165
- 0
graceful/listener/listener.go View File

@ -0,0 +1,165 @@
/*
Package listener provides a way to incorporate graceful shutdown to any
net.Listener.
This package provides low-level primitives, not a high-level API. If you're
looking for a package that provides graceful shutdown for HTTP servers, I
recommend this package's parent package, github.com/zenazn/goji/graceful.
*/
package listener
import (
"errors"
"net"
"runtime"
"sync"
"sync/atomic"
)
type mode int8
const (
// Manual mode is completely manual: users must use use MarkIdle and
// MarkInUse to indicate when connections are busy servicing requests or
// are eligible for termination.
Manual mode = iota
// Automatic mode is what most users probably want: calling Read on a
// connection will mark it as in use, but users must manually call
// MarkIdle to indicate when connections may be safely closed.
Automatic
// Deadline mode is like automatic mode, except that calling
// SetReadDeadline on a connection will also mark it as being idle. This
// is useful for many servers like net/http, where SetReadDeadline is
// used to implement read timeouts on new requests.
Deadline
)
// Wrap a net.Listener, returning a net.Listener which supports idle connection
// tracking and shutdown. Listeners can be placed in to one of three modes,
// exported as variables from this package: most users will probably want the
// "Automatic" mode.
func Wrap(l net.Listener, m mode) *T {
t := &T{
l: l,
mode: m,
// To keep the expected contention rate constant we'd have to
// grow this as numcpu**2. In practice, CPU counts don't
// generally grow without bound, and contention is probably
// going to be small enough that nobody cares anyways.
shards: make([]shard, 2*runtime.NumCPU()),
}
for i := range t.shards {
t.shards[i].init(t)
}
return t
}
// T is the type of this package's graceful listeners.
type T struct {
mu sync.Mutex
l net.Listener
// TODO(carl): a count of currently outstanding connections.
connCount uint64
shards []shard
mode mode
}
var _ net.Listener = &T{}
// Accept waits for and returns the next connection to the listener. The
// returned net.Conn's idleness is tracked, and idle connections can be closed
// from the associated T.
func (t *T) Accept() (net.Conn, error) {
c, err := t.l.Accept()
if err != nil {
return nil, err
}
connID := atomic.AddUint64(&t.connCount, 1)
shard := &t.shards[int(connID)%len(t.shards)]
wc := &conn{
Conn: c,
shard: shard,
mode: t.mode,
}
if err = wc.init(); err != nil {
return nil, err
}
return wc, nil
}
// Addr returns the wrapped listener's network address.
func (t *T) Addr() net.Addr {
return t.l.Addr()
}
// Close closes the wrapped listener.
func (t *T) Close() error {
return t.l.Close()
}
// CloseIdle closes all connections that are currently marked as being idle. It,
// however, makes no attempt to wait for in-use connections to die, or to close
// connections which become idle in the future. Call this function if you're
// interested in shedding useless connections, but otherwise wish to continue
// serving requests.
func (t *T) CloseIdle() error {
for i := range t.shards {
t.shards[i].closeIdle(false)
}
// Not sure if returning errors is actually useful here :/
return nil
}
// Drain immediately closes all idle connections, prevents new connections from
// being accepted, and waits for all outstanding connections to finish.
//
// Once a listener has been drained, there is no way to re-enable it. You
// probably want to Close the listener before draining it, otherwise new
// connections will be accepted and immediately closed.
func (t *T) Drain() error {
for i := range t.shards {
t.shards[i].closeIdle(true)
}
for i := range t.shards {
t.shards[i].wait()
}
return nil
}
var notManagedErr = errors.New("listener: passed net.Conn is not managed by us")
// Disown causes a connection to no longer be tracked by the listener. The
// passed connection must have been returned by a call to Accept from this
// listener.
func Disown(c net.Conn) error {
if cn, ok := c.(*conn); ok {
return cn.disown()
}
return notManagedErr
}
// MarkIdle marks the given connection as being idle, and therefore eligible for
// closing at any time. The passed connection must have been returned by a call
// to Accept from this listener.
func MarkIdle(c net.Conn) error {
if cn, ok := c.(*conn); ok {
cn.markIdle()
return nil
}
return notManagedErr
}
// MarkInUse marks this connection as being in use, removing it from the set of
// connections which are eligible for closing. The passed connection must have
// been returned by a call to Accept from this listener.
func MarkInUse(c net.Conn) error {
if cn, ok := c.(*conn); ok {
cn.markInUse()
return nil
}
return notManagedErr
}

+ 143
- 0
graceful/listener/listener_test.go View File

@ -0,0 +1,143 @@
package listener
import (
"net"
"testing"
"time"
)
// Helper for tests acting on a single accepted connection
func singleConn(t *testing.T, m mode) (*T, *fakeConn, net.Conn) {
l := makeFakeListener("net.Listener")
wl := Wrap(l, m)
c := makeFakeConn("local", "remote")
go l.Enqueue(c)
wc, err := wl.Accept()
if err != nil {
t.Fatalf("error accepting connection: %v", err)
}
return wl, c, wc
}
func TestAddr(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Manual)
if a := l.Addr(); a.String() != "net.Listener" {
t.Errorf("addr was %v, wanted net.Listener", a)
}
if c.LocalAddr() != wc.LocalAddr() {
t.Errorf("local addresses don't match: %v, %v", c.LocalAddr(),
wc.LocalAddr())
}
if c.RemoteAddr() != wc.RemoteAddr() {
t.Errorf("remote addresses don't match: %v, %v", c.RemoteAddr(),
wc.RemoteAddr())
}
}
func TestBasicCloseIdle(t *testing.T) {
t.Parallel()
l, c, _ := singleConn(t, Manual)
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if !c.Closed() {
t.Error("idle connection not closed")
}
}
func TestMark(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Manual)
if err := MarkInUse(wc); err != nil {
t.Fatalf("error marking %v in-use: %v", wc, err)
}
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if c.Closed() {
t.Errorf("manually in-use connection was closed")
}
if err := MarkIdle(wc); err != nil {
t.Fatalf("error marking %v idle: %v", wc, err)
}
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if !c.Closed() {
t.Error("manually idle connection was not closed")
}
}
func TestDisown(t *testing.T) {
t.Parallel()
l, c, wc := singleConn(t, Manual)
if err := Disown(wc); err != nil {
t.Fatalf("error disowning connection: %v", err)
}
if err := l.CloseIdle(); err != nil {
t.Fatalf("error closing idle connections: %v", err)
}
if c.Closed() {
t.Errorf("disowned connection got closed")
}
}
func TestDrain(t *testing.T) {
t.Parallel()
l, _, wc := singleConn(t, Manual)
MarkInUse(wc)
start := time.Now()
go func() {
time.Sleep(50 * time.Millisecond)
MarkIdle(wc)
}()
if err := l.Drain(); err != nil {
t.Fatalf("error draining listener: %v", err)
}
end := time.Now()
if dt := end.Sub(start); dt < 50*time.Millisecond {
t.Errorf("expected at least 50ms wait, but got %v", dt)
}
}
func TestErrors(t *testing.T) {
t.Parallel()
_, c, wc := singleConn(t, Manual)
if err := Disown(c); err == nil {
t.Error("expected error when disowning unmanaged net.Conn")
}
if err := MarkIdle(c); err == nil {
t.Error("expected error when marking unmanaged net.Conn idle")
}
if err := MarkInUse(c); err == nil {
t.Error("expected error when marking unmanaged net.Conn in use")
}
if err := Disown(wc); err != nil {
t.Fatalf("unexpected error disowning socket: %v", err)
}
if err := Disown(wc); err == nil {
t.Error("expected error disowning socket twice")
}
}
func TestClose(t *testing.T) {
t.Parallel()
l, c, _ := singleConn(t, Manual)
if err := l.Close(); err != nil {
t.Fatalf("error while closing listener: %v", err)
}
if c.Closed() {
t.Error("connection closed when listener was?")
}
}

+ 103
- 0
graceful/listener/race_test.go View File

@ -0,0 +1,103 @@
package listener
import (
"fmt"
"math/rand"
"runtime"
"sync/atomic"
"testing"
"time"
)
func init() {
// Just to make sure we get some variety
runtime.GOMAXPROCS(4 * runtime.NumCPU())
}
// Chosen by random die roll
const seed = 4611413766552969250
// This is mostly just fuzzing to see what happens.
func TestRace(t *testing.T) {
t.Parallel()
l := makeFakeListener("net.Listener")
wl := Wrap(l, Automatic)
var flag int32
go func() {
for i := 0; ; i++ {
laddr := fmt.Sprintf("local%d", i)
raddr := fmt.Sprintf("remote%d", i)
c := makeFakeConn(laddr, raddr)
go func() {
defer func() {
if r := recover(); r != nil {
if atomic.LoadInt32(&flag) != 0 {
return
}
panic(r)
}
}()
l.Enqueue(c)
}()
wc, err := wl.Accept()
if err != nil {
if atomic.LoadInt32(&flag) != 0 {
return
}
t.Fatalf("error accepting connection: %v", err)
}
go func() {
for {
time.Sleep(50 * time.Millisecond)
c.AllowRead()
}
}()
go func(i int64) {
rng := rand.New(rand.NewSource(i + seed))
buf := make([]byte, 1024)
for j := 0; j < 1024; j++ {
if _, err := wc.Read(buf); err != nil {
if atomic.LoadInt32(&flag) != 0 {
// Peaceful; the connection has
// probably been closed while
// idle
return
}
t.Errorf("error reading in conn %d: %v",
i, err)
}
time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond)
// This one is to make sure the connection
// hasn't closed underneath us
if _, err := wc.Read(buf); err != nil {
t.Errorf("error reading in conn %d: %v",
i, err)
}
MarkIdle(wc)
time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond)
}
}(int64(i))
time.Sleep(time.Duration(i) * time.Millisecond / 2)
}
}()
if testing.Short() {
time.Sleep(2 * time.Second)
} else {
time.Sleep(10 * time.Second)
}
start := time.Now()
atomic.StoreInt32(&flag, 1)
wl.Close()
wl.Drain()
end := time.Now()
if dt := end.Sub(start); dt > 300*time.Millisecond {
t.Errorf("took %v to drain; expected shorter", dt)
}
}

+ 71
- 0
graceful/listener/shard.go View File

@ -0,0 +1,71 @@
package listener
import "sync"
type shard struct {
l *T
mu sync.Mutex
set map[*conn]struct{}
wg sync.WaitGroup
drain bool
// We pack shards together in an array, but we don't want them packed
// too closely, since we want to give each shard a dedicated CPU cache
// line. This amount of padding works out well for the common case of
// x64 processors (64-bit pointers with a 64-byte cache line).
_ [12]byte
}
// We pretty aggressively preallocate set entries in the hopes that we never
// have to allocate memory with the lock held. This is definitely a premature
// optimization and is probably misguided, but luckily it costs us essentially
// nothing.
const prealloc = 2048
func (s *shard) init(l *T) {
s.l = l
s.set = make(map[*conn]struct{}, prealloc)
}
func (s *shard) markIdle(c *conn) (shouldClose bool) {
s.mu.Lock()
if s.drain {
s.mu.Unlock()
return true
}
s.set[c] = struct{}{}
s.mu.Unlock()
return false
}
func (s *shard) markInUse(c *conn) {
s.mu.Lock()
delete(s.set, c)
s.mu.Unlock()
}
func (s *shard) closeIdle(drain bool) {
s.mu.Lock()
if drain {
s.drain = true
}
set := s.set
s.set = make(map[*conn]struct{}, prealloc)
// We have to drop the shard lock here to avoid deadlock: we cannot
// acquire the shard lock after the connection lock, and the closeIfIdle
// call below will grab a connection lock.
s.mu.Unlock()
for conn := range set {
// This might return an error (from Close), but I don't think we
// can do anything about it, so let's just pretend it didn't
// happen. (I also expect that most errors returned in this way
// are going to be pretty boring)
conn.closeIfIdle()
}
}
func (s *shard) wait() {
s.wg.Wait()
}

+ 6
- 5
graceful/middleware.go View File

@ -7,6 +7,9 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"sync/atomic"
"github.com/zenazn/goji/graceful/listener"
) )
/* /*
@ -62,10 +65,8 @@ type basicWriter struct {
func (b *basicWriter) maybeClose() { func (b *basicWriter) maybeClose() {
b.headerWritten = true b.headerWritten = true
select {
case <-kill:
if atomic.LoadInt32(&closing) != 0 {
b.ResponseWriter.Header().Set("Connection", "close") b.ResponseWriter.Header().Set("Connection", "close")
default:
} }
} }
@ -103,8 +104,8 @@ func (f *fancyWriter) Hijack() (c net.Conn, b *bufio.ReadWriter, e error) {
hj := f.basicWriter.ResponseWriter.(http.Hijacker) hj := f.basicWriter.ResponseWriter.(http.Hijacker)
c, b, e = hj.Hijack() c, b, e = hj.Hijack()
if conn, ok := c.(*conn); ok {
c = conn.hijack()
if e == nil {
e = listener.Disown(c)
} }
return return


+ 8
- 7
graceful/middleware_test.go View File

@ -4,6 +4,7 @@ package graceful
import ( import (
"net/http" "net/http"
"sync/atomic"
"testing" "testing"
) )
@ -36,7 +37,7 @@ func testClose(t *testing.T, h http.Handler, expectClose bool) {
} }
func TestNormal(t *testing.T) { func TestNormal(t *testing.T) {
kill = make(chan struct{})
atomic.StoreInt32(&closing, 0)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte{}) w.Write([]byte{})
}) })
@ -44,26 +45,26 @@ func TestNormal(t *testing.T) {
} }
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
kill = make(chan struct{})
atomic.StoreInt32(&closing, 0)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(kill)
atomic.StoreInt32(&closing, 1)
}) })
testClose(t, h, true) testClose(t, h, true)
} }
func TestCloseWriteHeader(t *testing.T) { func TestCloseWriteHeader(t *testing.T) {
kill = make(chan struct{})
atomic.StoreInt32(&closing, 0)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(kill)
atomic.StoreInt32(&closing, 1)
w.WriteHeader(200) w.WriteHeader(200)
}) })
testClose(t, h, true) testClose(t, h, true)
} }
func TestCloseWrite(t *testing.T) { func TestCloseWrite(t *testing.T) {
kill = make(chan struct{})
atomic.StoreInt32(&closing, 0)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(kill)
atomic.StoreInt32(&closing, 1)
w.Write([]byte{}) w.Write([]byte{})
}) })
testClose(t, h, true) testClose(t, h, true)


+ 0
- 185
graceful/net.go View File

@ -1,185 +0,0 @@
package graceful
import (
"io"
"net"
"sync"
"sync/atomic"
"time"
)
type listener struct {
net.Listener
}
// WrapListener wraps an arbitrary net.Listener for use with graceful shutdowns.
// All net.Conn's Accept()ed by this listener will be auto-wrapped as if
// WrapConn() were called on them.
func WrapListener(l net.Listener) net.Listener {
return listener{l}
}
func (l listener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
return WrapConn(conn), err
}
type conn struct {
mu sync.Mutex
cs *connSet
net.Conn
id uint64
busy, die bool
dead bool
hijacked bool
}
/*
WrapConn wraps an arbitrary connection for use with graceful shutdowns. The
graceful shutdown process will ensure that this connection is closed before
terminating the process.
In order to use this function, you must call SetReadDeadline() before the call
to Read() you might make to read a new request off the wire. The connection is
eligible for abrupt closing at any point between when the call to
SetReadDeadline() returns and when the call to Read returns with new data. It
does not matter what deadline is given to SetReadDeadline()--if a deadline is
inappropriate, providing one extremely far into the future will suffice.
Unfortunately, this means that it's difficult to use SetReadDeadline() in a
great many perfectly reasonable circumstances, such as to extend a deadline
after more data has been read, without the connection being eligible for
"graceful" termination at an undesirable time. Since this package was written
explicitly to target net/http, which does not as of this writing do any of this,
fixing the semantics here does not seem especially urgent.
*/
func WrapConn(c net.Conn) net.Conn {
if c == nil {
return nil
}
// Avoid race with termination code.
wgLock.Lock()
defer wgLock.Unlock()
// Determine whether the app is shutting down.
if acceptingRequests {
wg.Add(1)
return &conn{
Conn: c,
id: atomic.AddUint64(&idleSet.id, 1),
}
} else {
return nil
}
}
func (c *conn) Read(b []byte) (n int, err error) {
c.mu.Lock()
if !c.hijacked {
defer func() {
c.mu.Lock()
if c.hijacked {
// It's a little unclear to me how this case
// would happen, but we *did* drop the lock, so
// let's play it safe.
return
}
if c.dead {
// Dead sockets don't tell tales. This is to
// prevent the case where a Read manages to suck
// an entire request off the wire in a race with
// someone trying to close idle connections.
// Whoever grabs the conn lock first wins, and
// if that's the closing process, we need to
// "take back" the read.
n = 0
err = io.EOF
} else {
idleSet.markBusy(c)
}
c.mu.Unlock()
}()
}
c.mu.Unlock()
return c.Conn.Read(b)
}
func (c *conn) SetReadDeadline(t time.Time) error {
c.mu.Lock()
if !c.hijacked {
defer c.markIdle()
}
c.mu.Unlock()
return c.Conn.SetReadDeadline(t)
}
func (c *conn) Close() error {
kill := false
c.mu.Lock()
kill, c.dead = !c.dead, true
idleSet.markBusy(c)
c.mu.Unlock()
if kill {
defer wg.Done()
}
return c.Conn.Close()
}
type writerOnly struct {
w io.Writer
}
func (w writerOnly) Write(buf []byte) (int, error) {
return w.w.Write(buf)
}
func (c *conn) ReadFrom(r io.Reader) (int64, error) {
if rf, ok := c.Conn.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(writerOnly{c}, r)
}
func (c *conn) markIdle() {
kill := false
c.mu.Lock()
idleSet.markIdle(c)
if c.die {
kill, c.dead = !c.dead, true
}
c.mu.Unlock()
if kill {
defer wg.Done()
c.Conn.Close()
}
}
func (c *conn) closeIfIdle() {
kill := false
c.mu.Lock()
c.die = true
if !c.busy && !c.hijacked {
kill, c.dead = !c.dead, true
}
c.mu.Unlock()
if kill {
defer wg.Done()
c.Conn.Close()
}
}
func (c *conn) hijack() net.Conn {
c.mu.Lock()
idleSet.markBusy(c)
c.hijacked = true
c.mu.Unlock()
return c.Conn
}

+ 0
- 12
graceful/race_test.go View File

@ -1,12 +0,0 @@
// +build race
package graceful
import "testing"
func TestWaitGroupRace(t *testing.T) {
go func() {
go WrapConn(fakeConn{}).Close()
}()
Shutdown()
}

+ 6
- 16
graceful/serve.go View File

@ -6,19 +6,14 @@ import (
"net" "net"
"net/http" "net/http"
"time" "time"
"github.com/zenazn/goji/graceful/listener"
) )
// About 200 years, also known as "forever" // About 200 years, also known as "forever"
const forever time.Duration = 200 * 365 * 24 * time.Hour const forever time.Duration = 200 * 365 * 24 * time.Hour
func (srv *Server) Serve(l net.Listener) error { func (srv *Server) Serve(l net.Listener) error {
go func() {
<-kill
l.Close()
idleSet.killall()
}()
l = WrapListener(l)
// Spawn a shadow http.Server to do the actual servering. We do this // Spawn a shadow http.Server to do the actual servering. We do this
// because we need to sketch on some of the parameters you passed in, // because we need to sketch on some of the parameters you passed in,
// and it's nice to keep our sketching to ourselves. // and it's nice to keep our sketching to ourselves.
@ -29,14 +24,9 @@ func (srv *Server) Serve(l net.Listener) error {
} }
shadow.Handler = Middleware(shadow.Handler) shadow.Handler = Middleware(shadow.Handler)
err := shadow.Serve(l)
wrap := listener.Wrap(l, listener.Deadline)
appendListener(wrap)
// We expect an error when we close the listener, so we indiscriminately
// swallow Serve errors when we're in a shutdown state.
select {
case <-kill:
return nil
default:
return err
}
err := shadow.Serve(wrap)
return peacefulError(err)
} }

+ 59
- 51
graceful/serve13.go View File

@ -3,67 +3,75 @@
package graceful package graceful
import ( import (
"log"
"net" "net"
"net/http" "net/http"
"github.com/zenazn/goji/graceful/listener"
) )
func (srv *Server) Serve(l net.Listener) error {
l = WrapListener(l)
// This is a slightly hacky shim to disable keepalives when shutting a server
// down. We could have added extra functionality in listener or signal.go to
// deal with this case, but this seems simpler.
type gracefulServer struct {
net.Listener
s *http.Server
}
// Spawn a shadow http.Server to do the actual servering. We do this
// because we need to sketch on some of the parameters you passed in,
// and it's nice to keep our sketching to ourselves.
shadow := *(*http.Server)(srv)
func (g gracefulServer) Close() error {
g.s.SetKeepAlivesEnabled(false)
return g.Listener.Close()
}
// A chaining http.ConnState wrapper
type connState func(net.Conn, http.ConnState)
cs := shadow.ConnState
shadow.ConnState = func(nc net.Conn, s http.ConnState) {
if c, ok := nc.(*conn); ok {
// There are a few other states defined, most notably
// StateActive. Unfortunately it doesn't look like it's
// possible to make use of StateActive to implement
// graceful shutdown, since StateActive is set after a
// complete request has been read off the wire with an
// intent to process it. If we were to race a graceful
// shutdown against a connection that was just read off
// the wire (but not yet in StateActive), we would
// accidentally close the connection out from underneath
// an active request.
//
// We already needed to work around this for Go 1.2 by
// shimming out a full net.Conn object, so we can just
// fall back to the old behavior there.
//
// I started a golang-nuts thread about this here:
// https://groups.google.com/forum/#!topic/golang-nuts/Xi8yjBGWfCQ
// I'd be very eager to find a better way to do this, so
// reach out to me if you have any ideas.
switch s {
case http.StateIdle:
c.markIdle()
case http.StateHijacked:
c.hijack()
}
func (c connState) Wrap(nc net.Conn, s http.ConnState) {
// There are a few other states defined, most notably StateActive.
// Unfortunately it doesn't look like it's possible to make use of
// StateActive to implement graceful shutdown, since StateActive is set
// after a complete request has been read off the wire with an intent to
// process it. If we were to race a graceful shutdown against a
// connection that was just read off the wire (but not yet in
// StateActive), we would accidentally close the connection out from
// underneath an active request.
//
// We already needed to work around this for Go 1.2 by shimming out a
// full net.Conn object, so we can just fall back to the old behavior
// there.
//
// I started a golang-nuts thread about this here:
// https://groups.google.com/forum/#!topic/golang-nuts/Xi8yjBGWfCQ I'd
// be very eager to find a better way to do this, so reach out to me if
// you have any ideas.
switch s {
case http.StateIdle:
if err := listener.MarkIdle(nc); err != nil {
log.Printf("error marking conn as idle: %v",
err)
} }
if cs != nil {
cs(nc, s)
case http.StateHijacked:
if err := listener.Disown(nc); err != nil {
log.Printf("error disowning hijacked conn: %v",
err)
} }
} }
if c != nil {
c(nc, s)
}
}
go func() {
<-kill
l.Close()
shadow.SetKeepAlivesEnabled(false)
idleSet.killall()
}()
func (srv *Server) Serve(l net.Listener) error {
// Spawn a shadow http.Server to do the actual servering. We do this
// because we need to sketch on some of the parameters you passed in,
// and it's nice to keep our sketching to ourselves.
shadow := *(*http.Server)(srv)
shadow.ConnState = connState(shadow.ConnState).Wrap
err := shadow.Serve(l)
l = gracefulServer{l, &shadow}
wrap := listener.Wrap(l, listener.Automatic)
appendListener(wrap)
// We expect an error when we close the listener, so we indiscriminately
// swallow Serve errors when we're in a shutdown state.
select {
case <-kill:
return nil
default:
return err
}
err := shadow.Serve(wrap)
return peacefulError(err)
} }

+ 53
- 33
graceful/signal.go View File

@ -1,32 +1,22 @@
package graceful package graceful
import ( import (
"net"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
)
// This is the channel that the connections select on. When it is closed, the
// connections should gracefully exit.
var kill = make(chan struct{})
// This is the channel that the Wait() function selects on. It should only be
// closed once all the posthooks have been called.
var wait = make(chan struct{})
// Whether new requests should be accepted. When false, new requests are refused.
var acceptingRequests bool = true
"sync/atomic"
// This is the WaitGroup that indicates when all the connections have gracefully
// shut down.
var wg sync.WaitGroup
var wgLock sync.Mutex
"github.com/zenazn/goji/graceful/listener"
)
// This lock protects the list of pre- and post- hooks below.
var hookLock sync.Mutex
var mu sync.Mutex // protects everything that follows
var listeners = make([]*listener.T, 0)
var prehooks = make([]func(), 0) var prehooks = make([]func(), 0)
var posthooks = make([]func(), 0) var posthooks = make([]func(), 0)
var closing int32
var wait = make(chan struct{})
var stdSignals = []os.Signal{os.Interrupt} var stdSignals = []os.Signal{os.Interrupt}
var sigchan = make(chan os.Signal, 1) var sigchan = make(chan os.Signal, 1)
@ -71,8 +61,8 @@ func Shutdown() {
// shutdown actions. All listeners will be called in the order they were added, // shutdown actions. All listeners will be called in the order they were added,
// from a single goroutine. // from a single goroutine.
func PreHook(f func()) { func PreHook(f func()) {
hookLock.Lock()
defer hookLock.Unlock()
mu.Lock()
defer mu.Unlock()
prehooks = append(prehooks, f) prehooks = append(prehooks, f)
} }
@ -82,12 +72,12 @@ func PreHook(f func()) {
// from a single goroutine, and are guaranteed to be called after all listening // from a single goroutine, and are guaranteed to be called after all listening
// connections have been closed, but before Wait() returns. // connections have been closed, but before Wait() returns.
// //
// If you've Hijack()ed any connections that must be gracefully shut down in
// some other way (since this library disowns all hijacked connections), it's
// reasonable to use a PostHook() to signal and wait for them.
// If you've Hijacked any connections that must be gracefully shut down in some
// other way (since this library disowns all hijacked connections), it's
// reasonable to use a PostHook to signal and wait for them.
func PostHook(f func()) { func PostHook(f func()) {
hookLock.Lock()
defer hookLock.Unlock()
mu.Lock()
defer mu.Unlock()
posthooks = append(posthooks, f) posthooks = append(posthooks, f)
} }
@ -95,19 +85,23 @@ func PostHook(f func()) {
func waitForSignal() { func waitForSignal() {
<-sigchan <-sigchan
// Prevent servicing of any new requests.
wgLock.Lock()
acceptingRequests = false
wgLock.Unlock()
hookLock.Lock()
defer hookLock.Unlock()
mu.Lock()
defer mu.Unlock()
for _, f := range prehooks { for _, f := range prehooks {
f() f()
} }
close(kill)
atomic.StoreInt32(&closing, 1)
var wg sync.WaitGroup
wg.Add(len(listeners))
for _, l := range listeners {
go func(l *listener.T) {
defer wg.Done()
l.Close()
l.Drain()
}(l)
}
wg.Wait() wg.Wait()
for _, f := range posthooks { for _, f := range posthooks {
@ -123,3 +117,29 @@ func waitForSignal() {
func Wait() { func Wait() {
<-wait <-wait
} }
func appendListener(l *listener.T) {
mu.Lock()
defer mu.Unlock()
listeners = append(listeners, l)
}
const errClosing = "use of closed network connection"
// During graceful shutdown, calls to Accept will start returning errors. This
// is inconvenient, since we know these sorts of errors are peaceful, so we
// silently swallow them.
func peacefulError(err error) error {
if atomic.LoadInt32(&closing) == 0 {
return err
}
// Unfortunately Go doesn't really give us a better way to select errors
// than this, so *shrug*.
if oe, ok := err.(*net.OpError); ok {
if oe.Op == "accept" && oe.Err.Error() == errClosing {
return nil
}
}
return err
}

Loading…
Cancel
Save