This change refactors package graceful into two packages: one very well tested package that deals with graceful shutdown of arbitrary net.Listeners in the abstract, and one less-well-tested package that works with the nitty-gritty details of net/http and signal handling. This is a breaking API change for advanced users of package graceful: the WrapConn function no longer exists. This shouldn't affect most users or use cases.
| @ -1,80 +0,0 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "runtime" | |||||
| "sync" | |||||
| ) | |||||
| type connShard struct { | |||||
| mu sync.Mutex | |||||
| // We sort of abuse this field to also act as a "please shut down" flag. | |||||
| // If it's nil, you should die at your earliest opportunity. | |||||
| set map[*conn]struct{} | |||||
| } | |||||
| type connSet struct { | |||||
| // This is an incrementing connection counter so we round-robin | |||||
| // connections across shards. Use atomic when touching it. | |||||
| id uint64 | |||||
| shards []*connShard | |||||
| } | |||||
| var idleSet connSet | |||||
| // We pretty aggressively preallocate set entries in the hopes that we never | |||||
| // have to allocate memory with the lock held. This is definitely a premature | |||||
| // optimization and is probably misguided, but luckily it costs us essentially | |||||
| // nothing. | |||||
| const prealloc = 2048 | |||||
| func init() { | |||||
| // To keep the expected contention rate constant we'd have to grow this | |||||
| // as numcpu**2. In practice, CPU counts don't generally grow without | |||||
| // bound, and contention is probably going to be small enough that | |||||
| // nobody cares anyways. | |||||
| idleSet.shards = make([]*connShard, 2*runtime.NumCPU()) | |||||
| for i := range idleSet.shards { | |||||
| idleSet.shards[i] = &connShard{ | |||||
| set: make(map[*conn]struct{}, prealloc), | |||||
| } | |||||
| } | |||||
| } | |||||
| func (cs connSet) markIdle(c *conn) { | |||||
| c.busy = false | |||||
| shard := cs.shards[int(c.id%uint64(len(cs.shards)))] | |||||
| shard.mu.Lock() | |||||
| if shard.set == nil { | |||||
| shard.mu.Unlock() | |||||
| c.die = true | |||||
| } else { | |||||
| shard.set[c] = struct{}{} | |||||
| shard.mu.Unlock() | |||||
| } | |||||
| } | |||||
| func (cs connSet) markBusy(c *conn) { | |||||
| c.busy = true | |||||
| shard := cs.shards[int(c.id%uint64(len(cs.shards)))] | |||||
| shard.mu.Lock() | |||||
| if shard.set == nil { | |||||
| shard.mu.Unlock() | |||||
| c.die = true | |||||
| } else { | |||||
| delete(shard.set, c) | |||||
| shard.mu.Unlock() | |||||
| } | |||||
| } | |||||
| func (cs connSet) killall() { | |||||
| for _, shard := range cs.shards { | |||||
| shard.mu.Lock() | |||||
| set := shard.set | |||||
| shard.set = nil | |||||
| shard.mu.Unlock() | |||||
| for conn := range set { | |||||
| conn.closeIfIdle() | |||||
| } | |||||
| } | |||||
| } | |||||
| @ -1,63 +0,0 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "net" | |||||
| "time" | |||||
| ) | |||||
| // Stub out a net.Conn. This is going to be painful. | |||||
| type fakeAddr struct{} | |||||
| func (f fakeAddr) Network() string { | |||||
| return "fake" | |||||
| } | |||||
| func (f fakeAddr) String() string { | |||||
| return "fake" | |||||
| } | |||||
| type fakeConn struct { | |||||
| onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func() | |||||
| onSetDeadline, onSetReadDeadline, onSetWriteDeadline func() | |||||
| } | |||||
| // Here's my number, so... | |||||
| func callMeMaybe(f func()) { | |||||
| // I apologize for nothing. | |||||
| if f != nil { | |||||
| f() | |||||
| } | |||||
| } | |||||
| func (f fakeConn) Read(b []byte) (int, error) { | |||||
| callMeMaybe(f.onRead) | |||||
| return len(b), nil | |||||
| } | |||||
| func (f fakeConn) Write(b []byte) (int, error) { | |||||
| callMeMaybe(f.onWrite) | |||||
| return len(b), nil | |||||
| } | |||||
| func (f fakeConn) Close() error { | |||||
| callMeMaybe(f.onClose) | |||||
| return nil | |||||
| } | |||||
| func (f fakeConn) LocalAddr() net.Addr { | |||||
| callMeMaybe(f.onLocalAddr) | |||||
| return fakeAddr{} | |||||
| } | |||||
| func (f fakeConn) RemoteAddr() net.Addr { | |||||
| callMeMaybe(f.onRemoteAddr) | |||||
| return fakeAddr{} | |||||
| } | |||||
| func (f fakeConn) SetDeadline(t time.Time) error { | |||||
| callMeMaybe(f.onSetDeadline) | |||||
| return nil | |||||
| } | |||||
| func (f fakeConn) SetReadDeadline(t time.Time) error { | |||||
| callMeMaybe(f.onSetReadDeadline) | |||||
| return nil | |||||
| } | |||||
| func (f fakeConn) SetWriteDeadline(t time.Time) error { | |||||
| callMeMaybe(f.onSetWriteDeadline) | |||||
| return nil | |||||
| } | |||||
| @ -0,0 +1,147 @@ | |||||
| package listener | |||||
| import ( | |||||
| "errors" | |||||
| "io" | |||||
| "net" | |||||
| "sync" | |||||
| "time" | |||||
| ) | |||||
| type conn struct { | |||||
| net.Conn | |||||
| shard *shard | |||||
| mode mode | |||||
| mu sync.Mutex // Protects the state machine below | |||||
| busy bool // connection is in use (i.e., not idle) | |||||
| closed bool // connection is closed | |||||
| disowned bool // if true, this connection is no longer under our management | |||||
| } | |||||
| // This intentionally looks a lot like the one in package net. | |||||
| var errClosing = errors.New("use of closed network connection") | |||||
| func (c *conn) init() error { | |||||
| c.shard.wg.Add(1) | |||||
| if shouldExit := c.shard.markIdle(c); shouldExit { | |||||
| c.Close() | |||||
| return errClosing | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (c *conn) Read(b []byte) (n int, err error) { | |||||
| defer func() { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| if c.disowned { | |||||
| return | |||||
| } | |||||
| // This protects against a Close/Read race. We're not really | |||||
| // concerned about the general case (it's fundamentally racy), | |||||
| // but are mostly trying to prevent a race between a new request | |||||
| // getting read off the wire in one thread while the connection | |||||
| // is being gracefully shut down in another. | |||||
| if c.closed && err == nil { | |||||
| n = 0 | |||||
| err = errClosing | |||||
| return | |||||
| } | |||||
| if c.mode != Manual && !c.busy && !c.closed { | |||||
| c.busy = true | |||||
| c.shard.markInUse(c) | |||||
| } | |||||
| }() | |||||
| return c.Conn.Read(b) | |||||
| } | |||||
| func (c *conn) Close() error { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| if !c.closed && !c.disowned { | |||||
| c.closed = true | |||||
| c.shard.markInUse(c) | |||||
| defer c.shard.wg.Done() | |||||
| } | |||||
| return c.Conn.Close() | |||||
| } | |||||
| func (c *conn) SetReadDeadline(t time.Time) error { | |||||
| c.mu.Lock() | |||||
| if !c.disowned && c.mode == Deadline { | |||||
| defer c.markIdle() | |||||
| } | |||||
| c.mu.Unlock() | |||||
| return c.Conn.SetReadDeadline(t) | |||||
| } | |||||
| func (c *conn) ReadFrom(r io.Reader) (int64, error) { | |||||
| return io.Copy(c.Conn, r) | |||||
| } | |||||
| func (c *conn) markIdle() { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| if !c.busy { | |||||
| return | |||||
| } | |||||
| c.busy = false | |||||
| if exit := c.shard.markIdle(c); exit && !c.closed && !c.disowned { | |||||
| c.closed = true | |||||
| c.shard.markInUse(c) | |||||
| defer c.shard.wg.Done() | |||||
| c.Conn.Close() | |||||
| return | |||||
| } | |||||
| } | |||||
| func (c *conn) markInUse() { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| if !c.busy && !c.closed && !c.disowned { | |||||
| c.busy = true | |||||
| c.shard.markInUse(c) | |||||
| } | |||||
| } | |||||
| func (c *conn) closeIfIdle() error { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| if !c.busy && !c.closed && !c.disowned { | |||||
| c.closed = true | |||||
| c.shard.markInUse(c) | |||||
| defer c.shard.wg.Done() | |||||
| return c.Conn.Close() | |||||
| } | |||||
| return nil | |||||
| } | |||||
| var errAlreadyDisowned = errors.New("listener: conn already disowned") | |||||
| func (c *conn) disown() error { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| if c.disowned { | |||||
| return errAlreadyDisowned | |||||
| } | |||||
| c.shard.markInUse(c) | |||||
| c.disowned = true | |||||
| c.shard.wg.Done() | |||||
| return nil | |||||
| } | |||||
| @ -0,0 +1,182 @@ | |||||
| package listener | |||||
| import ( | |||||
| "io" | |||||
| "strings" | |||||
| "testing" | |||||
| "time" | |||||
| ) | |||||
| func TestManualRead(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Manual) | |||||
| go c.AllowRead() | |||||
| wc.Read(make([]byte, 1024)) | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if !c.Closed() { | |||||
| t.Error("Read() should not make connection not-idle") | |||||
| } | |||||
| } | |||||
| func TestAutomaticRead(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Automatic) | |||||
| go c.AllowRead() | |||||
| wc.Read(make([]byte, 1024)) | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if c.Closed() { | |||||
| t.Error("expected Read() to mark connection as in-use") | |||||
| } | |||||
| } | |||||
| func TestDeadlineRead(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Deadline) | |||||
| go c.AllowRead() | |||||
| if _, err := wc.Read(make([]byte, 1024)); err != nil { | |||||
| t.Fatalf("error reading from connection: %v", err) | |||||
| } | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if c.Closed() { | |||||
| t.Error("expected Read() to mark connection as in-use") | |||||
| } | |||||
| } | |||||
| func TestDisownedRead(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Deadline) | |||||
| if err := Disown(wc); err != nil { | |||||
| t.Fatalf("unexpected error disowning conn: %v", err) | |||||
| } | |||||
| if err := l.Close(); err != nil { | |||||
| t.Fatalf("unexpected error closing listener: %v", err) | |||||
| } | |||||
| if err := l.Drain(); err != nil { | |||||
| t.Fatalf("unexpected error draining listener: %v", err) | |||||
| } | |||||
| go c.AllowRead() | |||||
| if _, err := wc.Read(make([]byte, 1024)); err != nil { | |||||
| t.Fatalf("error reading from connection: %v", err) | |||||
| } | |||||
| } | |||||
| func TestCloseConn(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, _, wc := singleConn(t, Deadline) | |||||
| if err := MarkInUse(wc); err != nil { | |||||
| t.Fatalf("error marking conn in use: %v", err) | |||||
| } | |||||
| if err := wc.Close(); err != nil { | |||||
| t.Errorf("error closing connection: %v", err) | |||||
| } | |||||
| // This will hang if wc.Close() doesn't un-track the connection | |||||
| if err := l.Drain(); err != nil { | |||||
| t.Errorf("error draining listener: %v", err) | |||||
| } | |||||
| } | |||||
| func TestManualReadDeadline(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Manual) | |||||
| if err := MarkInUse(wc); err != nil { | |||||
| t.Fatalf("error marking connection in use: %v", err) | |||||
| } | |||||
| if err := wc.SetReadDeadline(time.Now()); err != nil { | |||||
| t.Fatalf("error setting read deadline: %v", err) | |||||
| } | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if c.Closed() { | |||||
| t.Error("SetReadDeadline() should not mark manual conn as idle") | |||||
| } | |||||
| } | |||||
| func TestAutomaticReadDeadline(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Automatic) | |||||
| if err := MarkInUse(wc); err != nil { | |||||
| t.Fatalf("error marking connection in use: %v", err) | |||||
| } | |||||
| if err := wc.SetReadDeadline(time.Now()); err != nil { | |||||
| t.Fatalf("error setting read deadline: %v", err) | |||||
| } | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if c.Closed() { | |||||
| t.Error("SetReadDeadline() should not mark automatic conn as idle") | |||||
| } | |||||
| } | |||||
| func TestDeadlineReadDeadline(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Deadline) | |||||
| if err := MarkInUse(wc); err != nil { | |||||
| t.Fatalf("error marking connection in use: %v", err) | |||||
| } | |||||
| if err := wc.SetReadDeadline(time.Now()); err != nil { | |||||
| t.Fatalf("error setting read deadline: %v", err) | |||||
| } | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if !c.Closed() { | |||||
| t.Error("SetReadDeadline() should mark deadline conn as idle") | |||||
| } | |||||
| } | |||||
| type readerConn struct { | |||||
| fakeConn | |||||
| } | |||||
| func (rc *readerConn) ReadFrom(r io.Reader) (int64, error) { | |||||
| return 123, nil | |||||
| } | |||||
| func TestReadFrom(t *testing.T) { | |||||
| t.Parallel() | |||||
| l := makeFakeListener("net.Listener") | |||||
| wl := Wrap(l, Manual) | |||||
| c := &readerConn{ | |||||
| fakeConn{ | |||||
| read: make(chan struct{}), | |||||
| write: make(chan struct{}), | |||||
| closed: make(chan struct{}), | |||||
| me: fakeAddr{"tcp", "local"}, | |||||
| you: fakeAddr{"tcp", "remote"}, | |||||
| }, | |||||
| } | |||||
| go l.Enqueue(c) | |||||
| wc, err := wl.Accept() | |||||
| if err != nil { | |||||
| t.Fatalf("error accepting connection: %v", err) | |||||
| } | |||||
| // The io.MultiReader is a convenient hack to ensure that we're using | |||||
| // our ReadFrom, not strings.Reader's WriteTo. | |||||
| r := io.MultiReader(strings.NewReader("hello world")) | |||||
| if _, err := io.Copy(wc, r); err != nil { | |||||
| t.Fatalf("error copying: %v", err) | |||||
| } | |||||
| } | |||||
| @ -0,0 +1,123 @@ | |||||
| package listener | |||||
| import ( | |||||
| "net" | |||||
| "time" | |||||
| ) | |||||
| type fakeAddr struct { | |||||
| network, addr string | |||||
| } | |||||
| func (f fakeAddr) Network() string { | |||||
| return f.network | |||||
| } | |||||
| func (f fakeAddr) String() string { | |||||
| return f.addr | |||||
| } | |||||
| type fakeListener struct { | |||||
| ch chan net.Conn | |||||
| closed chan struct{} | |||||
| addr net.Addr | |||||
| } | |||||
| func makeFakeListener(addr string) *fakeListener { | |||||
| a := fakeAddr{"tcp", addr} | |||||
| return &fakeListener{ | |||||
| ch: make(chan net.Conn), | |||||
| closed: make(chan struct{}), | |||||
| addr: a, | |||||
| } | |||||
| } | |||||
| func (f *fakeListener) Accept() (net.Conn, error) { | |||||
| select { | |||||
| case c := <-f.ch: | |||||
| return c, nil | |||||
| case <-f.closed: | |||||
| return nil, errClosing | |||||
| } | |||||
| } | |||||
| func (f *fakeListener) Close() error { | |||||
| close(f.closed) | |||||
| return nil | |||||
| } | |||||
| func (f *fakeListener) Addr() net.Addr { | |||||
| return f.addr | |||||
| } | |||||
| func (f *fakeListener) Enqueue(c net.Conn) { | |||||
| f.ch <- c | |||||
| } | |||||
| type fakeConn struct { | |||||
| read, write, closed chan struct{} | |||||
| me, you net.Addr | |||||
| } | |||||
| func makeFakeConn(me, you string) *fakeConn { | |||||
| return &fakeConn{ | |||||
| read: make(chan struct{}), | |||||
| write: make(chan struct{}), | |||||
| closed: make(chan struct{}), | |||||
| me: fakeAddr{"tcp", me}, | |||||
| you: fakeAddr{"tcp", you}, | |||||
| } | |||||
| } | |||||
| func (f *fakeConn) Read(buf []byte) (int, error) { | |||||
| select { | |||||
| case <-f.read: | |||||
| return len(buf), nil | |||||
| case <-f.closed: | |||||
| return 0, errClosing | |||||
| } | |||||
| } | |||||
| func (f *fakeConn) Write(buf []byte) (int, error) { | |||||
| select { | |||||
| case <-f.write: | |||||
| return len(buf), nil | |||||
| case <-f.closed: | |||||
| return 0, errClosing | |||||
| } | |||||
| } | |||||
| func (f *fakeConn) Close() error { | |||||
| close(f.closed) | |||||
| return nil | |||||
| } | |||||
| func (f *fakeConn) LocalAddr() net.Addr { | |||||
| return f.me | |||||
| } | |||||
| func (f *fakeConn) RemoteAddr() net.Addr { | |||||
| return f.you | |||||
| } | |||||
| func (f *fakeConn) SetDeadline(t time.Time) error { | |||||
| return nil | |||||
| } | |||||
| func (f *fakeConn) SetReadDeadline(t time.Time) error { | |||||
| return nil | |||||
| } | |||||
| func (f *fakeConn) SetWriteDeadline(t time.Time) error { | |||||
| return nil | |||||
| } | |||||
| func (f *fakeConn) Closed() bool { | |||||
| select { | |||||
| case <-f.closed: | |||||
| return true | |||||
| default: | |||||
| return false | |||||
| } | |||||
| } | |||||
| func (f *fakeConn) AllowRead() { | |||||
| f.read <- struct{}{} | |||||
| } | |||||
| func (f *fakeConn) AllowWrite() { | |||||
| f.write <- struct{}{} | |||||
| } | |||||
| @ -0,0 +1,165 @@ | |||||
| /* | |||||
| Package listener provides a way to incorporate graceful shutdown to any | |||||
| net.Listener. | |||||
| This package provides low-level primitives, not a high-level API. If you're | |||||
| looking for a package that provides graceful shutdown for HTTP servers, I | |||||
| recommend this package's parent package, github.com/zenazn/goji/graceful. | |||||
| */ | |||||
| package listener | |||||
| import ( | |||||
| "errors" | |||||
| "net" | |||||
| "runtime" | |||||
| "sync" | |||||
| "sync/atomic" | |||||
| ) | |||||
| type mode int8 | |||||
| const ( | |||||
| // Manual mode is completely manual: users must use use MarkIdle and | |||||
| // MarkInUse to indicate when connections are busy servicing requests or | |||||
| // are eligible for termination. | |||||
| Manual mode = iota | |||||
| // Automatic mode is what most users probably want: calling Read on a | |||||
| // connection will mark it as in use, but users must manually call | |||||
| // MarkIdle to indicate when connections may be safely closed. | |||||
| Automatic | |||||
| // Deadline mode is like automatic mode, except that calling | |||||
| // SetReadDeadline on a connection will also mark it as being idle. This | |||||
| // is useful for many servers like net/http, where SetReadDeadline is | |||||
| // used to implement read timeouts on new requests. | |||||
| Deadline | |||||
| ) | |||||
| // Wrap a net.Listener, returning a net.Listener which supports idle connection | |||||
| // tracking and shutdown. Listeners can be placed in to one of three modes, | |||||
| // exported as variables from this package: most users will probably want the | |||||
| // "Automatic" mode. | |||||
| func Wrap(l net.Listener, m mode) *T { | |||||
| t := &T{ | |||||
| l: l, | |||||
| mode: m, | |||||
| // To keep the expected contention rate constant we'd have to | |||||
| // grow this as numcpu**2. In practice, CPU counts don't | |||||
| // generally grow without bound, and contention is probably | |||||
| // going to be small enough that nobody cares anyways. | |||||
| shards: make([]shard, 2*runtime.NumCPU()), | |||||
| } | |||||
| for i := range t.shards { | |||||
| t.shards[i].init(t) | |||||
| } | |||||
| return t | |||||
| } | |||||
| // T is the type of this package's graceful listeners. | |||||
| type T struct { | |||||
| mu sync.Mutex | |||||
| l net.Listener | |||||
| // TODO(carl): a count of currently outstanding connections. | |||||
| connCount uint64 | |||||
| shards []shard | |||||
| mode mode | |||||
| } | |||||
| var _ net.Listener = &T{} | |||||
| // Accept waits for and returns the next connection to the listener. The | |||||
| // returned net.Conn's idleness is tracked, and idle connections can be closed | |||||
| // from the associated T. | |||||
| func (t *T) Accept() (net.Conn, error) { | |||||
| c, err := t.l.Accept() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| connID := atomic.AddUint64(&t.connCount, 1) | |||||
| shard := &t.shards[int(connID)%len(t.shards)] | |||||
| wc := &conn{ | |||||
| Conn: c, | |||||
| shard: shard, | |||||
| mode: t.mode, | |||||
| } | |||||
| if err = wc.init(); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return wc, nil | |||||
| } | |||||
| // Addr returns the wrapped listener's network address. | |||||
| func (t *T) Addr() net.Addr { | |||||
| return t.l.Addr() | |||||
| } | |||||
| // Close closes the wrapped listener. | |||||
| func (t *T) Close() error { | |||||
| return t.l.Close() | |||||
| } | |||||
| // CloseIdle closes all connections that are currently marked as being idle. It, | |||||
| // however, makes no attempt to wait for in-use connections to die, or to close | |||||
| // connections which become idle in the future. Call this function if you're | |||||
| // interested in shedding useless connections, but otherwise wish to continue | |||||
| // serving requests. | |||||
| func (t *T) CloseIdle() error { | |||||
| for i := range t.shards { | |||||
| t.shards[i].closeIdle(false) | |||||
| } | |||||
| // Not sure if returning errors is actually useful here :/ | |||||
| return nil | |||||
| } | |||||
| // Drain immediately closes all idle connections, prevents new connections from | |||||
| // being accepted, and waits for all outstanding connections to finish. | |||||
| // | |||||
| // Once a listener has been drained, there is no way to re-enable it. You | |||||
| // probably want to Close the listener before draining it, otherwise new | |||||
| // connections will be accepted and immediately closed. | |||||
| func (t *T) Drain() error { | |||||
| for i := range t.shards { | |||||
| t.shards[i].closeIdle(true) | |||||
| } | |||||
| for i := range t.shards { | |||||
| t.shards[i].wait() | |||||
| } | |||||
| return nil | |||||
| } | |||||
| var notManagedErr = errors.New("listener: passed net.Conn is not managed by us") | |||||
| // Disown causes a connection to no longer be tracked by the listener. The | |||||
| // passed connection must have been returned by a call to Accept from this | |||||
| // listener. | |||||
| func Disown(c net.Conn) error { | |||||
| if cn, ok := c.(*conn); ok { | |||||
| return cn.disown() | |||||
| } | |||||
| return notManagedErr | |||||
| } | |||||
| // MarkIdle marks the given connection as being idle, and therefore eligible for | |||||
| // closing at any time. The passed connection must have been returned by a call | |||||
| // to Accept from this listener. | |||||
| func MarkIdle(c net.Conn) error { | |||||
| if cn, ok := c.(*conn); ok { | |||||
| cn.markIdle() | |||||
| return nil | |||||
| } | |||||
| return notManagedErr | |||||
| } | |||||
| // MarkInUse marks this connection as being in use, removing it from the set of | |||||
| // connections which are eligible for closing. The passed connection must have | |||||
| // been returned by a call to Accept from this listener. | |||||
| func MarkInUse(c net.Conn) error { | |||||
| if cn, ok := c.(*conn); ok { | |||||
| cn.markInUse() | |||||
| return nil | |||||
| } | |||||
| return notManagedErr | |||||
| } | |||||
| @ -0,0 +1,143 @@ | |||||
| package listener | |||||
| import ( | |||||
| "net" | |||||
| "testing" | |||||
| "time" | |||||
| ) | |||||
| // Helper for tests acting on a single accepted connection | |||||
| func singleConn(t *testing.T, m mode) (*T, *fakeConn, net.Conn) { | |||||
| l := makeFakeListener("net.Listener") | |||||
| wl := Wrap(l, m) | |||||
| c := makeFakeConn("local", "remote") | |||||
| go l.Enqueue(c) | |||||
| wc, err := wl.Accept() | |||||
| if err != nil { | |||||
| t.Fatalf("error accepting connection: %v", err) | |||||
| } | |||||
| return wl, c, wc | |||||
| } | |||||
| func TestAddr(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Manual) | |||||
| if a := l.Addr(); a.String() != "net.Listener" { | |||||
| t.Errorf("addr was %v, wanted net.Listener", a) | |||||
| } | |||||
| if c.LocalAddr() != wc.LocalAddr() { | |||||
| t.Errorf("local addresses don't match: %v, %v", c.LocalAddr(), | |||||
| wc.LocalAddr()) | |||||
| } | |||||
| if c.RemoteAddr() != wc.RemoteAddr() { | |||||
| t.Errorf("remote addresses don't match: %v, %v", c.RemoteAddr(), | |||||
| wc.RemoteAddr()) | |||||
| } | |||||
| } | |||||
| func TestBasicCloseIdle(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, _ := singleConn(t, Manual) | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if !c.Closed() { | |||||
| t.Error("idle connection not closed") | |||||
| } | |||||
| } | |||||
| func TestMark(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Manual) | |||||
| if err := MarkInUse(wc); err != nil { | |||||
| t.Fatalf("error marking %v in-use: %v", wc, err) | |||||
| } | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if c.Closed() { | |||||
| t.Errorf("manually in-use connection was closed") | |||||
| } | |||||
| if err := MarkIdle(wc); err != nil { | |||||
| t.Fatalf("error marking %v idle: %v", wc, err) | |||||
| } | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if !c.Closed() { | |||||
| t.Error("manually idle connection was not closed") | |||||
| } | |||||
| } | |||||
| func TestDisown(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, wc := singleConn(t, Manual) | |||||
| if err := Disown(wc); err != nil { | |||||
| t.Fatalf("error disowning connection: %v", err) | |||||
| } | |||||
| if err := l.CloseIdle(); err != nil { | |||||
| t.Fatalf("error closing idle connections: %v", err) | |||||
| } | |||||
| if c.Closed() { | |||||
| t.Errorf("disowned connection got closed") | |||||
| } | |||||
| } | |||||
| func TestDrain(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, _, wc := singleConn(t, Manual) | |||||
| MarkInUse(wc) | |||||
| start := time.Now() | |||||
| go func() { | |||||
| time.Sleep(50 * time.Millisecond) | |||||
| MarkIdle(wc) | |||||
| }() | |||||
| if err := l.Drain(); err != nil { | |||||
| t.Fatalf("error draining listener: %v", err) | |||||
| } | |||||
| end := time.Now() | |||||
| if dt := end.Sub(start); dt < 50*time.Millisecond { | |||||
| t.Errorf("expected at least 50ms wait, but got %v", dt) | |||||
| } | |||||
| } | |||||
| func TestErrors(t *testing.T) { | |||||
| t.Parallel() | |||||
| _, c, wc := singleConn(t, Manual) | |||||
| if err := Disown(c); err == nil { | |||||
| t.Error("expected error when disowning unmanaged net.Conn") | |||||
| } | |||||
| if err := MarkIdle(c); err == nil { | |||||
| t.Error("expected error when marking unmanaged net.Conn idle") | |||||
| } | |||||
| if err := MarkInUse(c); err == nil { | |||||
| t.Error("expected error when marking unmanaged net.Conn in use") | |||||
| } | |||||
| if err := Disown(wc); err != nil { | |||||
| t.Fatalf("unexpected error disowning socket: %v", err) | |||||
| } | |||||
| if err := Disown(wc); err == nil { | |||||
| t.Error("expected error disowning socket twice") | |||||
| } | |||||
| } | |||||
| func TestClose(t *testing.T) { | |||||
| t.Parallel() | |||||
| l, c, _ := singleConn(t, Manual) | |||||
| if err := l.Close(); err != nil { | |||||
| t.Fatalf("error while closing listener: %v", err) | |||||
| } | |||||
| if c.Closed() { | |||||
| t.Error("connection closed when listener was?") | |||||
| } | |||||
| } | |||||
| @ -0,0 +1,103 @@ | |||||
| package listener | |||||
| import ( | |||||
| "fmt" | |||||
| "math/rand" | |||||
| "runtime" | |||||
| "sync/atomic" | |||||
| "testing" | |||||
| "time" | |||||
| ) | |||||
| func init() { | |||||
| // Just to make sure we get some variety | |||||
| runtime.GOMAXPROCS(4 * runtime.NumCPU()) | |||||
| } | |||||
| // Chosen by random die roll | |||||
| const seed = 4611413766552969250 | |||||
| // This is mostly just fuzzing to see what happens. | |||||
| func TestRace(t *testing.T) { | |||||
| t.Parallel() | |||||
| l := makeFakeListener("net.Listener") | |||||
| wl := Wrap(l, Automatic) | |||||
| var flag int32 | |||||
| go func() { | |||||
| for i := 0; ; i++ { | |||||
| laddr := fmt.Sprintf("local%d", i) | |||||
| raddr := fmt.Sprintf("remote%d", i) | |||||
| c := makeFakeConn(laddr, raddr) | |||||
| go func() { | |||||
| defer func() { | |||||
| if r := recover(); r != nil { | |||||
| if atomic.LoadInt32(&flag) != 0 { | |||||
| return | |||||
| } | |||||
| panic(r) | |||||
| } | |||||
| }() | |||||
| l.Enqueue(c) | |||||
| }() | |||||
| wc, err := wl.Accept() | |||||
| if err != nil { | |||||
| if atomic.LoadInt32(&flag) != 0 { | |||||
| return | |||||
| } | |||||
| t.Fatalf("error accepting connection: %v", err) | |||||
| } | |||||
| go func() { | |||||
| for { | |||||
| time.Sleep(50 * time.Millisecond) | |||||
| c.AllowRead() | |||||
| } | |||||
| }() | |||||
| go func(i int64) { | |||||
| rng := rand.New(rand.NewSource(i + seed)) | |||||
| buf := make([]byte, 1024) | |||||
| for j := 0; j < 1024; j++ { | |||||
| if _, err := wc.Read(buf); err != nil { | |||||
| if atomic.LoadInt32(&flag) != 0 { | |||||
| // Peaceful; the connection has | |||||
| // probably been closed while | |||||
| // idle | |||||
| return | |||||
| } | |||||
| t.Errorf("error reading in conn %d: %v", | |||||
| i, err) | |||||
| } | |||||
| time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond) | |||||
| // This one is to make sure the connection | |||||
| // hasn't closed underneath us | |||||
| if _, err := wc.Read(buf); err != nil { | |||||
| t.Errorf("error reading in conn %d: %v", | |||||
| i, err) | |||||
| } | |||||
| MarkIdle(wc) | |||||
| time.Sleep(time.Duration(rng.Intn(100)) * time.Millisecond) | |||||
| } | |||||
| }(int64(i)) | |||||
| time.Sleep(time.Duration(i) * time.Millisecond / 2) | |||||
| } | |||||
| }() | |||||
| if testing.Short() { | |||||
| time.Sleep(2 * time.Second) | |||||
| } else { | |||||
| time.Sleep(10 * time.Second) | |||||
| } | |||||
| start := time.Now() | |||||
| atomic.StoreInt32(&flag, 1) | |||||
| wl.Close() | |||||
| wl.Drain() | |||||
| end := time.Now() | |||||
| if dt := end.Sub(start); dt > 300*time.Millisecond { | |||||
| t.Errorf("took %v to drain; expected shorter", dt) | |||||
| } | |||||
| } | |||||
| @ -0,0 +1,71 @@ | |||||
| package listener | |||||
| import "sync" | |||||
| type shard struct { | |||||
| l *T | |||||
| mu sync.Mutex | |||||
| set map[*conn]struct{} | |||||
| wg sync.WaitGroup | |||||
| drain bool | |||||
| // We pack shards together in an array, but we don't want them packed | |||||
| // too closely, since we want to give each shard a dedicated CPU cache | |||||
| // line. This amount of padding works out well for the common case of | |||||
| // x64 processors (64-bit pointers with a 64-byte cache line). | |||||
| _ [12]byte | |||||
| } | |||||
| // We pretty aggressively preallocate set entries in the hopes that we never | |||||
| // have to allocate memory with the lock held. This is definitely a premature | |||||
| // optimization and is probably misguided, but luckily it costs us essentially | |||||
| // nothing. | |||||
| const prealloc = 2048 | |||||
| func (s *shard) init(l *T) { | |||||
| s.l = l | |||||
| s.set = make(map[*conn]struct{}, prealloc) | |||||
| } | |||||
| func (s *shard) markIdle(c *conn) (shouldClose bool) { | |||||
| s.mu.Lock() | |||||
| if s.drain { | |||||
| s.mu.Unlock() | |||||
| return true | |||||
| } | |||||
| s.set[c] = struct{}{} | |||||
| s.mu.Unlock() | |||||
| return false | |||||
| } | |||||
| func (s *shard) markInUse(c *conn) { | |||||
| s.mu.Lock() | |||||
| delete(s.set, c) | |||||
| s.mu.Unlock() | |||||
| } | |||||
| func (s *shard) closeIdle(drain bool) { | |||||
| s.mu.Lock() | |||||
| if drain { | |||||
| s.drain = true | |||||
| } | |||||
| set := s.set | |||||
| s.set = make(map[*conn]struct{}, prealloc) | |||||
| // We have to drop the shard lock here to avoid deadlock: we cannot | |||||
| // acquire the shard lock after the connection lock, and the closeIfIdle | |||||
| // call below will grab a connection lock. | |||||
| s.mu.Unlock() | |||||
| for conn := range set { | |||||
| // This might return an error (from Close), but I don't think we | |||||
| // can do anything about it, so let's just pretend it didn't | |||||
| // happen. (I also expect that most errors returned in this way | |||||
| // are going to be pretty boring) | |||||
| conn.closeIfIdle() | |||||
| } | |||||
| } | |||||
| func (s *shard) wait() { | |||||
| s.wg.Wait() | |||||
| } | |||||
| @ -1,185 +0,0 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "io" | |||||
| "net" | |||||
| "sync" | |||||
| "sync/atomic" | |||||
| "time" | |||||
| ) | |||||
| type listener struct { | |||||
| net.Listener | |||||
| } | |||||
| // WrapListener wraps an arbitrary net.Listener for use with graceful shutdowns. | |||||
| // All net.Conn's Accept()ed by this listener will be auto-wrapped as if | |||||
| // WrapConn() were called on them. | |||||
| func WrapListener(l net.Listener) net.Listener { | |||||
| return listener{l} | |||||
| } | |||||
| func (l listener) Accept() (net.Conn, error) { | |||||
| conn, err := l.Listener.Accept() | |||||
| return WrapConn(conn), err | |||||
| } | |||||
| type conn struct { | |||||
| mu sync.Mutex | |||||
| cs *connSet | |||||
| net.Conn | |||||
| id uint64 | |||||
| busy, die bool | |||||
| dead bool | |||||
| hijacked bool | |||||
| } | |||||
| /* | |||||
| WrapConn wraps an arbitrary connection for use with graceful shutdowns. The | |||||
| graceful shutdown process will ensure that this connection is closed before | |||||
| terminating the process. | |||||
| In order to use this function, you must call SetReadDeadline() before the call | |||||
| to Read() you might make to read a new request off the wire. The connection is | |||||
| eligible for abrupt closing at any point between when the call to | |||||
| SetReadDeadline() returns and when the call to Read returns with new data. It | |||||
| does not matter what deadline is given to SetReadDeadline()--if a deadline is | |||||
| inappropriate, providing one extremely far into the future will suffice. | |||||
| Unfortunately, this means that it's difficult to use SetReadDeadline() in a | |||||
| great many perfectly reasonable circumstances, such as to extend a deadline | |||||
| after more data has been read, without the connection being eligible for | |||||
| "graceful" termination at an undesirable time. Since this package was written | |||||
| explicitly to target net/http, which does not as of this writing do any of this, | |||||
| fixing the semantics here does not seem especially urgent. | |||||
| */ | |||||
| func WrapConn(c net.Conn) net.Conn { | |||||
| if c == nil { | |||||
| return nil | |||||
| } | |||||
| // Avoid race with termination code. | |||||
| wgLock.Lock() | |||||
| defer wgLock.Unlock() | |||||
| // Determine whether the app is shutting down. | |||||
| if acceptingRequests { | |||||
| wg.Add(1) | |||||
| return &conn{ | |||||
| Conn: c, | |||||
| id: atomic.AddUint64(&idleSet.id, 1), | |||||
| } | |||||
| } else { | |||||
| return nil | |||||
| } | |||||
| } | |||||
| func (c *conn) Read(b []byte) (n int, err error) { | |||||
| c.mu.Lock() | |||||
| if !c.hijacked { | |||||
| defer func() { | |||||
| c.mu.Lock() | |||||
| if c.hijacked { | |||||
| // It's a little unclear to me how this case | |||||
| // would happen, but we *did* drop the lock, so | |||||
| // let's play it safe. | |||||
| return | |||||
| } | |||||
| if c.dead { | |||||
| // Dead sockets don't tell tales. This is to | |||||
| // prevent the case where a Read manages to suck | |||||
| // an entire request off the wire in a race with | |||||
| // someone trying to close idle connections. | |||||
| // Whoever grabs the conn lock first wins, and | |||||
| // if that's the closing process, we need to | |||||
| // "take back" the read. | |||||
| n = 0 | |||||
| err = io.EOF | |||||
| } else { | |||||
| idleSet.markBusy(c) | |||||
| } | |||||
| c.mu.Unlock() | |||||
| }() | |||||
| } | |||||
| c.mu.Unlock() | |||||
| return c.Conn.Read(b) | |||||
| } | |||||
| func (c *conn) SetReadDeadline(t time.Time) error { | |||||
| c.mu.Lock() | |||||
| if !c.hijacked { | |||||
| defer c.markIdle() | |||||
| } | |||||
| c.mu.Unlock() | |||||
| return c.Conn.SetReadDeadline(t) | |||||
| } | |||||
| func (c *conn) Close() error { | |||||
| kill := false | |||||
| c.mu.Lock() | |||||
| kill, c.dead = !c.dead, true | |||||
| idleSet.markBusy(c) | |||||
| c.mu.Unlock() | |||||
| if kill { | |||||
| defer wg.Done() | |||||
| } | |||||
| return c.Conn.Close() | |||||
| } | |||||
| type writerOnly struct { | |||||
| w io.Writer | |||||
| } | |||||
| func (w writerOnly) Write(buf []byte) (int, error) { | |||||
| return w.w.Write(buf) | |||||
| } | |||||
| func (c *conn) ReadFrom(r io.Reader) (int64, error) { | |||||
| if rf, ok := c.Conn.(io.ReaderFrom); ok { | |||||
| return rf.ReadFrom(r) | |||||
| } | |||||
| return io.Copy(writerOnly{c}, r) | |||||
| } | |||||
| func (c *conn) markIdle() { | |||||
| kill := false | |||||
| c.mu.Lock() | |||||
| idleSet.markIdle(c) | |||||
| if c.die { | |||||
| kill, c.dead = !c.dead, true | |||||
| } | |||||
| c.mu.Unlock() | |||||
| if kill { | |||||
| defer wg.Done() | |||||
| c.Conn.Close() | |||||
| } | |||||
| } | |||||
| func (c *conn) closeIfIdle() { | |||||
| kill := false | |||||
| c.mu.Lock() | |||||
| c.die = true | |||||
| if !c.busy && !c.hijacked { | |||||
| kill, c.dead = !c.dead, true | |||||
| } | |||||
| c.mu.Unlock() | |||||
| if kill { | |||||
| defer wg.Done() | |||||
| c.Conn.Close() | |||||
| } | |||||
| } | |||||
| func (c *conn) hijack() net.Conn { | |||||
| c.mu.Lock() | |||||
| idleSet.markBusy(c) | |||||
| c.hijacked = true | |||||
| c.mu.Unlock() | |||||
| return c.Conn | |||||
| } | |||||
| @ -1,12 +0,0 @@ | |||||
| // +build race | |||||
| package graceful | |||||
| import "testing" | |||||
| func TestWaitGroupRace(t *testing.T) { | |||||
| go func() { | |||||
| go WrapConn(fakeConn{}).Close() | |||||
| }() | |||||
| Shutdown() | |||||
| } | |||||