diff --git a/graceful/listener/conn.go b/graceful/listener/conn.go index 22e986f..3f797c5 100644 --- a/graceful/listener/conn.go +++ b/graceful/listener/conn.go @@ -25,7 +25,7 @@ var errClosing = errors.New("use of closed network connection") func (c *conn) init() error { c.shard.wg.Add(1) - if shouldExit := c.shard.markIdle(c); shouldExit { + if shouldExit := c.shard.track(c); shouldExit { c.Close() return errClosing } @@ -65,12 +65,14 @@ func (c *conn) Close() error { c.mu.Lock() defer c.mu.Unlock() - if !c.closed && !c.disowned { - c.closed = true - c.shard.markInUse(c) - defer c.shard.wg.Done() + if c.closed || c.disowned { + return errClosing } + c.closed = true + c.shard.disown(c) + defer c.shard.wg.Done() + return c.Conn.Close() } @@ -98,7 +100,7 @@ func (c *conn) markIdle() { if exit := c.shard.markIdle(c); exit && !c.closed && !c.disowned { c.closed = true - c.shard.markInUse(c) + c.shard.disown(c) defer c.shard.wg.Done() c.Conn.Close() return @@ -121,7 +123,7 @@ func (c *conn) closeIfIdle() error { if !c.busy && !c.closed && !c.disowned { c.closed = true - c.shard.markInUse(c) + c.shard.disown(c) defer c.shard.wg.Done() return c.Conn.Close() } @@ -139,7 +141,7 @@ func (c *conn) disown() error { return errAlreadyDisowned } - c.shard.markInUse(c) + c.shard.disown(c) c.disowned = true c.shard.wg.Done() diff --git a/graceful/listener/listener.go b/graceful/listener/listener.go index bebfb31..1a0e59e 100644 --- a/graceful/listener/listener.go +++ b/graceful/listener/listener.go @@ -108,7 +108,7 @@ func (t *T) Close() error { // serving requests. func (t *T) CloseIdle() error { for i := range t.shards { - t.shards[i].closeIdle(false) + t.shards[i].closeConns(false, false) } // Not sure if returning errors is actually useful here :/ return nil @@ -122,7 +122,7 @@ func (t *T) CloseIdle() error { // connections will be accepted and immediately closed. func (t *T) Drain() error { for i := range t.shards { - t.shards[i].closeIdle(true) + t.shards[i].closeConns(false, true) } for i := range t.shards { t.shards[i].wait() @@ -130,7 +130,20 @@ func (t *T) Drain() error { return nil } -var notManagedErr = errors.New("listener: passed net.Conn is not managed by us") +// DrainAll closes all connections currently tracked by this listener (both idle +// and in-use connections), and prevents new connections from being accepted. +// Disowned connections are not closed. +func (t *T) DrainAll() error { + for i := range t.shards { + t.shards[i].closeConns(true, true) + } + for i := range t.shards { + t.shards[i].wait() + } + return nil +} + +var notManagedErr = errors.New("listener: passed net.Conn is not managed by this package") // Disown causes a connection to no longer be tracked by the listener. The // passed connection must have been returned by a call to Accept from this diff --git a/graceful/listener/listener_test.go b/graceful/listener/listener_test.go index 160434b..dcefdaa 100644 --- a/graceful/listener/listener_test.go +++ b/graceful/listener/listener_test.go @@ -110,6 +110,19 @@ func TestDrain(t *testing.T) { } } +func TestDrainAll(t *testing.T) { + t.Parallel() + l, c, wc := singleConn(t, Manual) + + MarkInUse(wc) + if err := l.DrainAll(); err != nil { + t.Fatalf("error draining listener: %v", err) + } + if !c.Closed() { + t.Error("expected in-use connection to be closed") + } +} + func TestErrors(t *testing.T) { t.Parallel() _, c, wc := singleConn(t, Manual) diff --git a/graceful/listener/shard.go b/graceful/listener/shard.go index e47deac..a9addad 100644 --- a/graceful/listener/shard.go +++ b/graceful/listener/shard.go @@ -6,15 +6,10 @@ type shard struct { l *T mu sync.Mutex - set map[*conn]struct{} + idle map[*conn]struct{} + all map[*conn]struct{} wg sync.WaitGroup drain bool - - // We pack shards together in an array, but we don't want them packed - // too closely, since we want to give each shard a dedicated CPU cache - // line. This amount of padding works out well for the common case of - // x64 processors (64-bit pointers with a 64-byte cache line). - _ [12]byte } // We pretty aggressively preallocate set entries in the hopes that we never @@ -25,7 +20,27 @@ const prealloc = 2048 func (s *shard) init(l *T) { s.l = l - s.set = make(map[*conn]struct{}, prealloc) + s.idle = make(map[*conn]struct{}, prealloc) + s.all = make(map[*conn]struct{}, prealloc) +} + +func (s *shard) track(c *conn) (shouldClose bool) { + s.mu.Lock() + if s.drain { + s.mu.Unlock() + return true + } + s.all[c] = struct{}{} + s.idle[c] = struct{}{} + s.mu.Unlock() + return false +} + +func (s *shard) disown(c *conn) { + s.mu.Lock() + delete(s.all, c) + delete(s.idle, c) + s.mu.Unlock() } func (s *shard) markIdle(c *conn) (shouldClose bool) { @@ -34,35 +49,47 @@ func (s *shard) markIdle(c *conn) (shouldClose bool) { s.mu.Unlock() return true } - s.set[c] = struct{}{} + s.idle[c] = struct{}{} s.mu.Unlock() return false } func (s *shard) markInUse(c *conn) { s.mu.Lock() - delete(s.set, c) + delete(s.idle, c) s.mu.Unlock() } -func (s *shard) closeIdle(drain bool) { +func (s *shard) closeConns(all, drain bool) { s.mu.Lock() if drain { s.drain = true } - set := s.set - s.set = make(map[*conn]struct{}, prealloc) + set := make(map[*conn]struct{}, len(s.all)) + if all { + for c := range s.all { + set[c] = struct{}{} + } + } else { + for c := range s.idle { + set[c] = struct{}{} + } + } // We have to drop the shard lock here to avoid deadlock: we cannot // acquire the shard lock after the connection lock, and the closeIfIdle // call below will grab a connection lock. s.mu.Unlock() - for conn := range set { + for c := range set { // This might return an error (from Close), but I don't think we // can do anything about it, so let's just pretend it didn't // happen. (I also expect that most errors returned in this way // are going to be pretty boring) - conn.closeIfIdle() + if all { + c.Close() + } else { + c.closeIfIdle() + } } } diff --git a/graceful/listener/shard_test.go b/graceful/listener/shard_test.go new file mode 100644 index 0000000..9b99394 --- /dev/null +++ b/graceful/listener/shard_test.go @@ -0,0 +1,21 @@ +// +build amd64 + +package listener + +import ( + "testing" + "unsafe" +) + +// We pack shards together in an array, but we don't want them packed too +// closely, since we want to give each shard a dedicated CPU cache line. This +// test checks this property for x64 (which has a 64-byte cache line), which +// probably covers the majority of deployments. +// +// As always, this is probably a premature optimization. +func TestShardSize(t *testing.T) { + s := unsafe.Sizeof(shard{}) + if s < 64 { + t.Errorf("sizeof(shard) = %d; expected >64", s) + } +}