diff --git a/graceful/conn_test.go b/graceful/conn_test.go new file mode 100644 index 0000000..d86f12e --- /dev/null +++ b/graceful/conn_test.go @@ -0,0 +1,63 @@ +package graceful + +import ( + "net" + "time" +) + +// Stub out a net.Conn. This is going to be painful. + +type fakeAddr struct{} + +func (f fakeAddr) Network() string { + return "fake" +} +func (f fakeAddr) String() string { + return "fake" +} + +type fakeConn struct { + onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func() + onSetDeadline, onSetReadDeadline, onSetWriteDeadline func() +} + +// Here's my number, so... +func callMeMaybe(f func()) { + // I apologize for nothing. + if f != nil { + f() + } +} + +func (f fakeConn) Read(b []byte) (int, error) { + callMeMaybe(f.onRead) + return len(b), nil +} +func (f fakeConn) Write(b []byte) (int, error) { + callMeMaybe(f.onWrite) + return len(b), nil +} +func (f fakeConn) Close() error { + callMeMaybe(f.onClose) + return nil +} +func (f fakeConn) LocalAddr() net.Addr { + callMeMaybe(f.onLocalAddr) + return fakeAddr{} +} +func (f fakeConn) RemoteAddr() net.Addr { + callMeMaybe(f.onRemoteAddr) + return fakeAddr{} +} +func (f fakeConn) SetDeadline(t time.Time) error { + callMeMaybe(f.onSetDeadline) + return nil +} +func (f fakeConn) SetReadDeadline(t time.Time) error { + callMeMaybe(f.onSetReadDeadline) + return nil +} +func (f fakeConn) SetWriteDeadline(t time.Time) error { + callMeMaybe(f.onSetWriteDeadline) + return nil +} diff --git a/graceful/einhorn.go b/graceful/einhorn.go new file mode 100644 index 0000000..c8e7af2 --- /dev/null +++ b/graceful/einhorn.go @@ -0,0 +1,22 @@ +package graceful + +import ( + "log" + "os" + "strconv" + "syscall" +) + +func init() { + // This is a little unfortunate: goji/bind already knows whether we're + // running under einhorn, but we don't want to introduce a dependency + // between the two packages. Since the check is short enough, inlining + // it here seems "fine." + mpid, err := strconv.Atoi(os.Getenv("EINHORN_MASTER_PID")) + if err != nil || mpid != os.Getppid() { + return + } + + log.Print("graceful: Einhorn detected, adding SIGUSR2 handler") + AddSignal(syscall.SIGUSR2) +} diff --git a/graceful/graceful.go b/graceful/graceful.go new file mode 100644 index 0000000..121e307 --- /dev/null +++ b/graceful/graceful.go @@ -0,0 +1,136 @@ +/* +Package graceful implements graceful shutdown for HTTP servers by closing idle +connections after receiving a signal. By default, this package listens for +interrupts (i.e., SIGINT), but when it detects that it is running under Einhorn +it will additionally listen for SIGUSR2 as well, giving your application +automatic support for graceful upgrades. + +It's worth mentioning explicitly that this package is a hack to shim graceful +shutdown behavior into the net/http package provided in Go 1.2. It was written +by carefully reading the sequence of function calls net/http happened to use as +of this writing and finding enough surface area with which to add appropriate +behavior. There's a very good chance that this package will cease to work in +future versions of Go, but with any luck the standard library will add support +of its own by then. + +If you're interested in figuring out how this package works, we suggest you read +the documentation for WrapConn() and net.go. +*/ +package graceful + +import ( + "crypto/tls" + "net" + "net/http" + "time" +) + +// Exactly like net/http's Server. In fact, it *is* a net/http Server, just with +// different method implementations +type Server http.Server + +// About 200 years, also known as "forever" +const forever time.Duration = 200 * 365 * 24 * time.Hour + +/* +You might notice that these methods look awfully similar to the methods of the +same name from the go standard library--that's because they were stolen from +there! If go were more like, say, Ruby, it'd actually be possible to shim just +the Serve() method, since we can do everything we want from there. However, it's +not possible to get the other methods which call Serve() (ListenAndServe(), say) +to call your shimmed copy--they always call the original. + +Since I couldn't come up with a better idea, I just copy-and-pasted both +ListenAndServe and ListenAndServeTLS here more-or-less verbatim. "Oh well!" +*/ + +// Behaves exactly like the net/http function of the same name. +func (srv *Server) Serve(l net.Listener) (err error) { + go func() { + <-kill + l.Close() + }() + l = WrapListener(l) + + // Spawn a shadow http.Server to do the actual servering. We do this + // because we need to sketch on some of the parameters you passed in, + // and it's nice to keep our sketching to ourselves. + shadow := *(*http.Server)(srv) + + if shadow.ReadTimeout == 0 { + shadow.ReadTimeout = forever + } + shadow.Handler = Middleware(shadow.Handler) + + err = shadow.Serve(l) + + // We expect an error when we close the listener, so we indiscriminately + // swallow Serve errors when we're in a shutdown state. + select { + case <-kill: + return nil + default: + return err + } +} + +// Behaves exactly like the net/http function of the same name. +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":http" + } + l, e := net.Listen("tcp", addr) + if e != nil { + return e + } + return srv.Serve(l) +} + +// Behaves exactly like the net/http function of the same name. +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + config := &tls.Config{} + if srv.TLSConfig != nil { + *config = *srv.TLSConfig + } + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1"} + } + + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + + conn, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := tls.NewListener(conn, config) + return srv.Serve(tlsListener) +} + +// Behaves exactly like the net/http function of the same name. +func ListenAndServe(addr string, handler http.Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServe() +} + +// Behaves exactly like the net/http function of the same name. +func ListenAndServeTLS(addr, certfile, keyfile string, handler http.Handler) error { + server := &Server{Addr: addr, Handler: handler} + return server.ListenAndServeTLS(certfile, keyfile) +} + +// Behaves exactly like the net/http function of the same name. +func Serve(l net.Listener, handler http.Handler) error { + server := &Server{Handler: handler} + return server.Serve(l) +} diff --git a/graceful/middleware.go b/graceful/middleware.go new file mode 100644 index 0000000..34a3368 --- /dev/null +++ b/graceful/middleware.go @@ -0,0 +1,106 @@ +package graceful + +import ( + "bufio" + "net" + "net/http" +) + +/* +Graceful shutdown middleware. When a graceful shutdown is in progress, this +middleware intercepts responses to add a "Connection: close" header to politely +inform the client that we are about to go away. + +This package creates a shim http.ResponseWriter that it passes to subsequent +handlers. Unfortunately, there's a great many optional interfaces that this +http.ResponseWriter might implement (e.g., http.CloseNotifier, http.Flusher, and +http.Hijacker), and in order to perfectly proxy all of these options we'd be +left with some kind of awful powerset of ResponseWriters, and that's not even +counting all the other custom interfaces you might be expecting. Instead of +doing that, we have implemented two kinds of proxies: one that contains no +additional methods (i.e., exactly corresponding to the http.ResponseWriter +interface), and one that supports all three of http.CloseNotifier, http.Flusher, +and http.Hijacker. If you find that this is not enough, the original +http.ResponseWriter can be retrieved by calling Unwrap() on the proxy object. + +This middleware is automatically applied to every http.Handler passed to this +package, and most users will not need to call this function directly. It is +exported primarily for documentation purposes and in the off chance that someone +really wants more control over their http.Server than we currently provide. +*/ +func Middleware(h http.Handler) http.Handler { + if h == nil { + return nil + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, cn := w.(http.CloseNotifier) + _, fl := w.(http.Flusher) + _, hj := w.(http.Hijacker) + + bw := basicWriter{ResponseWriter: w} + + if cn && fl && hj { + h.ServeHTTP(&fancyWriter{bw}, r) + } else { + h.ServeHTTP(&bw, r) + } + if !bw.headerWritten { + bw.maybeClose() + } + }) +} + +type basicWriter struct { + http.ResponseWriter + headerWritten bool +} + +func (b *basicWriter) maybeClose() { + b.headerWritten = true + select { + case <-kill: + b.ResponseWriter.Header().Add("Connection", "close") + default: + } +} + +func (b *basicWriter) WriteHeader(code int) { + b.maybeClose() + b.ResponseWriter.WriteHeader(code) +} + +func (b *basicWriter) Write(buf []byte) (int, error) { + if !b.headerWritten { + b.maybeClose() + } + return b.ResponseWriter.Write(buf) +} + +func (b *basicWriter) Unwrap() http.ResponseWriter { + return b.ResponseWriter +} + +// Optimize for the common case of a ResponseWriter that supports all three of +// CloseNotifier, Flusher, and Hijacker. +type fancyWriter struct { + basicWriter +} + +func (f *fancyWriter) CloseNotify() <-chan bool { + cn := f.basicWriter.ResponseWriter.(http.CloseNotifier) + return cn.CloseNotify() +} +func (f *fancyWriter) Flush() { + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} +func (f *fancyWriter) Hijack() (c net.Conn, b *bufio.ReadWriter, e error) { + hj := f.basicWriter.ResponseWriter.(http.Hijacker) + c, b, e = hj.Hijack() + + if conn, ok := c.(hijackConn); ok { + c = conn.hijack() + } + + return +} diff --git a/graceful/middleware_test.go b/graceful/middleware_test.go new file mode 100644 index 0000000..ecec606 --- /dev/null +++ b/graceful/middleware_test.go @@ -0,0 +1,68 @@ +package graceful + +import ( + "net/http" + "testing" +) + +type fakeWriter http.Header + +func (f fakeWriter) Header() http.Header { + return http.Header(f) +} +func (f fakeWriter) Write(buf []byte) (int, error) { + return len(buf), nil +} +func (f fakeWriter) WriteHeader(status int) {} + +func testClose(t *testing.T, h http.Handler, expectClose bool) { + m := Middleware(h) + r, _ := http.NewRequest("GET", "/", nil) + w := make(fakeWriter) + m.ServeHTTP(w, r) + + c, ok := w["Connection"] + if expectClose { + if !ok || len(c) != 1 || c[0] != "close" { + t.Fatal("Expected 'Connection: close'") + } + } else { + if ok { + t.Fatal("Did not expect Connection header") + } + } +} + +func TestNormal(t *testing.T) { + kill = make(chan struct{}) + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte{}) + }) + testClose(t, h, false) +} + +func TestClose(t *testing.T) { + kill = make(chan struct{}) + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(kill) + }) + testClose(t, h, true) +} + +func TestCloseWriteHeader(t *testing.T) { + kill = make(chan struct{}) + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(kill) + w.WriteHeader(200) + }) + testClose(t, h, true) +} + +func TestCloseWrite(t *testing.T) { + kill = make(chan struct{}) + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(kill) + w.Write([]byte{}) + }) + testClose(t, h, true) +} diff --git a/graceful/net.go b/graceful/net.go new file mode 100644 index 0000000..b86af8c --- /dev/null +++ b/graceful/net.go @@ -0,0 +1,198 @@ +package graceful + +import ( + "io" + "net" + "sync" + "time" +) + +type listener struct { + net.Listener +} + +type gracefulConn interface { + gracefulShutdown() +} + +// Wrap an arbitrary net.Listener for use with graceful shutdowns. All +// net.Conn's Accept()ed by this listener will be auto-wrapped as if WrapConn() +// were called on them. +func WrapListener(l net.Listener) net.Listener { + return listener{l} +} + +func (l listener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return WrapConn(conn), nil +} + +/* +Wrap an arbitrary connection for use with graceful shutdowns. The graceful +shutdown process will ensure that this connection is closed before terminating +the process. + +In order to use this function, you must call SetReadDeadline() before the call +to Read() you might make to read a new request off the wire. The connection is +eligible for abrupt closing at any point between when the call to +SetReadDeadline() returns and when the call to Read returns with new data. It +does not matter what deadline is given to SetReadDeadline()--the default HTTP +server provided by this package sets a deadline far into the future when a +deadline is not provided, for instance. + +Unfortunately, this means that it's difficult to use SetReadDeadline() in a +great many perfectly reasonable circumstances, such as to extend a deadline +after more data has been read, without the connection being eligible for +"graceful" termination at an undesirable time. Since this package was written +explicitly to target net/http, which does not as of this writing do any of this, +fixing the semantics here does not seem especially urgent. + +As an optimization for net/http over TCP, if the input connection supports the +ReadFrom() function, the returned connection will as well. This allows the net +package to use sendfile(2) on certain platforms in certain circumstances. +*/ +func WrapConn(c net.Conn) net.Conn { + wg.Add(1) + + nc := conn{ + Conn: c, + closing: make(chan struct{}), + } + + if _, ok := c.(io.ReaderFrom); ok { + c = &sendfile{nc} + } else { + c = &nc + } + + go c.(gracefulConn).gracefulShutdown() + + return c +} + +type connstate int + +/* +State diagram. (Waiting) is the starting state. + +(Waiting) -----Read()-----> Working ---+ + | ^ / | ^ Read() + | \ / | +----+ + kill SetReadDeadline() kill + | | +-----+ + V V V Read() + Dead <-SetReadDeadline()-- Dying ----+ + ^ + | + +--Close()--- [from any state] + +*/ + +const ( + // Waiting for more data, and eligible for killing + csWaiting connstate = iota + // In the middle of a connection + csWorking + // Kill has been requested, but waiting on request to finish up + csDying + // Connection is gone forever. Also used when a connection gets hijacked + csDead +) + +type conn struct { + net.Conn + m sync.Mutex + state connstate + closing chan struct{} +} +type sendfile struct{ conn } + +func (c *conn) gracefulShutdown() { + select { + case <-kill: + case <-c.closing: + return + } + c.m.Lock() + defer c.m.Unlock() + + switch c.state { + case csWaiting: + c.unlockedClose(true) + case csWorking: + c.state = csDying + } +} + +func (c *conn) unlockedClose(closeConn bool) { + if closeConn { + c.Conn.Close() + } + close(c.closing) + wg.Done() + c.state = csDead +} + +// We do some hijinks to support hijacking. The semantics here is that any +// connection that gets hijacked is dead to us: we return the raw net.Conn and +// stop tracking the connection entirely. +type hijackConn interface { + hijack() net.Conn +} + +func (c *conn) hijack() net.Conn { + c.m.Lock() + defer c.m.Unlock() + if c.state != csDead { + close(c.closing) + wg.Done() + c.state = csDead + } + return c.Conn +} + +func (c *conn) Read(b []byte) (n int, err error) { + defer func() { + c.m.Lock() + defer c.m.Unlock() + + if c.state == csWaiting { + c.state = csWorking + } + }() + + return c.Conn.Read(b) +} +func (c *conn) Close() error { + defer func() { + c.m.Lock() + defer c.m.Unlock() + + if c.state != csDead { + c.unlockedClose(false) + } + }() + return c.Conn.Close() +} +func (c *conn) SetReadDeadline(t time.Time) error { + defer func() { + c.m.Lock() + defer c.m.Unlock() + switch c.state { + case csDying: + c.unlockedClose(false) + case csWorking: + c.state = csWaiting + } + }() + return c.Conn.SetReadDeadline(t) +} + +func (s *sendfile) ReadFrom(r io.Reader) (int64, error) { + // conn.Conn.KHAAAAAAAANNNNNN + return s.conn.Conn.(io.ReaderFrom).ReadFrom(r) +} diff --git a/graceful/net_test.go b/graceful/net_test.go new file mode 100644 index 0000000..d6e7208 --- /dev/null +++ b/graceful/net_test.go @@ -0,0 +1,198 @@ +package graceful + +import ( + "io" + "net" + "strings" + "testing" + "time" +) + +var b = make([]byte, 0) + +func connify(c net.Conn) *conn { + switch c.(type) { + case (*conn): + return c.(*conn) + case (*sendfile): + return &c.(*sendfile).conn + default: + panic("IDK") + } +} + +func assertState(t *testing.T, n net.Conn, st connstate) { + c := connify(n) + c.m.Lock() + defer c.m.Unlock() + if c.state != st { + t.Fatalf("conn was %v, but expected %v", c.state, st) + } +} + +// Not super happy about making the tests dependent on the passing of time, but +// I'm not really sure what else to do. + +func expectCall(t *testing.T, ch <-chan struct{}, name string) { + select { + case <-ch: + case <-time.After(5 * time.Millisecond): + t.Fatalf("Expected call to %s", name) + } +} + +func TestCounting(t *testing.T) { + kill = make(chan struct{}) + c := WrapConn(fakeConn{}) + ch := make(chan struct{}) + + go func() { + wg.Wait() + ch <- struct{}{} + }() + + select { + case <-ch: + t.Fatal("Expected connection to keep us from quitting") + case <-time.After(5 * time.Millisecond): + } + + c.Close() + expectCall(t, ch, "wg.Wait()") +} + +func TestStateTransitions1(t *testing.T) { + kill = make(chan struct{}) + ch := make(chan struct{}) + + onclose := make(chan struct{}) + read := make(chan struct{}) + deadline := make(chan struct{}) + c := WrapConn(fakeConn{ + onClose: func() { + onclose <- struct{}{} + }, + onRead: func() { + read <- struct{}{} + }, + onSetReadDeadline: func() { + deadline <- struct{}{} + }, + }) + + go func() { + wg.Wait() + ch <- struct{}{} + }() + + assertState(t, c, csWaiting) + + // Waiting + Read() = Working + go c.Read(b) + expectCall(t, read, "c.Read()") + assertState(t, c, csWorking) + + // Working + SetReadDeadline() = Waiting + go c.SetReadDeadline(time.Now()) + expectCall(t, deadline, "c.SetReadDeadline()") + assertState(t, c, csWaiting) + + // Waiting + kill = Dead + close(kill) + expectCall(t, onclose, "c.Close()") + assertState(t, c, csDead) + + expectCall(t, ch, "wg.Wait()") +} + +func TestStateTransitions2(t *testing.T) { + kill = make(chan struct{}) + ch := make(chan struct{}) + onclose := make(chan struct{}) + read := make(chan struct{}) + deadline := make(chan struct{}) + c := WrapConn(fakeConn{ + onClose: func() { + onclose <- struct{}{} + }, + onRead: func() { + read <- struct{}{} + }, + onSetReadDeadline: func() { + deadline <- struct{}{} + }, + }) + + go func() { + wg.Wait() + ch <- struct{}{} + }() + + assertState(t, c, csWaiting) + + // Waiting + Read() = Working + go c.Read(b) + expectCall(t, read, "c.Read()") + assertState(t, c, csWorking) + + // Working + Read() = Working + go c.Read(b) + expectCall(t, read, "c.Read()") + assertState(t, c, csWorking) + + // Working + kill = Dying + close(kill) + time.Sleep(5 * time.Millisecond) + assertState(t, c, csDying) + + // Dying + Read() = Dying + go c.Read(b) + expectCall(t, read, "c.Read()") + assertState(t, c, csDying) + + // Dying + SetReadDeadline() = Dead + go c.SetReadDeadline(time.Now()) + expectCall(t, deadline, "c.SetReadDeadline()") + assertState(t, c, csDead) + + expectCall(t, ch, "wg.Wait()") +} + +func TestHijack(t *testing.T) { + kill = make(chan struct{}) + fake := fakeConn{} + c := WrapConn(fake) + ch := make(chan struct{}) + + go func() { + wg.Wait() + ch <- struct{}{} + }() + + cc := connify(c) + if _, ok := cc.hijack().(fakeConn); !ok { + t.Error("Expected original connection back out") + } + assertState(t, c, csDead) + expectCall(t, ch, "wg.Wait()") +} + +type fakeSendfile struct { + fakeConn +} + +func (f fakeSendfile) ReadFrom(r io.Reader) (int64, error) { + return 0, nil +} + +func TestReadFrom(t *testing.T) { + kill = make(chan struct{}) + c := WrapConn(fakeSendfile{}) + r := strings.NewReader("Hello world") + + if rf, ok := c.(io.ReaderFrom); ok { + rf.ReadFrom(r) + } else { + t.Fatal("Expected a ReaderFrom in return") + } +} diff --git a/graceful/signal.go b/graceful/signal.go new file mode 100644 index 0000000..9a764a5 --- /dev/null +++ b/graceful/signal.go @@ -0,0 +1,117 @@ +package graceful + +import ( + "log" + "os" + "os/signal" + "sync" +) + +// This is the channel that the connections select on. When it is closed, the +// connections should gracefully exit. +var kill = make(chan struct{}) + +// This is the channel that the Wait() function selects on. It should only be +// closed once all the posthooks have been called. +var wait = make(chan struct{}) + +// This is the WaitGroup that indicates when all the connections have gracefully +// shut down. +var wg sync.WaitGroup + +// This lock protects the list of pre- and post- hooks below. +var hookLock sync.Mutex +var prehooks = make([]func(), 0) +var posthooks = make([]func(), 0) + +var sigchan = make(chan os.Signal, 1) + +func init() { + AddSignal(os.Interrupt) + go waitForSignal() +} + +// Add the given signal to the set of signals that trigger a graceful shutdown. +// Note that for convenience the default interrupt (SIGINT) handler is installed +// at package load time, and unless you call ResetSignals() will be listened for +// in addition to any signals you provide by calling this function. +func AddSignal(sig ...os.Signal) { + signal.Notify(sigchan, sig...) +} + +// Reset the list of signals that trigger a graceful shutdown. Useful if, for +// instance, you don't want to use the default interrupt (SIGINT) handler. Since +// we necessarily install the SIGINT handler before you have a chance to call +// ResetSignals(), there will be a brief window during which the set of signals +// this package listens for will not be as you intend. Therefore, if you intend +// on using this function, we encourage you to call it as soon as possible. +func ResetSignals() { + signal.Stop(sigchan) +} + +type userShutdown struct{} + +func (u userShutdown) String() string { + return "application initiated shutdown" +} +func (u userShutdown) Signal() {} + +// Manually trigger a shutdown from your application. Like Wait(), blocks until +// all connections have gracefully shut down. +func Shutdown() { + sigchan <- userShutdown{} + <-wait +} + +// Register a function to be called before any of this package's normal shutdown +// actions. All listeners will be called in the order they were added, from a +// single goroutine. +func PreHook(f func()) { + hookLock.Lock() + defer hookLock.Unlock() + + prehooks = append(prehooks, f) +} + +// Register a function to be called after all of this package's normal shutdown +// actions. All listeners will be called in the order they were added, from a +// single goroutine, and are guaranteed to be called after all listening +// connections have been closed, but before Wait() returns. +// +// If you've Hijack()ed any connections that must be gracefully shut down in +// some other way (since this library disowns all hijacked connections), it's +// reasonable to use a PostHook() to signal and wait for them. +func PostHook(f func()) { + hookLock.Lock() + defer hookLock.Unlock() + + posthooks = append(posthooks, f) +} + +func waitForSignal() { + sig := <-sigchan + log.Printf("Received %v, gracefully shutting down!", sig) + + hookLock.Lock() + defer hookLock.Unlock() + + for _, f := range prehooks { + f() + } + + close(kill) + wg.Wait() + + for _, f := range posthooks { + f() + } + + close(wait) +} + +// Wait for all connections to gracefully shut down. This is commonly called at +// the bottom of the main() function to prevent the program from exiting +// prematurely. +func Wait() { + <-wait +}