Browse Source

Graceful shutdown package

Package graceful provides graceful shutdown support for net/http servers,
net.Listeners and net.Conns. It does this through terrible, terrible hacks, but
"oh well!"
Carl Jackson 12 years ago
parent
commit
96b81e1930
8 changed files with 908 additions and 0 deletions
  1. +63
    -0
      graceful/conn_test.go
  2. +22
    -0
      graceful/einhorn.go
  3. +136
    -0
      graceful/graceful.go
  4. +106
    -0
      graceful/middleware.go
  5. +68
    -0
      graceful/middleware_test.go
  6. +198
    -0
      graceful/net.go
  7. +198
    -0
      graceful/net_test.go
  8. +117
    -0
      graceful/signal.go

+ 63
- 0
graceful/conn_test.go View File

@ -0,0 +1,63 @@
package graceful
import (
"net"
"time"
)
// Stub out a net.Conn. This is going to be painful.
type fakeAddr struct{}
func (f fakeAddr) Network() string {
return "fake"
}
func (f fakeAddr) String() string {
return "fake"
}
type fakeConn struct {
onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func()
onSetDeadline, onSetReadDeadline, onSetWriteDeadline func()
}
// Here's my number, so...
func callMeMaybe(f func()) {
// I apologize for nothing.
if f != nil {
f()
}
}
func (f fakeConn) Read(b []byte) (int, error) {
callMeMaybe(f.onRead)
return len(b), nil
}
func (f fakeConn) Write(b []byte) (int, error) {
callMeMaybe(f.onWrite)
return len(b), nil
}
func (f fakeConn) Close() error {
callMeMaybe(f.onClose)
return nil
}
func (f fakeConn) LocalAddr() net.Addr {
callMeMaybe(f.onLocalAddr)
return fakeAddr{}
}
func (f fakeConn) RemoteAddr() net.Addr {
callMeMaybe(f.onRemoteAddr)
return fakeAddr{}
}
func (f fakeConn) SetDeadline(t time.Time) error {
callMeMaybe(f.onSetDeadline)
return nil
}
func (f fakeConn) SetReadDeadline(t time.Time) error {
callMeMaybe(f.onSetReadDeadline)
return nil
}
func (f fakeConn) SetWriteDeadline(t time.Time) error {
callMeMaybe(f.onSetWriteDeadline)
return nil
}

+ 22
- 0
graceful/einhorn.go View File

@ -0,0 +1,22 @@
package graceful
import (
"log"
"os"
"strconv"
"syscall"
)
func init() {
// This is a little unfortunate: goji/bind already knows whether we're
// running under einhorn, but we don't want to introduce a dependency
// between the two packages. Since the check is short enough, inlining
// it here seems "fine."
mpid, err := strconv.Atoi(os.Getenv("EINHORN_MASTER_PID"))
if err != nil || mpid != os.Getppid() {
return
}
log.Print("graceful: Einhorn detected, adding SIGUSR2 handler")
AddSignal(syscall.SIGUSR2)
}

+ 136
- 0
graceful/graceful.go View File

@ -0,0 +1,136 @@
/*
Package graceful implements graceful shutdown for HTTP servers by closing idle
connections after receiving a signal. By default, this package listens for
interrupts (i.e., SIGINT), but when it detects that it is running under Einhorn
it will additionally listen for SIGUSR2 as well, giving your application
automatic support for graceful upgrades.
It's worth mentioning explicitly that this package is a hack to shim graceful
shutdown behavior into the net/http package provided in Go 1.2. It was written
by carefully reading the sequence of function calls net/http happened to use as
of this writing and finding enough surface area with which to add appropriate
behavior. There's a very good chance that this package will cease to work in
future versions of Go, but with any luck the standard library will add support
of its own by then.
If you're interested in figuring out how this package works, we suggest you read
the documentation for WrapConn() and net.go.
*/
package graceful
import (
"crypto/tls"
"net"
"net/http"
"time"
)
// Exactly like net/http's Server. In fact, it *is* a net/http Server, just with
// different method implementations
type Server http.Server
// About 200 years, also known as "forever"
const forever time.Duration = 200 * 365 * 24 * time.Hour
/*
You might notice that these methods look awfully similar to the methods of the
same name from the go standard library--that's because they were stolen from
there! If go were more like, say, Ruby, it'd actually be possible to shim just
the Serve() method, since we can do everything we want from there. However, it's
not possible to get the other methods which call Serve() (ListenAndServe(), say)
to call your shimmed copy--they always call the original.
Since I couldn't come up with a better idea, I just copy-and-pasted both
ListenAndServe and ListenAndServeTLS here more-or-less verbatim. "Oh well!"
*/
// Behaves exactly like the net/http function of the same name.
func (srv *Server) Serve(l net.Listener) (err error) {
go func() {
<-kill
l.Close()
}()
l = WrapListener(l)
// Spawn a shadow http.Server to do the actual servering. We do this
// because we need to sketch on some of the parameters you passed in,
// and it's nice to keep our sketching to ourselves.
shadow := *(*http.Server)(srv)
if shadow.ReadTimeout == 0 {
shadow.ReadTimeout = forever
}
shadow.Handler = Middleware(shadow.Handler)
err = shadow.Serve(l)
// We expect an error when we close the listener, so we indiscriminately
// swallow Serve errors when we're in a shutdown state.
select {
case <-kill:
return nil
default:
return err
}
}
// Behaves exactly like the net/http function of the same name.
func (srv *Server) ListenAndServe() error {
addr := srv.Addr
if addr == "" {
addr = ":http"
}
l, e := net.Listen("tcp", addr)
if e != nil {
return e
}
return srv.Serve(l)
}
// Behaves exactly like the net/http function of the same name.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
config := &tls.Config{}
if srv.TLSConfig != nil {
*config = *srv.TLSConfig
}
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1"}
}
var err error
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
conn, err := net.Listen("tcp", addr)
if err != nil {
return err
}
tlsListener := tls.NewListener(conn, config)
return srv.Serve(tlsListener)
}
// Behaves exactly like the net/http function of the same name.
func ListenAndServe(addr string, handler http.Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServe()
}
// Behaves exactly like the net/http function of the same name.
func ListenAndServeTLS(addr, certfile, keyfile string, handler http.Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServeTLS(certfile, keyfile)
}
// Behaves exactly like the net/http function of the same name.
func Serve(l net.Listener, handler http.Handler) error {
server := &Server{Handler: handler}
return server.Serve(l)
}

+ 106
- 0
graceful/middleware.go View File

@ -0,0 +1,106 @@
package graceful
import (
"bufio"
"net"
"net/http"
)
/*
Graceful shutdown middleware. When a graceful shutdown is in progress, this
middleware intercepts responses to add a "Connection: close" header to politely
inform the client that we are about to go away.
This package creates a shim http.ResponseWriter that it passes to subsequent
handlers. Unfortunately, there's a great many optional interfaces that this
http.ResponseWriter might implement (e.g., http.CloseNotifier, http.Flusher, and
http.Hijacker), and in order to perfectly proxy all of these options we'd be
left with some kind of awful powerset of ResponseWriters, and that's not even
counting all the other custom interfaces you might be expecting. Instead of
doing that, we have implemented two kinds of proxies: one that contains no
additional methods (i.e., exactly corresponding to the http.ResponseWriter
interface), and one that supports all three of http.CloseNotifier, http.Flusher,
and http.Hijacker. If you find that this is not enough, the original
http.ResponseWriter can be retrieved by calling Unwrap() on the proxy object.
This middleware is automatically applied to every http.Handler passed to this
package, and most users will not need to call this function directly. It is
exported primarily for documentation purposes and in the off chance that someone
really wants more control over their http.Server than we currently provide.
*/
func Middleware(h http.Handler) http.Handler {
if h == nil {
return nil
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, cn := w.(http.CloseNotifier)
_, fl := w.(http.Flusher)
_, hj := w.(http.Hijacker)
bw := basicWriter{ResponseWriter: w}
if cn && fl && hj {
h.ServeHTTP(&fancyWriter{bw}, r)
} else {
h.ServeHTTP(&bw, r)
}
if !bw.headerWritten {
bw.maybeClose()
}
})
}
type basicWriter struct {
http.ResponseWriter
headerWritten bool
}
func (b *basicWriter) maybeClose() {
b.headerWritten = true
select {
case <-kill:
b.ResponseWriter.Header().Add("Connection", "close")
default:
}
}
func (b *basicWriter) WriteHeader(code int) {
b.maybeClose()
b.ResponseWriter.WriteHeader(code)
}
func (b *basicWriter) Write(buf []byte) (int, error) {
if !b.headerWritten {
b.maybeClose()
}
return b.ResponseWriter.Write(buf)
}
func (b *basicWriter) Unwrap() http.ResponseWriter {
return b.ResponseWriter
}
// Optimize for the common case of a ResponseWriter that supports all three of
// CloseNotifier, Flusher, and Hijacker.
type fancyWriter struct {
basicWriter
}
func (f *fancyWriter) CloseNotify() <-chan bool {
cn := f.basicWriter.ResponseWriter.(http.CloseNotifier)
return cn.CloseNotify()
}
func (f *fancyWriter) Flush() {
fl := f.basicWriter.ResponseWriter.(http.Flusher)
fl.Flush()
}
func (f *fancyWriter) Hijack() (c net.Conn, b *bufio.ReadWriter, e error) {
hj := f.basicWriter.ResponseWriter.(http.Hijacker)
c, b, e = hj.Hijack()
if conn, ok := c.(hijackConn); ok {
c = conn.hijack()
}
return
}

+ 68
- 0
graceful/middleware_test.go View File

@ -0,0 +1,68 @@
package graceful
import (
"net/http"
"testing"
)
type fakeWriter http.Header
func (f fakeWriter) Header() http.Header {
return http.Header(f)
}
func (f fakeWriter) Write(buf []byte) (int, error) {
return len(buf), nil
}
func (f fakeWriter) WriteHeader(status int) {}
func testClose(t *testing.T, h http.Handler, expectClose bool) {
m := Middleware(h)
r, _ := http.NewRequest("GET", "/", nil)
w := make(fakeWriter)
m.ServeHTTP(w, r)
c, ok := w["Connection"]
if expectClose {
if !ok || len(c) != 1 || c[0] != "close" {
t.Fatal("Expected 'Connection: close'")
}
} else {
if ok {
t.Fatal("Did not expect Connection header")
}
}
}
func TestNormal(t *testing.T) {
kill = make(chan struct{})
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte{})
})
testClose(t, h, false)
}
func TestClose(t *testing.T) {
kill = make(chan struct{})
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(kill)
})
testClose(t, h, true)
}
func TestCloseWriteHeader(t *testing.T) {
kill = make(chan struct{})
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(kill)
w.WriteHeader(200)
})
testClose(t, h, true)
}
func TestCloseWrite(t *testing.T) {
kill = make(chan struct{})
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(kill)
w.Write([]byte{})
})
testClose(t, h, true)
}

+ 198
- 0
graceful/net.go View File

@ -0,0 +1,198 @@
package graceful
import (
"io"
"net"
"sync"
"time"
)
type listener struct {
net.Listener
}
type gracefulConn interface {
gracefulShutdown()
}
// Wrap an arbitrary net.Listener for use with graceful shutdowns. All
// net.Conn's Accept()ed by this listener will be auto-wrapped as if WrapConn()
// were called on them.
func WrapListener(l net.Listener) net.Listener {
return listener{l}
}
func (l listener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return WrapConn(conn), nil
}
/*
Wrap an arbitrary connection for use with graceful shutdowns. The graceful
shutdown process will ensure that this connection is closed before terminating
the process.
In order to use this function, you must call SetReadDeadline() before the call
to Read() you might make to read a new request off the wire. The connection is
eligible for abrupt closing at any point between when the call to
SetReadDeadline() returns and when the call to Read returns with new data. It
does not matter what deadline is given to SetReadDeadline()--the default HTTP
server provided by this package sets a deadline far into the future when a
deadline is not provided, for instance.
Unfortunately, this means that it's difficult to use SetReadDeadline() in a
great many perfectly reasonable circumstances, such as to extend a deadline
after more data has been read, without the connection being eligible for
"graceful" termination at an undesirable time. Since this package was written
explicitly to target net/http, which does not as of this writing do any of this,
fixing the semantics here does not seem especially urgent.
As an optimization for net/http over TCP, if the input connection supports the
ReadFrom() function, the returned connection will as well. This allows the net
package to use sendfile(2) on certain platforms in certain circumstances.
*/
func WrapConn(c net.Conn) net.Conn {
wg.Add(1)
nc := conn{
Conn: c,
closing: make(chan struct{}),
}
if _, ok := c.(io.ReaderFrom); ok {
c = &sendfile{nc}
} else {
c = &nc
}
go c.(gracefulConn).gracefulShutdown()
return c
}
type connstate int
/*
State diagram. (Waiting) is the starting state.
(Waiting) -----Read()-----> Working ---+
| ^ / | ^ Read()
| \ / | +----+
kill SetReadDeadline() kill
| | +-----+
V V V Read()
Dead <-SetReadDeadline()-- Dying ----+
^
|
+--Close()--- [from any state]
*/
const (
// Waiting for more data, and eligible for killing
csWaiting connstate = iota
// In the middle of a connection
csWorking
// Kill has been requested, but waiting on request to finish up
csDying
// Connection is gone forever. Also used when a connection gets hijacked
csDead
)
type conn struct {
net.Conn
m sync.Mutex
state connstate
closing chan struct{}
}
type sendfile struct{ conn }
func (c *conn) gracefulShutdown() {
select {
case <-kill:
case <-c.closing:
return
}
c.m.Lock()
defer c.m.Unlock()
switch c.state {
case csWaiting:
c.unlockedClose(true)
case csWorking:
c.state = csDying
}
}
func (c *conn) unlockedClose(closeConn bool) {
if closeConn {
c.Conn.Close()
}
close(c.closing)
wg.Done()
c.state = csDead
}
// We do some hijinks to support hijacking. The semantics here is that any
// connection that gets hijacked is dead to us: we return the raw net.Conn and
// stop tracking the connection entirely.
type hijackConn interface {
hijack() net.Conn
}
func (c *conn) hijack() net.Conn {
c.m.Lock()
defer c.m.Unlock()
if c.state != csDead {
close(c.closing)
wg.Done()
c.state = csDead
}
return c.Conn
}
func (c *conn) Read(b []byte) (n int, err error) {
defer func() {
c.m.Lock()
defer c.m.Unlock()
if c.state == csWaiting {
c.state = csWorking
}
}()
return c.Conn.Read(b)
}
func (c *conn) Close() error {
defer func() {
c.m.Lock()
defer c.m.Unlock()
if c.state != csDead {
c.unlockedClose(false)
}
}()
return c.Conn.Close()
}
func (c *conn) SetReadDeadline(t time.Time) error {
defer func() {
c.m.Lock()
defer c.m.Unlock()
switch c.state {
case csDying:
c.unlockedClose(false)
case csWorking:
c.state = csWaiting
}
}()
return c.Conn.SetReadDeadline(t)
}
func (s *sendfile) ReadFrom(r io.Reader) (int64, error) {
// conn.Conn.KHAAAAAAAANNNNNN
return s.conn.Conn.(io.ReaderFrom).ReadFrom(r)
}

+ 198
- 0
graceful/net_test.go View File

@ -0,0 +1,198 @@
package graceful
import (
"io"
"net"
"strings"
"testing"
"time"
)
var b = make([]byte, 0)
func connify(c net.Conn) *conn {
switch c.(type) {
case (*conn):
return c.(*conn)
case (*sendfile):
return &c.(*sendfile).conn
default:
panic("IDK")
}
}
func assertState(t *testing.T, n net.Conn, st connstate) {
c := connify(n)
c.m.Lock()
defer c.m.Unlock()
if c.state != st {
t.Fatalf("conn was %v, but expected %v", c.state, st)
}
}
// Not super happy about making the tests dependent on the passing of time, but
// I'm not really sure what else to do.
func expectCall(t *testing.T, ch <-chan struct{}, name string) {
select {
case <-ch:
case <-time.After(5 * time.Millisecond):
t.Fatalf("Expected call to %s", name)
}
}
func TestCounting(t *testing.T) {
kill = make(chan struct{})
c := WrapConn(fakeConn{})
ch := make(chan struct{})
go func() {
wg.Wait()
ch <- struct{}{}
}()
select {
case <-ch:
t.Fatal("Expected connection to keep us from quitting")
case <-time.After(5 * time.Millisecond):
}
c.Close()
expectCall(t, ch, "wg.Wait()")
}
func TestStateTransitions1(t *testing.T) {
kill = make(chan struct{})
ch := make(chan struct{})
onclose := make(chan struct{})
read := make(chan struct{})
deadline := make(chan struct{})
c := WrapConn(fakeConn{
onClose: func() {
onclose <- struct{}{}
},
onRead: func() {
read <- struct{}{}
},
onSetReadDeadline: func() {
deadline <- struct{}{}
},
})
go func() {
wg.Wait()
ch <- struct{}{}
}()
assertState(t, c, csWaiting)
// Waiting + Read() = Working
go c.Read(b)
expectCall(t, read, "c.Read()")
assertState(t, c, csWorking)
// Working + SetReadDeadline() = Waiting
go c.SetReadDeadline(time.Now())
expectCall(t, deadline, "c.SetReadDeadline()")
assertState(t, c, csWaiting)
// Waiting + kill = Dead
close(kill)
expectCall(t, onclose, "c.Close()")
assertState(t, c, csDead)
expectCall(t, ch, "wg.Wait()")
}
func TestStateTransitions2(t *testing.T) {
kill = make(chan struct{})
ch := make(chan struct{})
onclose := make(chan struct{})
read := make(chan struct{})
deadline := make(chan struct{})
c := WrapConn(fakeConn{
onClose: func() {
onclose <- struct{}{}
},
onRead: func() {
read <- struct{}{}
},
onSetReadDeadline: func() {
deadline <- struct{}{}
},
})
go func() {
wg.Wait()
ch <- struct{}{}
}()
assertState(t, c, csWaiting)
// Waiting + Read() = Working
go c.Read(b)
expectCall(t, read, "c.Read()")
assertState(t, c, csWorking)
// Working + Read() = Working
go c.Read(b)
expectCall(t, read, "c.Read()")
assertState(t, c, csWorking)
// Working + kill = Dying
close(kill)
time.Sleep(5 * time.Millisecond)
assertState(t, c, csDying)
// Dying + Read() = Dying
go c.Read(b)
expectCall(t, read, "c.Read()")
assertState(t, c, csDying)
// Dying + SetReadDeadline() = Dead
go c.SetReadDeadline(time.Now())
expectCall(t, deadline, "c.SetReadDeadline()")
assertState(t, c, csDead)
expectCall(t, ch, "wg.Wait()")
}
func TestHijack(t *testing.T) {
kill = make(chan struct{})
fake := fakeConn{}
c := WrapConn(fake)
ch := make(chan struct{})
go func() {
wg.Wait()
ch <- struct{}{}
}()
cc := connify(c)
if _, ok := cc.hijack().(fakeConn); !ok {
t.Error("Expected original connection back out")
}
assertState(t, c, csDead)
expectCall(t, ch, "wg.Wait()")
}
type fakeSendfile struct {
fakeConn
}
func (f fakeSendfile) ReadFrom(r io.Reader) (int64, error) {
return 0, nil
}
func TestReadFrom(t *testing.T) {
kill = make(chan struct{})
c := WrapConn(fakeSendfile{})
r := strings.NewReader("Hello world")
if rf, ok := c.(io.ReaderFrom); ok {
rf.ReadFrom(r)
} else {
t.Fatal("Expected a ReaderFrom in return")
}
}

+ 117
- 0
graceful/signal.go View File

@ -0,0 +1,117 @@
package graceful
import (
"log"
"os"
"os/signal"
"sync"
)
// This is the channel that the connections select on. When it is closed, the
// connections should gracefully exit.
var kill = make(chan struct{})
// This is the channel that the Wait() function selects on. It should only be
// closed once all the posthooks have been called.
var wait = make(chan struct{})
// This is the WaitGroup that indicates when all the connections have gracefully
// shut down.
var wg sync.WaitGroup
// This lock protects the list of pre- and post- hooks below.
var hookLock sync.Mutex
var prehooks = make([]func(), 0)
var posthooks = make([]func(), 0)
var sigchan = make(chan os.Signal, 1)
func init() {
AddSignal(os.Interrupt)
go waitForSignal()
}
// Add the given signal to the set of signals that trigger a graceful shutdown.
// Note that for convenience the default interrupt (SIGINT) handler is installed
// at package load time, and unless you call ResetSignals() will be listened for
// in addition to any signals you provide by calling this function.
func AddSignal(sig ...os.Signal) {
signal.Notify(sigchan, sig...)
}
// Reset the list of signals that trigger a graceful shutdown. Useful if, for
// instance, you don't want to use the default interrupt (SIGINT) handler. Since
// we necessarily install the SIGINT handler before you have a chance to call
// ResetSignals(), there will be a brief window during which the set of signals
// this package listens for will not be as you intend. Therefore, if you intend
// on using this function, we encourage you to call it as soon as possible.
func ResetSignals() {
signal.Stop(sigchan)
}
type userShutdown struct{}
func (u userShutdown) String() string {
return "application initiated shutdown"
}
func (u userShutdown) Signal() {}
// Manually trigger a shutdown from your application. Like Wait(), blocks until
// all connections have gracefully shut down.
func Shutdown() {
sigchan <- userShutdown{}
<-wait
}
// Register a function to be called before any of this package's normal shutdown
// actions. All listeners will be called in the order they were added, from a
// single goroutine.
func PreHook(f func()) {
hookLock.Lock()
defer hookLock.Unlock()
prehooks = append(prehooks, f)
}
// Register a function to be called after all of this package's normal shutdown
// actions. All listeners will be called in the order they were added, from a
// single goroutine, and are guaranteed to be called after all listening
// connections have been closed, but before Wait() returns.
//
// If you've Hijack()ed any connections that must be gracefully shut down in
// some other way (since this library disowns all hijacked connections), it's
// reasonable to use a PostHook() to signal and wait for them.
func PostHook(f func()) {
hookLock.Lock()
defer hookLock.Unlock()
posthooks = append(posthooks, f)
}
func waitForSignal() {
sig := <-sigchan
log.Printf("Received %v, gracefully shutting down!", sig)
hookLock.Lock()
defer hookLock.Unlock()
for _, f := range prehooks {
f()
}
close(kill)
wg.Wait()
for _, f := range posthooks {
f()
}
close(wait)
}
// Wait for all connections to gracefully shut down. This is commonly called at
// the bottom of the main() function to prevent the program from exiting
// prematurely.
func Wait() {
<-wait
}

Loading…
Cancel
Save