Package graceful provides graceful shutdown support for net/http servers, net.Listeners and net.Conns. It does this through terrible, terrible hacks, but "oh well!"
| @ -0,0 +1,63 @@ | |||
| package graceful | |||
| import ( | |||
| "net" | |||
| "time" | |||
| ) | |||
| // Stub out a net.Conn. This is going to be painful. | |||
| type fakeAddr struct{} | |||
| func (f fakeAddr) Network() string { | |||
| return "fake" | |||
| } | |||
| func (f fakeAddr) String() string { | |||
| return "fake" | |||
| } | |||
| type fakeConn struct { | |||
| onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func() | |||
| onSetDeadline, onSetReadDeadline, onSetWriteDeadline func() | |||
| } | |||
| // Here's my number, so... | |||
| func callMeMaybe(f func()) { | |||
| // I apologize for nothing. | |||
| if f != nil { | |||
| f() | |||
| } | |||
| } | |||
| func (f fakeConn) Read(b []byte) (int, error) { | |||
| callMeMaybe(f.onRead) | |||
| return len(b), nil | |||
| } | |||
| func (f fakeConn) Write(b []byte) (int, error) { | |||
| callMeMaybe(f.onWrite) | |||
| return len(b), nil | |||
| } | |||
| func (f fakeConn) Close() error { | |||
| callMeMaybe(f.onClose) | |||
| return nil | |||
| } | |||
| func (f fakeConn) LocalAddr() net.Addr { | |||
| callMeMaybe(f.onLocalAddr) | |||
| return fakeAddr{} | |||
| } | |||
| func (f fakeConn) RemoteAddr() net.Addr { | |||
| callMeMaybe(f.onRemoteAddr) | |||
| return fakeAddr{} | |||
| } | |||
| func (f fakeConn) SetDeadline(t time.Time) error { | |||
| callMeMaybe(f.onSetDeadline) | |||
| return nil | |||
| } | |||
| func (f fakeConn) SetReadDeadline(t time.Time) error { | |||
| callMeMaybe(f.onSetReadDeadline) | |||
| return nil | |||
| } | |||
| func (f fakeConn) SetWriteDeadline(t time.Time) error { | |||
| callMeMaybe(f.onSetWriteDeadline) | |||
| return nil | |||
| } | |||
| @ -0,0 +1,22 @@ | |||
| package graceful | |||
| import ( | |||
| "log" | |||
| "os" | |||
| "strconv" | |||
| "syscall" | |||
| ) | |||
| func init() { | |||
| // This is a little unfortunate: goji/bind already knows whether we're | |||
| // running under einhorn, but we don't want to introduce a dependency | |||
| // between the two packages. Since the check is short enough, inlining | |||
| // it here seems "fine." | |||
| mpid, err := strconv.Atoi(os.Getenv("EINHORN_MASTER_PID")) | |||
| if err != nil || mpid != os.Getppid() { | |||
| return | |||
| } | |||
| log.Print("graceful: Einhorn detected, adding SIGUSR2 handler") | |||
| AddSignal(syscall.SIGUSR2) | |||
| } | |||
| @ -0,0 +1,136 @@ | |||
| /* | |||
| Package graceful implements graceful shutdown for HTTP servers by closing idle | |||
| connections after receiving a signal. By default, this package listens for | |||
| interrupts (i.e., SIGINT), but when it detects that it is running under Einhorn | |||
| it will additionally listen for SIGUSR2 as well, giving your application | |||
| automatic support for graceful upgrades. | |||
| It's worth mentioning explicitly that this package is a hack to shim graceful | |||
| shutdown behavior into the net/http package provided in Go 1.2. It was written | |||
| by carefully reading the sequence of function calls net/http happened to use as | |||
| of this writing and finding enough surface area with which to add appropriate | |||
| behavior. There's a very good chance that this package will cease to work in | |||
| future versions of Go, but with any luck the standard library will add support | |||
| of its own by then. | |||
| If you're interested in figuring out how this package works, we suggest you read | |||
| the documentation for WrapConn() and net.go. | |||
| */ | |||
| package graceful | |||
| import ( | |||
| "crypto/tls" | |||
| "net" | |||
| "net/http" | |||
| "time" | |||
| ) | |||
| // Exactly like net/http's Server. In fact, it *is* a net/http Server, just with | |||
| // different method implementations | |||
| type Server http.Server | |||
| // About 200 years, also known as "forever" | |||
| const forever time.Duration = 200 * 365 * 24 * time.Hour | |||
| /* | |||
| You might notice that these methods look awfully similar to the methods of the | |||
| same name from the go standard library--that's because they were stolen from | |||
| there! If go were more like, say, Ruby, it'd actually be possible to shim just | |||
| the Serve() method, since we can do everything we want from there. However, it's | |||
| not possible to get the other methods which call Serve() (ListenAndServe(), say) | |||
| to call your shimmed copy--they always call the original. | |||
| Since I couldn't come up with a better idea, I just copy-and-pasted both | |||
| ListenAndServe and ListenAndServeTLS here more-or-less verbatim. "Oh well!" | |||
| */ | |||
| // Behaves exactly like the net/http function of the same name. | |||
| func (srv *Server) Serve(l net.Listener) (err error) { | |||
| go func() { | |||
| <-kill | |||
| l.Close() | |||
| }() | |||
| l = WrapListener(l) | |||
| // Spawn a shadow http.Server to do the actual servering. We do this | |||
| // because we need to sketch on some of the parameters you passed in, | |||
| // and it's nice to keep our sketching to ourselves. | |||
| shadow := *(*http.Server)(srv) | |||
| if shadow.ReadTimeout == 0 { | |||
| shadow.ReadTimeout = forever | |||
| } | |||
| shadow.Handler = Middleware(shadow.Handler) | |||
| err = shadow.Serve(l) | |||
| // We expect an error when we close the listener, so we indiscriminately | |||
| // swallow Serve errors when we're in a shutdown state. | |||
| select { | |||
| case <-kill: | |||
| return nil | |||
| default: | |||
| return err | |||
| } | |||
| } | |||
| // Behaves exactly like the net/http function of the same name. | |||
| func (srv *Server) ListenAndServe() error { | |||
| addr := srv.Addr | |||
| if addr == "" { | |||
| addr = ":http" | |||
| } | |||
| l, e := net.Listen("tcp", addr) | |||
| if e != nil { | |||
| return e | |||
| } | |||
| return srv.Serve(l) | |||
| } | |||
| // Behaves exactly like the net/http function of the same name. | |||
| func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { | |||
| addr := srv.Addr | |||
| if addr == "" { | |||
| addr = ":https" | |||
| } | |||
| config := &tls.Config{} | |||
| if srv.TLSConfig != nil { | |||
| *config = *srv.TLSConfig | |||
| } | |||
| if config.NextProtos == nil { | |||
| config.NextProtos = []string{"http/1.1"} | |||
| } | |||
| var err error | |||
| config.Certificates = make([]tls.Certificate, 1) | |||
| config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| conn, err := net.Listen("tcp", addr) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| tlsListener := tls.NewListener(conn, config) | |||
| return srv.Serve(tlsListener) | |||
| } | |||
| // Behaves exactly like the net/http function of the same name. | |||
| func ListenAndServe(addr string, handler http.Handler) error { | |||
| server := &Server{Addr: addr, Handler: handler} | |||
| return server.ListenAndServe() | |||
| } | |||
| // Behaves exactly like the net/http function of the same name. | |||
| func ListenAndServeTLS(addr, certfile, keyfile string, handler http.Handler) error { | |||
| server := &Server{Addr: addr, Handler: handler} | |||
| return server.ListenAndServeTLS(certfile, keyfile) | |||
| } | |||
| // Behaves exactly like the net/http function of the same name. | |||
| func Serve(l net.Listener, handler http.Handler) error { | |||
| server := &Server{Handler: handler} | |||
| return server.Serve(l) | |||
| } | |||
| @ -0,0 +1,106 @@ | |||
| package graceful | |||
| import ( | |||
| "bufio" | |||
| "net" | |||
| "net/http" | |||
| ) | |||
| /* | |||
| Graceful shutdown middleware. When a graceful shutdown is in progress, this | |||
| middleware intercepts responses to add a "Connection: close" header to politely | |||
| inform the client that we are about to go away. | |||
| This package creates a shim http.ResponseWriter that it passes to subsequent | |||
| handlers. Unfortunately, there's a great many optional interfaces that this | |||
| http.ResponseWriter might implement (e.g., http.CloseNotifier, http.Flusher, and | |||
| http.Hijacker), and in order to perfectly proxy all of these options we'd be | |||
| left with some kind of awful powerset of ResponseWriters, and that's not even | |||
| counting all the other custom interfaces you might be expecting. Instead of | |||
| doing that, we have implemented two kinds of proxies: one that contains no | |||
| additional methods (i.e., exactly corresponding to the http.ResponseWriter | |||
| interface), and one that supports all three of http.CloseNotifier, http.Flusher, | |||
| and http.Hijacker. If you find that this is not enough, the original | |||
| http.ResponseWriter can be retrieved by calling Unwrap() on the proxy object. | |||
| This middleware is automatically applied to every http.Handler passed to this | |||
| package, and most users will not need to call this function directly. It is | |||
| exported primarily for documentation purposes and in the off chance that someone | |||
| really wants more control over their http.Server than we currently provide. | |||
| */ | |||
| func Middleware(h http.Handler) http.Handler { | |||
| if h == nil { | |||
| return nil | |||
| } | |||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
| _, cn := w.(http.CloseNotifier) | |||
| _, fl := w.(http.Flusher) | |||
| _, hj := w.(http.Hijacker) | |||
| bw := basicWriter{ResponseWriter: w} | |||
| if cn && fl && hj { | |||
| h.ServeHTTP(&fancyWriter{bw}, r) | |||
| } else { | |||
| h.ServeHTTP(&bw, r) | |||
| } | |||
| if !bw.headerWritten { | |||
| bw.maybeClose() | |||
| } | |||
| }) | |||
| } | |||
| type basicWriter struct { | |||
| http.ResponseWriter | |||
| headerWritten bool | |||
| } | |||
| func (b *basicWriter) maybeClose() { | |||
| b.headerWritten = true | |||
| select { | |||
| case <-kill: | |||
| b.ResponseWriter.Header().Add("Connection", "close") | |||
| default: | |||
| } | |||
| } | |||
| func (b *basicWriter) WriteHeader(code int) { | |||
| b.maybeClose() | |||
| b.ResponseWriter.WriteHeader(code) | |||
| } | |||
| func (b *basicWriter) Write(buf []byte) (int, error) { | |||
| if !b.headerWritten { | |||
| b.maybeClose() | |||
| } | |||
| return b.ResponseWriter.Write(buf) | |||
| } | |||
| func (b *basicWriter) Unwrap() http.ResponseWriter { | |||
| return b.ResponseWriter | |||
| } | |||
| // Optimize for the common case of a ResponseWriter that supports all three of | |||
| // CloseNotifier, Flusher, and Hijacker. | |||
| type fancyWriter struct { | |||
| basicWriter | |||
| } | |||
| func (f *fancyWriter) CloseNotify() <-chan bool { | |||
| cn := f.basicWriter.ResponseWriter.(http.CloseNotifier) | |||
| return cn.CloseNotify() | |||
| } | |||
| func (f *fancyWriter) Flush() { | |||
| fl := f.basicWriter.ResponseWriter.(http.Flusher) | |||
| fl.Flush() | |||
| } | |||
| func (f *fancyWriter) Hijack() (c net.Conn, b *bufio.ReadWriter, e error) { | |||
| hj := f.basicWriter.ResponseWriter.(http.Hijacker) | |||
| c, b, e = hj.Hijack() | |||
| if conn, ok := c.(hijackConn); ok { | |||
| c = conn.hijack() | |||
| } | |||
| return | |||
| } | |||
| @ -0,0 +1,68 @@ | |||
| package graceful | |||
| import ( | |||
| "net/http" | |||
| "testing" | |||
| ) | |||
| type fakeWriter http.Header | |||
| func (f fakeWriter) Header() http.Header { | |||
| return http.Header(f) | |||
| } | |||
| func (f fakeWriter) Write(buf []byte) (int, error) { | |||
| return len(buf), nil | |||
| } | |||
| func (f fakeWriter) WriteHeader(status int) {} | |||
| func testClose(t *testing.T, h http.Handler, expectClose bool) { | |||
| m := Middleware(h) | |||
| r, _ := http.NewRequest("GET", "/", nil) | |||
| w := make(fakeWriter) | |||
| m.ServeHTTP(w, r) | |||
| c, ok := w["Connection"] | |||
| if expectClose { | |||
| if !ok || len(c) != 1 || c[0] != "close" { | |||
| t.Fatal("Expected 'Connection: close'") | |||
| } | |||
| } else { | |||
| if ok { | |||
| t.Fatal("Did not expect Connection header") | |||
| } | |||
| } | |||
| } | |||
| func TestNormal(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
| w.Write([]byte{}) | |||
| }) | |||
| testClose(t, h, false) | |||
| } | |||
| func TestClose(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
| close(kill) | |||
| }) | |||
| testClose(t, h, true) | |||
| } | |||
| func TestCloseWriteHeader(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
| close(kill) | |||
| w.WriteHeader(200) | |||
| }) | |||
| testClose(t, h, true) | |||
| } | |||
| func TestCloseWrite(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
| close(kill) | |||
| w.Write([]byte{}) | |||
| }) | |||
| testClose(t, h, true) | |||
| } | |||
| @ -0,0 +1,198 @@ | |||
| package graceful | |||
| import ( | |||
| "io" | |||
| "net" | |||
| "sync" | |||
| "time" | |||
| ) | |||
| type listener struct { | |||
| net.Listener | |||
| } | |||
| type gracefulConn interface { | |||
| gracefulShutdown() | |||
| } | |||
| // Wrap an arbitrary net.Listener for use with graceful shutdowns. All | |||
| // net.Conn's Accept()ed by this listener will be auto-wrapped as if WrapConn() | |||
| // were called on them. | |||
| func WrapListener(l net.Listener) net.Listener { | |||
| return listener{l} | |||
| } | |||
| func (l listener) Accept() (net.Conn, error) { | |||
| conn, err := l.Listener.Accept() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return WrapConn(conn), nil | |||
| } | |||
| /* | |||
| Wrap an arbitrary connection for use with graceful shutdowns. The graceful | |||
| shutdown process will ensure that this connection is closed before terminating | |||
| the process. | |||
| In order to use this function, you must call SetReadDeadline() before the call | |||
| to Read() you might make to read a new request off the wire. The connection is | |||
| eligible for abrupt closing at any point between when the call to | |||
| SetReadDeadline() returns and when the call to Read returns with new data. It | |||
| does not matter what deadline is given to SetReadDeadline()--the default HTTP | |||
| server provided by this package sets a deadline far into the future when a | |||
| deadline is not provided, for instance. | |||
| Unfortunately, this means that it's difficult to use SetReadDeadline() in a | |||
| great many perfectly reasonable circumstances, such as to extend a deadline | |||
| after more data has been read, without the connection being eligible for | |||
| "graceful" termination at an undesirable time. Since this package was written | |||
| explicitly to target net/http, which does not as of this writing do any of this, | |||
| fixing the semantics here does not seem especially urgent. | |||
| As an optimization for net/http over TCP, if the input connection supports the | |||
| ReadFrom() function, the returned connection will as well. This allows the net | |||
| package to use sendfile(2) on certain platforms in certain circumstances. | |||
| */ | |||
| func WrapConn(c net.Conn) net.Conn { | |||
| wg.Add(1) | |||
| nc := conn{ | |||
| Conn: c, | |||
| closing: make(chan struct{}), | |||
| } | |||
| if _, ok := c.(io.ReaderFrom); ok { | |||
| c = &sendfile{nc} | |||
| } else { | |||
| c = &nc | |||
| } | |||
| go c.(gracefulConn).gracefulShutdown() | |||
| return c | |||
| } | |||
| type connstate int | |||
| /* | |||
| State diagram. (Waiting) is the starting state. | |||
| (Waiting) -----Read()-----> Working ---+ | |||
| | ^ / | ^ Read() | |||
| | \ / | +----+ | |||
| kill SetReadDeadline() kill | |||
| | | +-----+ | |||
| V V V Read() | |||
| Dead <-SetReadDeadline()-- Dying ----+ | |||
| ^ | |||
| | | |||
| +--Close()--- [from any state] | |||
| */ | |||
| const ( | |||
| // Waiting for more data, and eligible for killing | |||
| csWaiting connstate = iota | |||
| // In the middle of a connection | |||
| csWorking | |||
| // Kill has been requested, but waiting on request to finish up | |||
| csDying | |||
| // Connection is gone forever. Also used when a connection gets hijacked | |||
| csDead | |||
| ) | |||
| type conn struct { | |||
| net.Conn | |||
| m sync.Mutex | |||
| state connstate | |||
| closing chan struct{} | |||
| } | |||
| type sendfile struct{ conn } | |||
| func (c *conn) gracefulShutdown() { | |||
| select { | |||
| case <-kill: | |||
| case <-c.closing: | |||
| return | |||
| } | |||
| c.m.Lock() | |||
| defer c.m.Unlock() | |||
| switch c.state { | |||
| case csWaiting: | |||
| c.unlockedClose(true) | |||
| case csWorking: | |||
| c.state = csDying | |||
| } | |||
| } | |||
| func (c *conn) unlockedClose(closeConn bool) { | |||
| if closeConn { | |||
| c.Conn.Close() | |||
| } | |||
| close(c.closing) | |||
| wg.Done() | |||
| c.state = csDead | |||
| } | |||
| // We do some hijinks to support hijacking. The semantics here is that any | |||
| // connection that gets hijacked is dead to us: we return the raw net.Conn and | |||
| // stop tracking the connection entirely. | |||
| type hijackConn interface { | |||
| hijack() net.Conn | |||
| } | |||
| func (c *conn) hijack() net.Conn { | |||
| c.m.Lock() | |||
| defer c.m.Unlock() | |||
| if c.state != csDead { | |||
| close(c.closing) | |||
| wg.Done() | |||
| c.state = csDead | |||
| } | |||
| return c.Conn | |||
| } | |||
| func (c *conn) Read(b []byte) (n int, err error) { | |||
| defer func() { | |||
| c.m.Lock() | |||
| defer c.m.Unlock() | |||
| if c.state == csWaiting { | |||
| c.state = csWorking | |||
| } | |||
| }() | |||
| return c.Conn.Read(b) | |||
| } | |||
| func (c *conn) Close() error { | |||
| defer func() { | |||
| c.m.Lock() | |||
| defer c.m.Unlock() | |||
| if c.state != csDead { | |||
| c.unlockedClose(false) | |||
| } | |||
| }() | |||
| return c.Conn.Close() | |||
| } | |||
| func (c *conn) SetReadDeadline(t time.Time) error { | |||
| defer func() { | |||
| c.m.Lock() | |||
| defer c.m.Unlock() | |||
| switch c.state { | |||
| case csDying: | |||
| c.unlockedClose(false) | |||
| case csWorking: | |||
| c.state = csWaiting | |||
| } | |||
| }() | |||
| return c.Conn.SetReadDeadline(t) | |||
| } | |||
| func (s *sendfile) ReadFrom(r io.Reader) (int64, error) { | |||
| // conn.Conn.KHAAAAAAAANNNNNN | |||
| return s.conn.Conn.(io.ReaderFrom).ReadFrom(r) | |||
| } | |||
| @ -0,0 +1,198 @@ | |||
| package graceful | |||
| import ( | |||
| "io" | |||
| "net" | |||
| "strings" | |||
| "testing" | |||
| "time" | |||
| ) | |||
| var b = make([]byte, 0) | |||
| func connify(c net.Conn) *conn { | |||
| switch c.(type) { | |||
| case (*conn): | |||
| return c.(*conn) | |||
| case (*sendfile): | |||
| return &c.(*sendfile).conn | |||
| default: | |||
| panic("IDK") | |||
| } | |||
| } | |||
| func assertState(t *testing.T, n net.Conn, st connstate) { | |||
| c := connify(n) | |||
| c.m.Lock() | |||
| defer c.m.Unlock() | |||
| if c.state != st { | |||
| t.Fatalf("conn was %v, but expected %v", c.state, st) | |||
| } | |||
| } | |||
| // Not super happy about making the tests dependent on the passing of time, but | |||
| // I'm not really sure what else to do. | |||
| func expectCall(t *testing.T, ch <-chan struct{}, name string) { | |||
| select { | |||
| case <-ch: | |||
| case <-time.After(5 * time.Millisecond): | |||
| t.Fatalf("Expected call to %s", name) | |||
| } | |||
| } | |||
| func TestCounting(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| c := WrapConn(fakeConn{}) | |||
| ch := make(chan struct{}) | |||
| go func() { | |||
| wg.Wait() | |||
| ch <- struct{}{} | |||
| }() | |||
| select { | |||
| case <-ch: | |||
| t.Fatal("Expected connection to keep us from quitting") | |||
| case <-time.After(5 * time.Millisecond): | |||
| } | |||
| c.Close() | |||
| expectCall(t, ch, "wg.Wait()") | |||
| } | |||
| func TestStateTransitions1(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| ch := make(chan struct{}) | |||
| onclose := make(chan struct{}) | |||
| read := make(chan struct{}) | |||
| deadline := make(chan struct{}) | |||
| c := WrapConn(fakeConn{ | |||
| onClose: func() { | |||
| onclose <- struct{}{} | |||
| }, | |||
| onRead: func() { | |||
| read <- struct{}{} | |||
| }, | |||
| onSetReadDeadline: func() { | |||
| deadline <- struct{}{} | |||
| }, | |||
| }) | |||
| go func() { | |||
| wg.Wait() | |||
| ch <- struct{}{} | |||
| }() | |||
| assertState(t, c, csWaiting) | |||
| // Waiting + Read() = Working | |||
| go c.Read(b) | |||
| expectCall(t, read, "c.Read()") | |||
| assertState(t, c, csWorking) | |||
| // Working + SetReadDeadline() = Waiting | |||
| go c.SetReadDeadline(time.Now()) | |||
| expectCall(t, deadline, "c.SetReadDeadline()") | |||
| assertState(t, c, csWaiting) | |||
| // Waiting + kill = Dead | |||
| close(kill) | |||
| expectCall(t, onclose, "c.Close()") | |||
| assertState(t, c, csDead) | |||
| expectCall(t, ch, "wg.Wait()") | |||
| } | |||
| func TestStateTransitions2(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| ch := make(chan struct{}) | |||
| onclose := make(chan struct{}) | |||
| read := make(chan struct{}) | |||
| deadline := make(chan struct{}) | |||
| c := WrapConn(fakeConn{ | |||
| onClose: func() { | |||
| onclose <- struct{}{} | |||
| }, | |||
| onRead: func() { | |||
| read <- struct{}{} | |||
| }, | |||
| onSetReadDeadline: func() { | |||
| deadline <- struct{}{} | |||
| }, | |||
| }) | |||
| go func() { | |||
| wg.Wait() | |||
| ch <- struct{}{} | |||
| }() | |||
| assertState(t, c, csWaiting) | |||
| // Waiting + Read() = Working | |||
| go c.Read(b) | |||
| expectCall(t, read, "c.Read()") | |||
| assertState(t, c, csWorking) | |||
| // Working + Read() = Working | |||
| go c.Read(b) | |||
| expectCall(t, read, "c.Read()") | |||
| assertState(t, c, csWorking) | |||
| // Working + kill = Dying | |||
| close(kill) | |||
| time.Sleep(5 * time.Millisecond) | |||
| assertState(t, c, csDying) | |||
| // Dying + Read() = Dying | |||
| go c.Read(b) | |||
| expectCall(t, read, "c.Read()") | |||
| assertState(t, c, csDying) | |||
| // Dying + SetReadDeadline() = Dead | |||
| go c.SetReadDeadline(time.Now()) | |||
| expectCall(t, deadline, "c.SetReadDeadline()") | |||
| assertState(t, c, csDead) | |||
| expectCall(t, ch, "wg.Wait()") | |||
| } | |||
| func TestHijack(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| fake := fakeConn{} | |||
| c := WrapConn(fake) | |||
| ch := make(chan struct{}) | |||
| go func() { | |||
| wg.Wait() | |||
| ch <- struct{}{} | |||
| }() | |||
| cc := connify(c) | |||
| if _, ok := cc.hijack().(fakeConn); !ok { | |||
| t.Error("Expected original connection back out") | |||
| } | |||
| assertState(t, c, csDead) | |||
| expectCall(t, ch, "wg.Wait()") | |||
| } | |||
| type fakeSendfile struct { | |||
| fakeConn | |||
| } | |||
| func (f fakeSendfile) ReadFrom(r io.Reader) (int64, error) { | |||
| return 0, nil | |||
| } | |||
| func TestReadFrom(t *testing.T) { | |||
| kill = make(chan struct{}) | |||
| c := WrapConn(fakeSendfile{}) | |||
| r := strings.NewReader("Hello world") | |||
| if rf, ok := c.(io.ReaderFrom); ok { | |||
| rf.ReadFrom(r) | |||
| } else { | |||
| t.Fatal("Expected a ReaderFrom in return") | |||
| } | |||
| } | |||
| @ -0,0 +1,117 @@ | |||
| package graceful | |||
| import ( | |||
| "log" | |||
| "os" | |||
| "os/signal" | |||
| "sync" | |||
| ) | |||
| // This is the channel that the connections select on. When it is closed, the | |||
| // connections should gracefully exit. | |||
| var kill = make(chan struct{}) | |||
| // This is the channel that the Wait() function selects on. It should only be | |||
| // closed once all the posthooks have been called. | |||
| var wait = make(chan struct{}) | |||
| // This is the WaitGroup that indicates when all the connections have gracefully | |||
| // shut down. | |||
| var wg sync.WaitGroup | |||
| // This lock protects the list of pre- and post- hooks below. | |||
| var hookLock sync.Mutex | |||
| var prehooks = make([]func(), 0) | |||
| var posthooks = make([]func(), 0) | |||
| var sigchan = make(chan os.Signal, 1) | |||
| func init() { | |||
| AddSignal(os.Interrupt) | |||
| go waitForSignal() | |||
| } | |||
| // Add the given signal to the set of signals that trigger a graceful shutdown. | |||
| // Note that for convenience the default interrupt (SIGINT) handler is installed | |||
| // at package load time, and unless you call ResetSignals() will be listened for | |||
| // in addition to any signals you provide by calling this function. | |||
| func AddSignal(sig ...os.Signal) { | |||
| signal.Notify(sigchan, sig...) | |||
| } | |||
| // Reset the list of signals that trigger a graceful shutdown. Useful if, for | |||
| // instance, you don't want to use the default interrupt (SIGINT) handler. Since | |||
| // we necessarily install the SIGINT handler before you have a chance to call | |||
| // ResetSignals(), there will be a brief window during which the set of signals | |||
| // this package listens for will not be as you intend. Therefore, if you intend | |||
| // on using this function, we encourage you to call it as soon as possible. | |||
| func ResetSignals() { | |||
| signal.Stop(sigchan) | |||
| } | |||
| type userShutdown struct{} | |||
| func (u userShutdown) String() string { | |||
| return "application initiated shutdown" | |||
| } | |||
| func (u userShutdown) Signal() {} | |||
| // Manually trigger a shutdown from your application. Like Wait(), blocks until | |||
| // all connections have gracefully shut down. | |||
| func Shutdown() { | |||
| sigchan <- userShutdown{} | |||
| <-wait | |||
| } | |||
| // Register a function to be called before any of this package's normal shutdown | |||
| // actions. All listeners will be called in the order they were added, from a | |||
| // single goroutine. | |||
| func PreHook(f func()) { | |||
| hookLock.Lock() | |||
| defer hookLock.Unlock() | |||
| prehooks = append(prehooks, f) | |||
| } | |||
| // Register a function to be called after all of this package's normal shutdown | |||
| // actions. All listeners will be called in the order they were added, from a | |||
| // single goroutine, and are guaranteed to be called after all listening | |||
| // connections have been closed, but before Wait() returns. | |||
| // | |||
| // If you've Hijack()ed any connections that must be gracefully shut down in | |||
| // some other way (since this library disowns all hijacked connections), it's | |||
| // reasonable to use a PostHook() to signal and wait for them. | |||
| func PostHook(f func()) { | |||
| hookLock.Lock() | |||
| defer hookLock.Unlock() | |||
| posthooks = append(posthooks, f) | |||
| } | |||
| func waitForSignal() { | |||
| sig := <-sigchan | |||
| log.Printf("Received %v, gracefully shutting down!", sig) | |||
| hookLock.Lock() | |||
| defer hookLock.Unlock() | |||
| for _, f := range prehooks { | |||
| f() | |||
| } | |||
| close(kill) | |||
| wg.Wait() | |||
| for _, f := range posthooks { | |||
| f() | |||
| } | |||
| close(wait) | |||
| } | |||
| // Wait for all connections to gracefully shut down. This is commonly called at | |||
| // the bottom of the main() function to prevent the program from exiting | |||
| // prematurely. | |||
| func Wait() { | |||
| <-wait | |||
| } | |||