From 57e752c3bcac3538485b6a74f854134ded80998a Mon Sep 17 00:00:00 2001 From: Carl Jackson Date: Tue, 2 Sep 2014 01:29:28 -0700 Subject: [PATCH] Refactor package graceful This is meant to accomplish a few things: 1. graceful no longer spawns an additional goroutine per connection. Instead, it maintains a sharded set of idle connections that a single reaper goroutine can go through when necessary. 2. graceful's connection struct has a more orthogonal set of connection state flags, replacing the harder-to-understand state machine. The underlying mechanics are largely the same, however. 3. graceful now uses the Go 1.3 ConnState API to avoid the "200-year SetReadDeadline hack." It still falls back on SetReadDeadline on Go 1.2 or where ConnState does not apply. --- graceful/conn_set.go | 80 ++++++++++++++ graceful/middleware.go | 2 +- graceful/net.go | 243 ++++++++++++++++++----------------------- graceful/net_test.go | 198 --------------------------------- graceful/serve.go | 1 + graceful/serve13.go | 41 +++++-- 6 files changed, 225 insertions(+), 340 deletions(-) create mode 100644 graceful/conn_set.go delete mode 100644 graceful/net_test.go diff --git a/graceful/conn_set.go b/graceful/conn_set.go new file mode 100644 index 0000000..c08099c --- /dev/null +++ b/graceful/conn_set.go @@ -0,0 +1,80 @@ +package graceful + +import ( + "runtime" + "sync" +) + +type connShard struct { + mu sync.Mutex + // We sort of abuse this field to also act as a "please shut down" flag. + // If it's nil, you should die at your earliest opportunity. + set map[*conn]struct{} +} + +type connSet struct { + // This is an incrementing connection counter so we round-robin + // connections across shards. Use atomic when touching it. + id uint64 + shards []*connShard +} + +var idleSet connSet + +// We pretty aggressively preallocate set entries in the hopes that we never +// have to allocate memory with the lock held. This is definitely a premature +// optimization and is probably misguided, but luckily it costs us essentially +// nothing. +const prealloc = 2048 + +func init() { + // To keep the expected contention rate constant we'd have to grow this + // as numcpu**2. In practice, CPU counts don't generally grow without + // bound, and contention is probably going to be small enough that + // nobody cares anyways. + idleSet.shards = make([]*connShard, 2*runtime.NumCPU()) + for i := range idleSet.shards { + idleSet.shards[i] = &connShard{ + set: make(map[*conn]struct{}, prealloc), + } + } +} + +func (cs connSet) markIdle(c *conn) { + c.busy = false + shard := cs.shards[int(c.id%uint64(len(cs.shards)))] + shard.mu.Lock() + if shard.set == nil { + shard.mu.Unlock() + c.die = true + } else { + shard.set[c] = struct{}{} + shard.mu.Unlock() + } +} + +func (cs connSet) markBusy(c *conn) { + c.busy = true + shard := cs.shards[int(c.id%uint64(len(cs.shards)))] + shard.mu.Lock() + if shard.set == nil { + shard.mu.Unlock() + c.die = true + } else { + delete(shard.set, c) + shard.mu.Unlock() + } +} + +func (cs connSet) killall() { + for _, shard := range cs.shards { + shard.mu.Lock() + set := shard.set + shard.set = nil + shard.mu.Unlock() + + for conn := range set { + conn.closeIfIdle() + } + } +} diff --git a/graceful/middleware.go b/graceful/middleware.go index f169891..4aace7d 100644 --- a/graceful/middleware.go +++ b/graceful/middleware.go @@ -103,7 +103,7 @@ func (f *fancyWriter) Hijack() (c net.Conn, b *bufio.ReadWriter, e error) { hj := f.basicWriter.ResponseWriter.(http.Hijacker) c, b, e = hj.Hijack() - if conn, ok := c.(hijackConn); ok { + if conn, ok := c.(*conn); ok { c = conn.hijack() } diff --git a/graceful/net.go b/graceful/net.go index b3f124a..e667d05 100644 --- a/graceful/net.go +++ b/graceful/net.go @@ -4,6 +4,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" ) @@ -11,10 +12,6 @@ type listener struct { net.Listener } -type gracefulConn interface { - gracefulShutdown() -} - // WrapListener wraps an arbitrary net.Listener for use with graceful shutdowns. // All net.Conn's Accept()ed by this listener will be auto-wrapped as if // WrapConn() were called on them. @@ -24,11 +21,17 @@ func WrapListener(l net.Listener) net.Listener { func (l listener) Accept() (net.Conn, error) { conn, err := l.Listener.Accept() - if err != nil { - return nil, err - } + return WrapConn(conn), err +} - return WrapConn(conn), nil +type conn struct { + mu sync.Mutex + cs *connSet + net.Conn + id uint64 + busy, die bool + dead bool + hijacked bool } /* @@ -39,10 +42,9 @@ terminating the process. In order to use this function, you must call SetReadDeadline() before the call to Read() you might make to read a new request off the wire. The connection is eligible for abrupt closing at any point between when the call to -SetReadDeadline() returns and when the call to Read returns with new data. It -does not matter what deadline is given to SetReadDeadline()--the default HTTP -server provided by this package sets a deadline far into the future when a -deadline is not provided, for instance. +SetReadDeadline() returns and when the call to Read returns with new data. It +does not matter what deadline is given to SetReadDeadline()--if a deadline is +inappropriate, providing one extremely far into the future will suffice. Unfortunately, this means that it's difficult to use SetReadDeadline() in a great many perfectly reasonable circumstances, such as to extend a deadline @@ -50,152 +52,125 @@ after more data has been read, without the connection being eligible for "graceful" termination at an undesirable time. Since this package was written explicitly to target net/http, which does not as of this writing do any of this, fixing the semantics here does not seem especially urgent. - -As an optimization for net/http over TCP, if the input connection supports the -ReadFrom() function, the returned connection will as well. This allows the net -package to use sendfile(2) on certain platforms in certain circumstances. */ func WrapConn(c net.Conn) net.Conn { - wg.Add(1) - - nc := conn{ - Conn: c, - closing: make(chan struct{}), + if c == nil { + return nil } - if _, ok := c.(io.ReaderFrom); ok { - c = &sendfile{nc} - } else { - c = &nc + wg.Add(1) + return &conn{ + Conn: c, + id: atomic.AddUint64(&idleSet.id, 1), } +} - go c.(gracefulConn).gracefulShutdown() +func (c *conn) Read(b []byte) (n int, err error) { + c.mu.Lock() + if !c.hijacked { + defer func() { + c.mu.Lock() + if c.hijacked { + // It's a little unclear to me how this case + // would happen, but we *did* drop the lock, so + // let's play it safe. + return + } + + if c.dead { + // Dead sockets don't tell tales. This is to + // prevent the case where a Read manages to suck + // an entire request off the wire in a race with + // someone trying to close idle connections. + // Whoever grabs the conn lock first wins, and + // if that's the closing process, we need to + // "take back" the read. + n = 0 + err = io.EOF + } else { + idleSet.markBusy(c) + } + c.mu.Unlock() + }() + } + c.mu.Unlock() - return c + return c.Conn.Read(b) } -type connstate int - -/* -State diagram. (Waiting) is the starting state. - -(Waiting) -----Read()-----> Working ---+ - | ^ / | ^ Read() - | \ / | +----+ - kill SetReadDeadline() kill - | | +-----+ - V V V Read() - Dead <-SetReadDeadline()-- Dying ----+ - ^ - | - +--Close()--- [from any state] +func (c *conn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + if !c.hijacked { + defer c.markIdle() + } + c.mu.Unlock() + return c.Conn.SetReadDeadline(t) +} -*/ +func (c *conn) Close() error { + kill := false + c.mu.Lock() + kill, c.dead = !c.dead, true + idleSet.markBusy(c) + c.mu.Unlock() + + if kill { + defer wg.Done() + } + return c.Conn.Close() +} -const ( - // Waiting for more data, and eligible for killing - csWaiting connstate = iota - // In the middle of a connection - csWorking - // Kill has been requested, but waiting on request to finish up - csDying - // Connection is gone forever. Also used when a connection gets hijacked - csDead -) +type writerOnly struct { + w io.Writer +} -type conn struct { - net.Conn - m sync.Mutex - state connstate - closing chan struct{} +func (w writerOnly) Write(buf []byte) (int, error) { + return w.w.Write(buf) } -type sendfile struct{ conn } -func (c *conn) gracefulShutdown() { - select { - case <-kill: - case <-c.closing: - return - } - c.m.Lock() - defer c.m.Unlock() - - switch c.state { - case csWaiting: - c.unlockedClose(true) - case csWorking: - c.state = csDying +func (c *conn) ReadFrom(r io.Reader) (int64, error) { + if rf, ok := c.Conn.(io.ReaderFrom); ok { + return rf.ReadFrom(r) } + return io.Copy(writerOnly{c}, r) } -func (c *conn) unlockedClose(closeConn bool) { - if closeConn { +func (c *conn) markIdle() { + kill := false + c.mu.Lock() + idleSet.markIdle(c) + if c.die { + kill, c.dead = !c.dead, true + } + c.mu.Unlock() + + if kill { + defer wg.Done() c.Conn.Close() } - close(c.closing) - wg.Done() - c.state = csDead -} -// We do some hijinks to support hijacking. The semantics here is that any -// connection that gets hijacked is dead to us: we return the raw net.Conn and -// stop tracking the connection entirely. -type hijackConn interface { - hijack() net.Conn } -func (c *conn) hijack() net.Conn { - c.m.Lock() - defer c.m.Unlock() - if c.state != csDead { - close(c.closing) - wg.Done() - c.state = csDead +func (c *conn) closeIfIdle() { + kill := false + c.mu.Lock() + c.die = true + if !c.busy && !c.hijacked { + kill, c.dead = !c.dead, true } - return c.Conn -} - -func (c *conn) Read(b []byte) (n int, err error) { - defer func() { - c.m.Lock() - defer c.m.Unlock() - - if c.state == csWaiting { - c.state = csWorking - } else if c.state == csDead { - n = 0 - err = io.EOF - } - }() + c.mu.Unlock() - return c.Conn.Read(b) -} -func (c *conn) Close() error { - defer func() { - c.m.Lock() - defer c.m.Unlock() - - if c.state != csDead { - c.unlockedClose(false) - } - }() - return c.Conn.Close() -} -func (c *conn) SetReadDeadline(t time.Time) error { - defer func() { - c.m.Lock() - defer c.m.Unlock() - switch c.state { - case csDying: - c.unlockedClose(false) - case csWorking: - c.state = csWaiting - } - }() - return c.Conn.SetReadDeadline(t) + if kill { + defer wg.Done() + c.Conn.Close() + } } -func (s *sendfile) ReadFrom(r io.Reader) (int64, error) { - // conn.Conn.KHAAAAAAAANNNNNN - return s.conn.Conn.(io.ReaderFrom).ReadFrom(r) +func (c *conn) hijack() net.Conn { + c.mu.Lock() + idleSet.markBusy(c) + c.hijacked = true + c.mu.Unlock() + + return c.Conn } diff --git a/graceful/net_test.go b/graceful/net_test.go deleted file mode 100644 index d6e7208..0000000 --- a/graceful/net_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package graceful - -import ( - "io" - "net" - "strings" - "testing" - "time" -) - -var b = make([]byte, 0) - -func connify(c net.Conn) *conn { - switch c.(type) { - case (*conn): - return c.(*conn) - case (*sendfile): - return &c.(*sendfile).conn - default: - panic("IDK") - } -} - -func assertState(t *testing.T, n net.Conn, st connstate) { - c := connify(n) - c.m.Lock() - defer c.m.Unlock() - if c.state != st { - t.Fatalf("conn was %v, but expected %v", c.state, st) - } -} - -// Not super happy about making the tests dependent on the passing of time, but -// I'm not really sure what else to do. - -func expectCall(t *testing.T, ch <-chan struct{}, name string) { - select { - case <-ch: - case <-time.After(5 * time.Millisecond): - t.Fatalf("Expected call to %s", name) - } -} - -func TestCounting(t *testing.T) { - kill = make(chan struct{}) - c := WrapConn(fakeConn{}) - ch := make(chan struct{}) - - go func() { - wg.Wait() - ch <- struct{}{} - }() - - select { - case <-ch: - t.Fatal("Expected connection to keep us from quitting") - case <-time.After(5 * time.Millisecond): - } - - c.Close() - expectCall(t, ch, "wg.Wait()") -} - -func TestStateTransitions1(t *testing.T) { - kill = make(chan struct{}) - ch := make(chan struct{}) - - onclose := make(chan struct{}) - read := make(chan struct{}) - deadline := make(chan struct{}) - c := WrapConn(fakeConn{ - onClose: func() { - onclose <- struct{}{} - }, - onRead: func() { - read <- struct{}{} - }, - onSetReadDeadline: func() { - deadline <- struct{}{} - }, - }) - - go func() { - wg.Wait() - ch <- struct{}{} - }() - - assertState(t, c, csWaiting) - - // Waiting + Read() = Working - go c.Read(b) - expectCall(t, read, "c.Read()") - assertState(t, c, csWorking) - - // Working + SetReadDeadline() = Waiting - go c.SetReadDeadline(time.Now()) - expectCall(t, deadline, "c.SetReadDeadline()") - assertState(t, c, csWaiting) - - // Waiting + kill = Dead - close(kill) - expectCall(t, onclose, "c.Close()") - assertState(t, c, csDead) - - expectCall(t, ch, "wg.Wait()") -} - -func TestStateTransitions2(t *testing.T) { - kill = make(chan struct{}) - ch := make(chan struct{}) - onclose := make(chan struct{}) - read := make(chan struct{}) - deadline := make(chan struct{}) - c := WrapConn(fakeConn{ - onClose: func() { - onclose <- struct{}{} - }, - onRead: func() { - read <- struct{}{} - }, - onSetReadDeadline: func() { - deadline <- struct{}{} - }, - }) - - go func() { - wg.Wait() - ch <- struct{}{} - }() - - assertState(t, c, csWaiting) - - // Waiting + Read() = Working - go c.Read(b) - expectCall(t, read, "c.Read()") - assertState(t, c, csWorking) - - // Working + Read() = Working - go c.Read(b) - expectCall(t, read, "c.Read()") - assertState(t, c, csWorking) - - // Working + kill = Dying - close(kill) - time.Sleep(5 * time.Millisecond) - assertState(t, c, csDying) - - // Dying + Read() = Dying - go c.Read(b) - expectCall(t, read, "c.Read()") - assertState(t, c, csDying) - - // Dying + SetReadDeadline() = Dead - go c.SetReadDeadline(time.Now()) - expectCall(t, deadline, "c.SetReadDeadline()") - assertState(t, c, csDead) - - expectCall(t, ch, "wg.Wait()") -} - -func TestHijack(t *testing.T) { - kill = make(chan struct{}) - fake := fakeConn{} - c := WrapConn(fake) - ch := make(chan struct{}) - - go func() { - wg.Wait() - ch <- struct{}{} - }() - - cc := connify(c) - if _, ok := cc.hijack().(fakeConn); !ok { - t.Error("Expected original connection back out") - } - assertState(t, c, csDead) - expectCall(t, ch, "wg.Wait()") -} - -type fakeSendfile struct { - fakeConn -} - -func (f fakeSendfile) ReadFrom(r io.Reader) (int64, error) { - return 0, nil -} - -func TestReadFrom(t *testing.T) { - kill = make(chan struct{}) - c := WrapConn(fakeSendfile{}) - r := strings.NewReader("Hello world") - - if rf, ok := c.(io.ReaderFrom); ok { - rf.ReadFrom(r) - } else { - t.Fatal("Expected a ReaderFrom in return") - } -} diff --git a/graceful/serve.go b/graceful/serve.go index 8746075..626f807 100644 --- a/graceful/serve.go +++ b/graceful/serve.go @@ -15,6 +15,7 @@ func (srv *Server) Serve(l net.Listener) error { go func() { <-kill l.Close() + idleSet.killall() }() l = WrapListener(l) diff --git a/graceful/serve13.go b/graceful/serve13.go index 41e9f52..e0acd18 100644 --- a/graceful/serve13.go +++ b/graceful/serve13.go @@ -5,12 +5,8 @@ package graceful import ( "net" "net/http" - "time" ) -// About 200 years, also known as "forever" -const forever time.Duration = 200 * 365 * 24 * time.Hour - func (srv *Server) Serve(l net.Listener) error { l = WrapListener(l) @@ -19,14 +15,45 @@ func (srv *Server) Serve(l net.Listener) error { // and it's nice to keep our sketching to ourselves. shadow := *(*http.Server)(srv) - if shadow.ReadTimeout == 0 { - shadow.ReadTimeout = forever + cs := shadow.ConnState + shadow.ConnState = func(nc net.Conn, s http.ConnState) { + if c, ok := nc.(*conn); ok { + // There are a few other states defined, most notably + // StateActive. Unfortunately it doesn't look like it's + // possible to make use of StateActive to implement + // graceful shutdown, since StateActive is set after a + // complete request has been read off the wire with an + // intent to process it. If we were to race a graceful + // shutdown against a connection that was just read off + // the wire (but not yet in StateActive), we would + // accidentally close the connection out from underneath + // an active request. + // + // We already needed to work around this for Go 1.2 by + // shimming out a full net.Conn object, so we can just + // fall back to the old behavior there. + // + // I started a golang-nuts thread about this here: + // https://groups.google.com/forum/#!topic/golang-nuts/Xi8yjBGWfCQ + // I'd be very eager to find a better way to do this, so + // reach out to me if you have any ideas. + switch s { + case http.StateIdle: + c.markIdle() + case http.StateHijacked: + c.hijack() + } + } + if cs != nil { + cs(nc, s) + } } go func() { <-kill - shadow.SetKeepAlivesEnabled(false) l.Close() + shadow.SetKeepAlivesEnabled(false) + idleSet.killall() }() err := shadow.Serve(l)