This change refactors package graceful into two packages: one very well tested package that deals with graceful shutdown of arbitrary net.Listeners in the abstract, and one less-well-tested package that works with the nitty-gritty details of net/http and signal handling. This is a breaking API change for advanced users of package graceful: the WrapConn function no longer exists. This shouldn't affect most users or use cases.
| @ -1,80 +0,0 @@ | |||
| package graceful | |||
| import ( | |||
| "runtime" | |||
| "sync" | |||
| ) | |||
| type connShard struct { | |||
| mu sync.Mutex | |||
| // We sort of abuse this field to also act as a "please shut down" flag. | |||
| // If it's nil, you should die at your earliest opportunity. | |||
| set map[*conn]struct{} | |||
| } | |||
| type connSet struct { | |||
| // This is an incrementing connection counter so we round-robin | |||
| // connections across shards. Use atomic when touching it. | |||
| id uint64 | |||
| shards []*connShard | |||
| } | |||
| var idleSet connSet | |||
| // We pretty aggressively preallocate set entries in the hopes that we never | |||
| // have to allocate memory with the lock held. This is definitely a premature | |||
| // optimization and is probably misguided, but luckily it costs us essentially | |||
| // nothing. | |||
| const prealloc = 2048 | |||
| func init() { | |||
| // To keep the expected contention rate constant we'd have to grow this | |||
| // as numcpu**2. In practice, CPU counts don't generally grow without | |||
| // bound, and contention is probably going to be small enough that | |||
| // nobody cares anyways. | |||
| idleSet.shards = make([]*connShard, 2*runtime.NumCPU()) | |||
| for i := range idleSet.shards { | |||
| idleSet.shards[i] = &connShard{ | |||
| set: make(map[*conn]struct{}, prealloc), | |||
| } | |||
| } | |||
| } | |||
| func (cs connSet) markIdle(c *conn) { | |||
| c.busy = false | |||
| shard := cs.shards[int(c.id%uint64(len(cs.shards)))] | |||
| shard.mu.Lock() | |||
| if shard.set == nil { | |||
| shard.mu.Unlock() | |||
| c.die = true | |||
| } else { | |||
| shard.set[c] = struct{}{} | |||
| shard.mu.Unlock() | |||
| } | |||
| } | |||
| func (cs connSet) markBusy(c *conn) { | |||
| c.busy = true | |||
| shard := cs.shards[int(c.id%uint64(len(cs.shards)))] | |||
| shard.mu.Lock() | |||
| if shard.set == nil { | |||
| shard.mu.Unlock() | |||
| c.die = true | |||
| } else { | |||
| delete(shard.set, c) | |||
| shard.mu.Unlock() | |||
| } | |||
| } | |||
| func (cs connSet) killall() { | |||
| for _, shard := range cs.shards { | |||
| shard.mu.Lock() | |||
| set := shard.set | |||
| shard.set = nil | |||
| shard.mu.Unlock() | |||
| for conn := range set { | |||
| conn.closeIfIdle() | |||
| } | |||
| } | |||
| } | |||
| @ -1,63 +0,0 @@ | |||
| package graceful | |||
| import ( | |||
| "net" | |||
| "time" | |||
| ) | |||
| // Stub out a net.Conn. This is going to be painful. | |||
| type fakeAddr struct{} | |||
| func (f fakeAddr) Network() string { | |||
| return "fake" | |||
| } | |||
| func (f fakeAddr) String() string { | |||
| return "fake" | |||
| } | |||
| type fakeConn struct { | |||
| onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func() | |||
| onSetDeadline, onSetReadDeadline, onSetWriteDeadline func() | |||
| } | |||
| // Here's my number, so... | |||
| func callMeMaybe(f func()) { | |||
| // I apologize for nothing. | |||
| if f != nil { | |||
| f() | |||
| } | |||
| } | |||
| func (f fakeConn) Read(b []byte) (int, error) { | |||
| callMeMaybe(f.onRead) | |||
| return len(b), nil | |||
| } | |||
| func (f fakeConn) Write(b []byte) (int, error) { | |||
| callMeMaybe(f.onWrite) | |||
| return len(b), nil | |||
| } | |||
| func (f fakeConn) Close() error { | |||
| callMeMaybe(f.onClose) | |||
| return nil | |||
| } | |||
| func (f fakeConn) LocalAddr() net.Addr { | |||
| callMeMaybe(f.onLocalAddr) | |||
| return fakeAddr{} | |||
| } | |||
| func (f fakeConn) RemoteAddr() net.Addr { | |||
| callMeMaybe(f.onRemoteAddr) | |||
| return fakeAddr{} | |||
| } | |||
| func (f fakeConn) SetDeadline(t time.Time) error { | |||
| callMeMaybe(f.onSetDeadline) | |||
| return nil | |||
| } | |||
| func (f fakeConn) SetReadDeadline(t time.Time) error { | |||
| callMeMaybe(f.onSetReadDeadline) | |||
| return nil | |||
| } | |||
| func (f fakeConn) SetWriteDeadline(t time.Time) error { | |||
| callMeMaybe(f.onSetWriteDeadline) | |||
| return nil | |||
| } | |||
| @ -0,0 +1,147 @@ | |||
| package listener | |||
| import ( | |||
| "errors" | |||
| "io" | |||
| "net" | |||
| "sync" | |||
| "time" | |||
| ) | |||
| type conn struct { | |||
| net.Conn | |||
| shard *shard | |||
| mode mode | |||
| mu sync.Mutex // Protects the state machine below | |||
| busy bool // connection is in use (i.e., not idle) | |||
| closed bool // connection is closed | |||
| disowned bool // if true, this connection is no longer under our management | |||
| } | |||
| // This intentionally looks a lot like the one in package net. | |||
| var errClosing = errors.New("use of closed network connection") | |||
| func (c *conn) init() error { | |||
| c.shard.wg.Add(1) | |||
| if shouldExit := c.shard.markIdle(c); shouldExit { | |||
| c.Close() | |||
| return errClosing | |||
| } | |||
| return nil | |||
| } | |||
| func (c *conn) Read(b []byte) (n int, err error) { | |||
| defer func() { | |||
| c.mu.Lock() | |||
| defer c.mu.Unlock() | |||
| if c.disowned { | |||
| return | |||
| } | |||
| // This protects against a Close/Read race. We're not really | |||
| // concerned about the general case (it's fundamentally racy), | |||
| // but are mostly trying to prevent a race between a new request | |||
| // getting read off the wire in one thread while the connection | |||
| // is being gracefully shut down in another. | |||
| if c.closed && err == nil { | |||
| n = 0 | |||
| err = errClosing | |||
| return | |||
| } | |||
| if c.mode != Manual && !c.busy && !c.closed { | |||
| c.busy = true | |||
| c.shard.markInUse(c) | |||
| } | |||
| }() | |||
| return c.Conn.Read(b) | |||
| } | |||
| func (c *conn) Close() error { | |||
| c.mu.Lock() | |||
| defer c.mu.Unlock() | |||
| if !c.closed && !c.disowned { | |||
| c.closed = true | |||
| c.shard.markInUse(c) | |||
| defer c.shard.wg.Done() | |||
| } | |||
| return c.Conn.Close() | |||
| } | |||
| func (c *conn) SetReadDeadline(t time.Time) error { | |||
| c.mu.Lock() | |||
| if !c.disowned && c.mode == Deadline { | |||
| defer c.markIdle() | |||
| } | |||
| c.mu.Unlock() | |||
| return c.Conn.SetReadDeadline(t) | |||
| } | |||
| func (c *conn) ReadFrom(r io.Reader) (int64, error) { | |||
| return io.Copy(c.Conn, r) | |||
| } | |||
| func (c *conn) markIdle() { | |||
| c.mu.Lock() | |||
| defer c.mu.Unlock() | |||
| if !c.busy { | |||
| return | |||
| } | |||
| c.busy = false | |||
| if exit := c.shard.markIdle(c); exit && !c.closed && !c.disowned { | |||
| c.closed = true | |||
| c.shard.markInUse(c) | |||
| defer c.shard.wg.Done() | |||
| c.Conn.Close() | |||
| return | |||
| } | |||
| } | |||
| func (c *conn) markInUse() { | |||
| c.mu.Lock() | |||
| defer c.mu.Unlock() | |||
| if !c.busy && !c.closed && !c.disowned { | |||
| c.busy = true | |||
| c.shard.markInUse(c) | |||
| } | |||
| } | |||
| func (c *conn) closeIfIdle() error { | |||
| c.mu.Lock() | |||
| defer c.mu.Unlock() | |||
| if !c.busy && !c.closed && !c.disowned { | |||
| c.closed = true | |||
| c.shard.markInUse(c) | |||
| defer c.shard.wg.Done() | |||
| return c.Conn.Close() | |||
| } | |||
| return nil | |||
| } | |||
| var errAlreadyDisowned = errors.New("listener: conn already disowned") | |||
| func (c *conn) disown() error { | |||
| c.mu.Lock() | |||
| defer c.mu.Unlock() | |||
| if c.disowned { | |||
| return errAlreadyDisowned | |||
| } | |||
| c.shard.markInUse(c) | |||
| c.disowned = true | |||
| c.shard.wg.Done() | |||
| return nil | |||
| } | |||
| @ -0,0 +1,182 @@ | |||
| package listener | |||
| import ( | |||
| "io" | |||
| "strings" | |||
| "testing" | |||
| "time" | |||
| ) | |||
| func TestManualRead(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Manual) | |||
| go c.AllowRead() | |||
| wc.Read(make([]byte, 1024)) | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if !c.Closed() { | |||
| t.Error("Read() should not make connection not-idle") | |||
| } | |||
| } | |||
| func TestAutomaticRead(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Automatic) | |||
| go c.AllowRead() | |||
| wc.Read(make([]byte, 1024)) | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if c.Closed() { | |||
| t.Error("expected Read() to mark connection as in-use") | |||
| } | |||
| } | |||
| func TestDeadlineRead(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Deadline) | |||
| go c.AllowRead() | |||
| if _, err := wc.Read(make([]byte, 1024)); err != nil { | |||
| t.Fatalf("error reading from connection: %v", err) | |||
| } | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if c.Closed() { | |||
| t.Error("expected Read() to mark connection as in-use") | |||
| } | |||
| } | |||
| func TestDisownedRead(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Deadline) | |||
| if err := Disown(wc); err != nil { | |||
| t.Fatalf("unexpected error disowning conn: %v", err) | |||
| } | |||
| if err := l.Close(); err != nil { | |||
| t.Fatalf("unexpected error closing listener: %v", err) | |||
| } | |||
| if err := l.Drain(); err != nil { | |||
| t.Fatalf("unexpected error draining listener: %v", err) | |||
| } | |||
| go c.AllowRead() | |||
| if _, err := wc.Read(make([]byte, 1024)); err != nil { | |||
| t.Fatalf("error reading from connection: %v", err) | |||
| } | |||
| } | |||
| func TestCloseConn(t *testing.T) { | |||
| t.Parallel() | |||
| l, _, wc := singleConn(t, Deadline) | |||
| if err := MarkInUse(wc); err != nil { | |||
| t.Fatalf("error marking conn in use: %v", err) | |||
| } | |||
| if err := wc.Close(); err != nil { | |||
| t.Errorf("error closing connection: %v", err) | |||
| } | |||
| // This will hang if wc.Close() doesn't un-track the connection | |||
| if err := l.Drain(); err != nil { | |||
| t.Errorf("error draining listener: %v", err) | |||
| } | |||
| } | |||
| func TestManualReadDeadline(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Manual) | |||
| if err := MarkInUse(wc); err != nil { | |||
| t.Fatalf("error marking connection in use: %v", err) | |||
| } | |||
| if err := wc.SetReadDeadline(time.Now()); err != nil { | |||
| t.Fatalf("error setting read deadline: %v", err) | |||
| } | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if c.Closed() { | |||
| t.Error("SetReadDeadline() should not mark manual conn as idle") | |||
| } | |||
| } | |||
| func TestAutomaticReadDeadline(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Automatic) | |||
| if err := MarkInUse(wc); err != nil { | |||
| t.Fatalf("error marking connection in use: %v", err) | |||
| } | |||
| if err := wc.SetReadDeadline(time.Now()); err != nil { | |||
| t.Fatalf("error setting read deadline: %v", err) | |||
| } | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if c.Closed() { | |||
| t.Error("SetReadDeadline() should not mark automatic conn as idle") | |||
| } | |||
| } | |||
| func TestDeadlineReadDeadline(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Deadline) | |||
| if err := MarkInUse(wc); err != nil { | |||
| t.Fatalf("error marking connection in use: %v", err) | |||
| } | |||
| if err := wc.SetReadDeadline(time.Now()); err != nil { | |||
| t.Fatalf("error setting read deadline: %v", err) | |||
| } | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if !c.Closed() { | |||
| t.Error("SetReadDeadline() should mark deadline conn as idle") | |||
| } | |||
| } | |||
| type readerConn struct { | |||
| fakeConn | |||
| } | |||
| func (rc *readerConn) ReadFrom(r io.Reader) (int64, error) { | |||
| return 123, nil | |||
| } | |||
| func TestReadFrom(t *testing.T) { | |||
| t.Parallel() | |||
| l := makeFakeListener("net.Listener") | |||
| wl := Wrap(l, Manual) | |||
| c := &readerConn{ | |||
| fakeConn{ | |||
| read: make(chan struct{}), | |||
| write: make(chan struct{}), | |||
| closed: make(chan struct{}), | |||
| me: fakeAddr{"tcp", "local"}, | |||
| you: fakeAddr{"tcp", "remote"}, | |||
| }, | |||
| } | |||
| go l.Enqueue(c) | |||
| wc, err := wl.Accept() | |||
| if err != nil { | |||
| t.Fatalf("error accepting connection: %v", err) | |||
| } | |||
| // The io.MultiReader is a convenient hack to ensure that we're using | |||
| // our ReadFrom, not strings.Reader's WriteTo. | |||
| r := io.MultiReader(strings.NewReader("hello world")) | |||
| if _, err := io.Copy(wc, r); err != nil { | |||
| t.Fatalf("error copying: %v", err) | |||
| } | |||
| } | |||
| @ -0,0 +1,123 @@ | |||
| package listener | |||
| import ( | |||
| "net" | |||
| "time" | |||
| ) | |||
| type fakeAddr struct { | |||
| network, addr string | |||
| } | |||
| func (f fakeAddr) Network() string { | |||
| return f.network | |||
| } | |||
| func (f fakeAddr) String() string { | |||
| return f.addr | |||
| } | |||
| type fakeListener struct { | |||
| ch chan net.Conn | |||
| closed chan struct{} | |||
| addr net.Addr | |||
| } | |||
| func makeFakeListener(addr string) *fakeListener { | |||
| a := fakeAddr{"tcp", addr} | |||
| return &fakeListener{ | |||
| ch: make(chan net.Conn), | |||
| closed: make(chan struct{}), | |||
| addr: a, | |||
| } | |||
| } | |||
| func (f *fakeListener) Accept() (net.Conn, error) { | |||
| select { | |||
| case c := <-f.ch: | |||
| return c, nil | |||
| case <-f.closed: | |||
| return nil, errClosing | |||
| } | |||
| } | |||
| func (f *fakeListener) Close() error { | |||
| close(f.closed) | |||
| return nil | |||
| } | |||
| func (f *fakeListener) Addr() net.Addr { | |||
| return f.addr | |||
| } | |||
| func (f *fakeListener) Enqueue(c net.Conn) { | |||
| f.ch <- c | |||
| } | |||
| type fakeConn struct { | |||
| read, write, closed chan struct{} | |||
| me, you net.Addr | |||
| } | |||
| func makeFakeConn(me, you string) *fakeConn { | |||
| return &fakeConn{ | |||
| read: make(chan struct{}), | |||
| write: make(chan struct{}), | |||
| closed: make(chan struct{}), | |||
| me: fakeAddr{"tcp", me}, | |||
| you: fakeAddr{"tcp", you}, | |||
| } | |||
| } | |||
| func (f *fakeConn) Read(buf []byte) (int, error) { | |||
| select { | |||
| case <-f.read: | |||
| return len(buf), nil | |||
| case <-f.closed: | |||
| return 0, errClosing | |||
| } | |||
| } | |||
| func (f *fakeConn) Write(buf []byte) (int, error) { | |||
| select { | |||
| case <-f.write: | |||
| return len(buf), nil | |||
| case <-f.closed: | |||
| return 0, errClosing | |||
| } | |||
| } | |||
| func (f *fakeConn) Close() error { | |||
| close(f.closed) | |||
| return nil | |||
| } | |||
| func (f *fakeConn) LocalAddr() net.Addr { | |||
| return f.me | |||
| } | |||
| func (f *fakeConn) RemoteAddr() net.Addr { | |||
| return f.you | |||
| } | |||
| func (f *fakeConn) SetDeadline(t time.Time) error { | |||
| return nil | |||
| } | |||
| func (f *fakeConn) SetReadDeadline(t time.Time) error { | |||
| return nil | |||
| } | |||
| func (f *fakeConn) SetWriteDeadline(t time.Time) error { | |||
| return nil | |||
| } | |||
| func (f *fakeConn) Closed() bool { | |||
| select { | |||
| case <-f.closed: | |||
| return true | |||
| default: | |||
| return false | |||
| } | |||
| } | |||
| func (f *fakeConn) AllowRead() { | |||
| f.read <- struct{}{} | |||
| } | |||
| func (f *fakeConn) AllowWrite() { | |||
| f.write <- struct{}{} | |||
| } | |||
| @ -0,0 +1,165 @@ | |||
| /* | |||
| Package listener provides a way to incorporate graceful shutdown to any | |||
| net.Listener. | |||
| This package provides low-level primitives, not a high-level API. If you're | |||
| looking for a package that provides graceful shutdown for HTTP servers, I | |||
| recommend this package's parent package, github.com/zenazn/goji/graceful. | |||
| */ | |||
| package listener | |||
| import ( | |||
| "errors" | |||
| "net" | |||
| "runtime" | |||
| "sync" | |||
| "sync/atomic" | |||
| ) | |||
| type mode int8 | |||
| const ( | |||
| // Manual mode is completely manual: users must use use MarkIdle and | |||
| // MarkInUse to indicate when connections are busy servicing requests or | |||
| // are eligible for termination. | |||
| Manual mode = iota | |||
| // Automatic mode is what most users probably want: calling Read on a | |||
| // connection will mark it as in use, but users must manually call | |||
| // MarkIdle to indicate when connections may be safely closed. | |||
| Automatic | |||
| // Deadline mode is like automatic mode, except that calling | |||
| // SetReadDeadline on a connection will also mark it as being idle. This | |||
| // is useful for many servers like net/http, where SetReadDeadline is | |||
| // used to implement read timeouts on new requests. | |||
| Deadline | |||
| ) | |||
| // Wrap a net.Listener, returning a net.Listener which supports idle connection | |||
| // tracking and shutdown. Listeners can be placed in to one of three modes, | |||
| // exported as variables from this package: most users will probably want the | |||
| // "Automatic" mode. | |||
| func Wrap(l net.Listener, m mode) *T { | |||
| t := &T{ | |||
| l: l, | |||
| mode: m, | |||
| // To keep the expected contention rate constant we'd have to | |||
| // grow this as numcpu**2. In practice, CPU counts don't | |||
| // generally grow without bound, and contention is probably | |||
| // going to be small enough that nobody cares anyways. | |||
| shards: make([]shard, 2*runtime.NumCPU()), | |||
| } | |||
| for i := range t.shards { | |||
| t.shards[i].init(t) | |||
| } | |||
| return t | |||
| } | |||
| // T is the type of this package's graceful listeners. | |||
| type T struct { | |||
| mu sync.Mutex | |||
| l net.Listener | |||
| // TODO(carl): a count of currently outstanding connections. | |||
| connCount uint64 | |||
| shards []shard | |||
| mode mode | |||
| } | |||
| var _ net.Listener = &T{} | |||
| // Accept waits for and returns the next connection to the listener. The | |||
| // returned net.Conn's idleness is tracked, and idle connections can be closed | |||
| // from the associated T. | |||
| func (t *T) Accept() (net.Conn, error) { | |||
| c, err := t.l.Accept() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| connID := atomic.AddUint64(&t.connCount, 1) | |||
| shard := &t.shards[int(connID)%len(t.shards)] | |||
| wc := &conn{ | |||
| Conn: c, | |||
| shard: shard, | |||
| mode: t.mode, | |||
| } | |||
| if err = wc.init(); err != nil { | |||
| return nil, err | |||
| } | |||
| return wc, nil | |||
| } | |||
| // Addr returns the wrapped listener's network address. | |||
| func (t *T) Addr() net.Addr { | |||
| return t.l.Addr() | |||
| } | |||
| // Close closes the wrapped listener. | |||
| func (t *T) Close() error { | |||
| return t.l.Close() | |||
| } | |||
| // CloseIdle closes all connections that are currently marked as being idle. It, | |||
| // however, makes no attempt to wait for in-use connections to die, or to close | |||
| // connections which become idle in the future. Call this function if you're | |||
| // interested in shedding useless connections, but otherwise wish to continue | |||
| // serving requests. | |||
| func (t *T) CloseIdle() error { | |||
| for i := range t.shards { | |||
| t.shards[i].closeIdle(false) | |||
| } | |||
| // Not sure if returning errors is actually useful here :/ | |||
| return nil | |||
| } | |||
| // Drain immediately closes all idle connections, prevents new connections from | |||
| // being accepted, and waits for all outstanding connections to finish. | |||
| // | |||
| // Once a listener has been drained, there is no way to re-enable it. You | |||
| // probably want to Close the listener before draining it, otherwise new | |||
| // connections will be accepted and immediately closed. | |||
| func (t *T) Drain() error { | |||
| for i := range t.shards { | |||
| t.shards[i].closeIdle(true) | |||
| } | |||
| for i := range t.shards { | |||
| t.shards[i].wait() | |||
| } | |||
| return nil | |||
| } | |||
| var notManagedErr = errors.New("listener: passed net.Conn is not managed by us") | |||
| // Disown causes a connection to no longer be tracked by the listener. The | |||
| // passed connection must have been returned by a call to Accept from this | |||
| // listener. | |||
| func Disown(c net.Conn) error { | |||
| if cn, ok := c.(*conn); ok { | |||
| return cn.disown() | |||
| } | |||
| return notManagedErr | |||
| } | |||
| // MarkIdle marks the given connection as being idle, and therefore eligible for | |||
| // closing at any time. The passed connection must have been returned by a call | |||
| // to Accept from this listener. | |||
| func MarkIdle(c net.Conn) error { | |||
| if cn, ok := c.(*conn); ok { | |||
| cn.markIdle() | |||
| return nil | |||
| } | |||
| return notManagedErr | |||
| } | |||
| // MarkInUse marks this connection as being in use, removing it from the set of | |||
| // connections which are eligible for closing. The passed connection must have | |||
| // been returned by a call to Accept from this listener. | |||
| func MarkInUse(c net.Conn) error { | |||
| if cn, ok := c.(*conn); ok { | |||
| cn.markInUse() | |||
| return nil | |||
| } | |||
| return notManagedErr | |||
| } | |||
| @ -0,0 +1,143 @@ | |||
| package listener | |||
| import ( | |||
| "net" | |||
| "testing" | |||
| "time" | |||
| ) | |||
| // Helper for tests acting on a single accepted connection | |||
| func singleConn(t *testing.T, m mode) (*T, *fakeConn, net.Conn) { | |||
| l := makeFakeListener("net.Listener") | |||
| wl := Wrap(l, m) | |||
| c := makeFakeConn("local", "remote") | |||
| go l.Enqueue(c) | |||
| wc, err := wl.Accept() | |||
| if err != nil { | |||
| t.Fatalf("error accepting connection: %v", err) | |||
| } | |||
| return wl, c, wc | |||
| } | |||
| func TestAddr(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Manual) | |||
| if a := l.Addr(); a.String() != "net.Listener" { | |||
| t.Errorf("addr was %v, wanted net.Listener", a) | |||
| } | |||
| if c.LocalAddr() != wc.LocalAddr() { | |||
| t.Errorf("local addresses don't match: %v, %v", c.LocalAddr(), | |||
| wc.LocalAddr()) | |||
| } | |||
| if c.RemoteAddr() != wc.RemoteAddr() { | |||
| t.Errorf("remote addresses don't match: %v, %v", c.RemoteAddr(), | |||
| wc.RemoteAddr()) | |||
| } | |||
| } | |||
| func TestBasicCloseIdle(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, _ := singleConn(t, Manual) | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if !c.Closed() { | |||
| t.Error("idle connection not closed") | |||
| } | |||
| } | |||
| func TestMark(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Manual) | |||
| if err := MarkInUse(wc); err != nil { | |||
| t.Fatalf("error marking %v in-use: %v", wc, err) | |||
| } | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if c.Closed() { | |||
| t.Errorf("manually in-use connection was closed") | |||
| } | |||
| if err := MarkIdle(wc); err != nil { | |||
| t.Fatalf("error marking %v idle: %v", wc, err) | |||
| } | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if !c.Closed() { | |||
| t.Error("manually idle connection was not closed") | |||
| } | |||
| } | |||
| func TestDisown(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, wc := singleConn(t, Manual) | |||
| if err := Disown(wc); err != nil { | |||
| t.Fatalf("error disowning connection: %v", err) | |||
| } | |||
| if err := l.CloseIdle(); err != nil { | |||
| t.Fatalf("error closing idle connections: %v", err) | |||
| } | |||
| if c.Closed() { | |||
| t.Errorf("disowned connection got closed") | |||
| } | |||
| } | |||
| func TestDrain(t *testing.T) { | |||
| t.Parallel() | |||
| l, _, wc := singleConn(t, Manual) | |||
| MarkInUse(wc) | |||
| start := time.Now() | |||
| go func() { | |||
| time.Sleep(50 * time.Millisecond) | |||
| MarkIdle(wc) | |||
| }() | |||
| if err := l.Drain(); err != nil { | |||
| t.Fatalf("error draining listener: %v", err) | |||
| } | |||
| end := time.Now() | |||
| if dt := end.Sub(start); dt < 50*time.Millisecond { | |||
| t.Errorf("expected at least 50ms wait, but got %v", dt) | |||
| } | |||
| } | |||
| func TestErrors(t *testing.T) { | |||
| t.Parallel() | |||
| _, c, wc := singleConn(t, Manual) | |||
| if err := Disown(c); err == nil { | |||
| t.Error("expected error when disowning unmanaged net.Conn") | |||
| } | |||
| if err := MarkIdle(c); err == nil { | |||
| t.Error("expected error when marking unmanaged net.Conn idle") | |||
| } | |||
| if err := MarkInUse(c); err == nil { | |||
| t.Error("expected error when marking unmanaged net.Conn in use") | |||
| } | |||
| if err := Disown(wc); err != nil { | |||
| t.Fatalf("unexpected error disowning socket: %v", err) | |||
| } | |||
| if err := Disown(wc); err == nil { | |||
| t.Error("expected error disowning socket twice") | |||
| } | |||
| } | |||
| func TestClose(t *testing.T) { | |||
| t.Parallel() | |||
| l, c, _ := singleConn(t, Manual) | |||
| if err := l.Close(); err != nil { | |||
| t.Fatalf("error while closing listener: %v", err) | |||
| } | |||
| if c.Closed() { | |||
| t.Error("connection closed when listener was?") | |||
| } | |||
| } | |||
| @ -0,0 +1,103 @@ | |||
| package listener | |||
| import ( | |||
| "fmt" | |||
| "math/rand" | |||
| "runtime" | |||
| "sync/atomic" | |||
| "testing" | |||
| "time" | |||
| ) | |||
| func init() { | |||
| // Just to make sure we get some variety | |||
| runtime.GOMAXPROCS(4 * runtime.NumCPU()) | |||
| } | |||
| // Chosen by random die roll | |||
| const seed = 4611413766552969250 | |||
| // This is mostly just fuzzing to see what happens. | |||
| func TestRace(t *testing.T) { | |||
| t.Parallel() | |||
| l := makeFakeListener("net.Listener") | |||
| wl := Wrap(l, Automatic) | |||
| var flag int32 | |||
| go func() { | |||
| for i := 0; ; i++ { | |||
| laddr := fmt.Sprintf("local%d", i) | |||
| raddr := fmt.Sprintf("remote%d", i) | |||
| c := makeFakeConn(laddr, raddr) | |||
| go func() { | |||
| defer func() { | |||
| if r := recover(); r != nil { | |||
| if atomic.LoadInt32(&flag) != 0 { | |||
| return | |||
| } | |||
| panic(r) | |||
| } | |||
| }() | |||
| l.Enqueue(c) | |||
| }() | |||
| wc, err := wl.Accept() | |||
| if err != nil { | |||
| if atomic.LoadInt32(&flag) != 0 { | |||
| return | |||
| } | |||
| t.Fatalf("error accepting connection: %v", err) | |||
| } | |||
| go func() { | |||
| for { | |||
| time.Sleep(50 * time.Millisecond) | |||
| c.AllowRead() | |||
| } | |||
| }() | |||
| go func(i int64) { | |||
| rng := rand.New(rand.NewSource(i + seed)) | |||
| buf := make([]byte, 1024) | |||
| for j := 0; j < 1024; j++ { | |||
| if _, err := wc.Read(buf); err != nil { | |||
| if atomic.LoadInt32(&flag) != 0 { | |||
| // Peaceful; the connection has | |||
| // probably been closed while | |||
| // idle | |||
| return | |||
| } | |||
| t.Errorf("error reading in conn %d: %v", | |||
| i, err) | |||
| } | |||
| time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond) | |||
| // This one is to make sure the connection | |||
| // hasn't closed underneath us | |||
| if _, err := wc.Read(buf); err != nil { | |||
| t.Errorf("error reading in conn %d: %v", | |||
| i, err) | |||
| } | |||
| MarkIdle(wc) | |||
| time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond) | |||
| } | |||
| }(int64(i)) | |||
| time.Sleep(time.Duration(i) * time.Millisecond / 2) | |||
| } | |||
| }() | |||
| if testing.Short() { | |||
| time.Sleep(2 * time.Second) | |||
| } else { | |||
| time.Sleep(10 * time.Second) | |||
| } | |||
| start := time.Now() | |||
| atomic.StoreInt32(&flag, 1) | |||
| wl.Close() | |||
| wl.Drain() | |||
| end := time.Now() | |||
| if dt := end.Sub(start); dt > 300*time.Millisecond { | |||
| t.Errorf("took %v to drain; expected shorter", dt) | |||
| } | |||
| } | |||
| @ -0,0 +1,71 @@ | |||
| package listener | |||
| import "sync" | |||
| type shard struct { | |||
| l *T | |||
| mu sync.Mutex | |||
| set map[*conn]struct{} | |||
| wg sync.WaitGroup | |||
| drain bool | |||
| // We pack shards together in an array, but we don't want them packed | |||
| // too closely, since we want to give each shard a dedicated CPU cache | |||
| // line. This amount of padding works out well for the common case of | |||
| // x64 processors (64-bit pointers with a 64-byte cache line). | |||
| _ [12]byte | |||
| } | |||
| // We pretty aggressively preallocate set entries in the hopes that we never | |||
| // have to allocate memory with the lock held. This is definitely a premature | |||
| // optimization and is probably misguided, but luckily it costs us essentially | |||
| // nothing. | |||
| const prealloc = 2048 | |||
| func (s *shard) init(l *T) { | |||
| s.l = l | |||
| s.set = make(map[*conn]struct{}, prealloc) | |||
| } | |||
| func (s *shard) markIdle(c *conn) (shouldClose bool) { | |||
| s.mu.Lock() | |||
| if s.drain { | |||
| s.mu.Unlock() | |||
| return true | |||
| } | |||
| s.set[c] = struct{}{} | |||
| s.mu.Unlock() | |||
| return false | |||
| } | |||
| func (s *shard) markInUse(c *conn) { | |||
| s.mu.Lock() | |||
| delete(s.set, c) | |||
| s.mu.Unlock() | |||
| } | |||
| func (s *shard) closeIdle(drain bool) { | |||
| s.mu.Lock() | |||
| if drain { | |||
| s.drain = true | |||
| } | |||
| set := s.set | |||
| s.set = make(map[*conn]struct{}, prealloc) | |||
| // We have to drop the shard lock here to avoid deadlock: we cannot | |||
| // acquire the shard lock after the connection lock, and the closeIfIdle | |||
| // call below will grab a connection lock. | |||
| s.mu.Unlock() | |||
| for conn := range set { | |||
| // This might return an error (from Close), but I don't think we | |||
| // can do anything about it, so let's just pretend it didn't | |||
| // happen. (I also expect that most errors returned in this way | |||
| // are going to be pretty boring) | |||
| conn.closeIfIdle() | |||
| } | |||
| } | |||
| func (s *shard) wait() { | |||
| s.wg.Wait() | |||
| } | |||
| @ -1,185 +0,0 @@ | |||
| package graceful | |||
| import ( | |||
| "io" | |||
| "net" | |||
| "sync" | |||
| "sync/atomic" | |||
| "time" | |||
| ) | |||
| type listener struct { | |||
| net.Listener | |||
| } | |||
| // WrapListener wraps an arbitrary net.Listener for use with graceful shutdowns. | |||
| // All net.Conn's Accept()ed by this listener will be auto-wrapped as if | |||
| // WrapConn() were called on them. | |||
| func WrapListener(l net.Listener) net.Listener { | |||
| return listener{l} | |||
| } | |||
| func (l listener) Accept() (net.Conn, error) { | |||
| conn, err := l.Listener.Accept() | |||
| return WrapConn(conn), err | |||
| } | |||
| type conn struct { | |||
| mu sync.Mutex | |||
| cs *connSet | |||
| net.Conn | |||
| id uint64 | |||
| busy, die bool | |||
| dead bool | |||
| hijacked bool | |||
| } | |||
| /* | |||
| WrapConn wraps an arbitrary connection for use with graceful shutdowns. The | |||
| graceful shutdown process will ensure that this connection is closed before | |||
| terminating the process. | |||
| In order to use this function, you must call SetReadDeadline() before the call | |||
| to Read() you might make to read a new request off the wire. The connection is | |||
| eligible for abrupt closing at any point between when the call to | |||
| SetReadDeadline() returns and when the call to Read returns with new data. It | |||
| does not matter what deadline is given to SetReadDeadline()--if a deadline is | |||
| inappropriate, providing one extremely far into the future will suffice. | |||
| Unfortunately, this means that it's difficult to use SetReadDeadline() in a | |||
| great many perfectly reasonable circumstances, such as to extend a deadline | |||
| after more data has been read, without the connection being eligible for | |||
| "graceful" termination at an undesirable time. Since this package was written | |||
| explicitly to target net/http, which does not as of this writing do any of this, | |||
| fixing the semantics here does not seem especially urgent. | |||
| */ | |||
| func WrapConn(c net.Conn) net.Conn { | |||
| if c == nil { | |||
| return nil | |||
| } | |||
| // Avoid race with termination code. | |||
| wgLock.Lock() | |||
| defer wgLock.Unlock() | |||
| // Determine whether the app is shutting down. | |||
| if acceptingRequests { | |||
| wg.Add(1) | |||
| return &conn{ | |||
| Conn: c, | |||
| id: atomic.AddUint64(&idleSet.id, 1), | |||
| } | |||
| } else { | |||
| return nil | |||
| } | |||
| } | |||
| func (c *conn) Read(b []byte) (n int, err error) { | |||
| c.mu.Lock() | |||
| if !c.hijacked { | |||
| defer func() { | |||
| c.mu.Lock() | |||
| if c.hijacked { | |||
| // It's a little unclear to me how this case | |||
| // would happen, but we *did* drop the lock, so | |||
| // let's play it safe. | |||
| return | |||
| } | |||
| if c.dead { | |||
| // Dead sockets don't tell tales. This is to | |||
| // prevent the case where a Read manages to suck | |||
| // an entire request off the wire in a race with | |||
| // someone trying to close idle connections. | |||
| // Whoever grabs the conn lock first wins, and | |||
| // if that's the closing process, we need to | |||
| // "take back" the read. | |||
| n = 0 | |||
| err = io.EOF | |||
| } else { | |||
| idleSet.markBusy(c) | |||
| } | |||
| c.mu.Unlock() | |||
| }() | |||
| } | |||
| c.mu.Unlock() | |||
| return c.Conn.Read(b) | |||
| } | |||
| func (c *conn) SetReadDeadline(t time.Time) error { | |||
| c.mu.Lock() | |||
| if !c.hijacked { | |||
| defer c.markIdle() | |||
| } | |||
| c.mu.Unlock() | |||
| return c.Conn.SetReadDeadline(t) | |||
| } | |||
| func (c *conn) Close() error { | |||
| kill := false | |||
| c.mu.Lock() | |||
| kill, c.dead = !c.dead, true | |||
| idleSet.markBusy(c) | |||
| c.mu.Unlock() | |||
| if kill { | |||
| defer wg.Done() | |||
| } | |||
| return c.Conn.Close() | |||
| } | |||
| type writerOnly struct { | |||
| w io.Writer | |||
| } | |||
| func (w writerOnly) Write(buf []byte) (int, error) { | |||
| return w.w.Write(buf) | |||
| } | |||
| func (c *conn) ReadFrom(r io.Reader) (int64, error) { | |||
| if rf, ok := c.Conn.(io.ReaderFrom); ok { | |||
| return rf.ReadFrom(r) | |||
| } | |||
| return io.Copy(writerOnly{c}, r) | |||
| } | |||
| func (c *conn) markIdle() { | |||
| kill := false | |||
| c.mu.Lock() | |||
| idleSet.markIdle(c) | |||
| if c.die { | |||
| kill, c.dead = !c.dead, true | |||
| } | |||
| c.mu.Unlock() | |||
| if kill { | |||
| defer wg.Done() | |||
| c.Conn.Close() | |||
| } | |||
| } | |||
| func (c *conn) closeIfIdle() { | |||
| kill := false | |||
| c.mu.Lock() | |||
| c.die = true | |||
| if !c.busy && !c.hijacked { | |||
| kill, c.dead = !c.dead, true | |||
| } | |||
| c.mu.Unlock() | |||
| if kill { | |||
| defer wg.Done() | |||
| c.Conn.Close() | |||
| } | |||
| } | |||
| func (c *conn) hijack() net.Conn { | |||
| c.mu.Lock() | |||
| idleSet.markBusy(c) | |||
| c.hijacked = true | |||
| c.mu.Unlock() | |||
| return c.Conn | |||
| } | |||
| @ -1,12 +0,0 @@ | |||
| // +build race | |||
| package graceful | |||
| import "testing" | |||
| func TestWaitGroupRace(t *testing.T) { | |||
| go func() { | |||
| go WrapConn(fakeConn{}).Close() | |||
| }() | |||
| Shutdown() | |||
| } | |||