Package graceful provides graceful shutdown support for net/http servers, net.Listeners and net.Conns. It does this through terrible, terrible hacks, but "oh well!"
| @ -0,0 +1,63 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "net" | |||||
| "time" | |||||
| ) | |||||
| // Stub out a net.Conn. This is going to be painful. | |||||
| type fakeAddr struct{} | |||||
| func (f fakeAddr) Network() string { | |||||
| return "fake" | |||||
| } | |||||
| func (f fakeAddr) String() string { | |||||
| return "fake" | |||||
| } | |||||
| type fakeConn struct { | |||||
| onRead, onWrite, onClose, onLocalAddr, onRemoteAddr func() | |||||
| onSetDeadline, onSetReadDeadline, onSetWriteDeadline func() | |||||
| } | |||||
| // Here's my number, so... | |||||
| func callMeMaybe(f func()) { | |||||
| // I apologize for nothing. | |||||
| if f != nil { | |||||
| f() | |||||
| } | |||||
| } | |||||
| func (f fakeConn) Read(b []byte) (int, error) { | |||||
| callMeMaybe(f.onRead) | |||||
| return len(b), nil | |||||
| } | |||||
| func (f fakeConn) Write(b []byte) (int, error) { | |||||
| callMeMaybe(f.onWrite) | |||||
| return len(b), nil | |||||
| } | |||||
| func (f fakeConn) Close() error { | |||||
| callMeMaybe(f.onClose) | |||||
| return nil | |||||
| } | |||||
| func (f fakeConn) LocalAddr() net.Addr { | |||||
| callMeMaybe(f.onLocalAddr) | |||||
| return fakeAddr{} | |||||
| } | |||||
| func (f fakeConn) RemoteAddr() net.Addr { | |||||
| callMeMaybe(f.onRemoteAddr) | |||||
| return fakeAddr{} | |||||
| } | |||||
| func (f fakeConn) SetDeadline(t time.Time) error { | |||||
| callMeMaybe(f.onSetDeadline) | |||||
| return nil | |||||
| } | |||||
| func (f fakeConn) SetReadDeadline(t time.Time) error { | |||||
| callMeMaybe(f.onSetReadDeadline) | |||||
| return nil | |||||
| } | |||||
| func (f fakeConn) SetWriteDeadline(t time.Time) error { | |||||
| callMeMaybe(f.onSetWriteDeadline) | |||||
| return nil | |||||
| } | |||||
| @ -0,0 +1,22 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "log" | |||||
| "os" | |||||
| "strconv" | |||||
| "syscall" | |||||
| ) | |||||
| func init() { | |||||
| // This is a little unfortunate: goji/bind already knows whether we're | |||||
| // running under einhorn, but we don't want to introduce a dependency | |||||
| // between the two packages. Since the check is short enough, inlining | |||||
| // it here seems "fine." | |||||
| mpid, err := strconv.Atoi(os.Getenv("EINHORN_MASTER_PID")) | |||||
| if err != nil || mpid != os.Getppid() { | |||||
| return | |||||
| } | |||||
| log.Print("graceful: Einhorn detected, adding SIGUSR2 handler") | |||||
| AddSignal(syscall.SIGUSR2) | |||||
| } | |||||
| @ -0,0 +1,136 @@ | |||||
| /* | |||||
| Package graceful implements graceful shutdown for HTTP servers by closing idle | |||||
| connections after receiving a signal. By default, this package listens for | |||||
| interrupts (i.e., SIGINT), but when it detects that it is running under Einhorn | |||||
| it will additionally listen for SIGUSR2 as well, giving your application | |||||
| automatic support for graceful upgrades. | |||||
| It's worth mentioning explicitly that this package is a hack to shim graceful | |||||
| shutdown behavior into the net/http package provided in Go 1.2. It was written | |||||
| by carefully reading the sequence of function calls net/http happened to use as | |||||
| of this writing and finding enough surface area with which to add appropriate | |||||
| behavior. There's a very good chance that this package will cease to work in | |||||
| future versions of Go, but with any luck the standard library will add support | |||||
| of its own by then. | |||||
| If you're interested in figuring out how this package works, we suggest you read | |||||
| the documentation for WrapConn() and net.go. | |||||
| */ | |||||
| package graceful | |||||
| import ( | |||||
| "crypto/tls" | |||||
| "net" | |||||
| "net/http" | |||||
| "time" | |||||
| ) | |||||
| // Exactly like net/http's Server. In fact, it *is* a net/http Server, just with | |||||
| // different method implementations | |||||
| type Server http.Server | |||||
| // About 200 years, also known as "forever" | |||||
| const forever time.Duration = 200 * 365 * 24 * time.Hour | |||||
| /* | |||||
| You might notice that these methods look awfully similar to the methods of the | |||||
| same name from the go standard library--that's because they were stolen from | |||||
| there! If go were more like, say, Ruby, it'd actually be possible to shim just | |||||
| the Serve() method, since we can do everything we want from there. However, it's | |||||
| not possible to get the other methods which call Serve() (ListenAndServe(), say) | |||||
| to call your shimmed copy--they always call the original. | |||||
| Since I couldn't come up with a better idea, I just copy-and-pasted both | |||||
| ListenAndServe and ListenAndServeTLS here more-or-less verbatim. "Oh well!" | |||||
| */ | |||||
| // Behaves exactly like the net/http function of the same name. | |||||
| func (srv *Server) Serve(l net.Listener) (err error) { | |||||
| go func() { | |||||
| <-kill | |||||
| l.Close() | |||||
| }() | |||||
| l = WrapListener(l) | |||||
| // Spawn a shadow http.Server to do the actual servering. We do this | |||||
| // because we need to sketch on some of the parameters you passed in, | |||||
| // and it's nice to keep our sketching to ourselves. | |||||
| shadow := *(*http.Server)(srv) | |||||
| if shadow.ReadTimeout == 0 { | |||||
| shadow.ReadTimeout = forever | |||||
| } | |||||
| shadow.Handler = Middleware(shadow.Handler) | |||||
| err = shadow.Serve(l) | |||||
| // We expect an error when we close the listener, so we indiscriminately | |||||
| // swallow Serve errors when we're in a shutdown state. | |||||
| select { | |||||
| case <-kill: | |||||
| return nil | |||||
| default: | |||||
| return err | |||||
| } | |||||
| } | |||||
| // Behaves exactly like the net/http function of the same name. | |||||
| func (srv *Server) ListenAndServe() error { | |||||
| addr := srv.Addr | |||||
| if addr == "" { | |||||
| addr = ":http" | |||||
| } | |||||
| l, e := net.Listen("tcp", addr) | |||||
| if e != nil { | |||||
| return e | |||||
| } | |||||
| return srv.Serve(l) | |||||
| } | |||||
| // Behaves exactly like the net/http function of the same name. | |||||
| func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { | |||||
| addr := srv.Addr | |||||
| if addr == "" { | |||||
| addr = ":https" | |||||
| } | |||||
| config := &tls.Config{} | |||||
| if srv.TLSConfig != nil { | |||||
| *config = *srv.TLSConfig | |||||
| } | |||||
| if config.NextProtos == nil { | |||||
| config.NextProtos = []string{"http/1.1"} | |||||
| } | |||||
| var err error | |||||
| config.Certificates = make([]tls.Certificate, 1) | |||||
| config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| conn, err := net.Listen("tcp", addr) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| tlsListener := tls.NewListener(conn, config) | |||||
| return srv.Serve(tlsListener) | |||||
| } | |||||
| // Behaves exactly like the net/http function of the same name. | |||||
| func ListenAndServe(addr string, handler http.Handler) error { | |||||
| server := &Server{Addr: addr, Handler: handler} | |||||
| return server.ListenAndServe() | |||||
| } | |||||
| // Behaves exactly like the net/http function of the same name. | |||||
| func ListenAndServeTLS(addr, certfile, keyfile string, handler http.Handler) error { | |||||
| server := &Server{Addr: addr, Handler: handler} | |||||
| return server.ListenAndServeTLS(certfile, keyfile) | |||||
| } | |||||
| // Behaves exactly like the net/http function of the same name. | |||||
| func Serve(l net.Listener, handler http.Handler) error { | |||||
| server := &Server{Handler: handler} | |||||
| return server.Serve(l) | |||||
| } | |||||
| @ -0,0 +1,106 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "bufio" | |||||
| "net" | |||||
| "net/http" | |||||
| ) | |||||
| /* | |||||
| Graceful shutdown middleware. When a graceful shutdown is in progress, this | |||||
| middleware intercepts responses to add a "Connection: close" header to politely | |||||
| inform the client that we are about to go away. | |||||
| This package creates a shim http.ResponseWriter that it passes to subsequent | |||||
| handlers. Unfortunately, there's a great many optional interfaces that this | |||||
| http.ResponseWriter might implement (e.g., http.CloseNotifier, http.Flusher, and | |||||
| http.Hijacker), and in order to perfectly proxy all of these options we'd be | |||||
| left with some kind of awful powerset of ResponseWriters, and that's not even | |||||
| counting all the other custom interfaces you might be expecting. Instead of | |||||
| doing that, we have implemented two kinds of proxies: one that contains no | |||||
| additional methods (i.e., exactly corresponding to the http.ResponseWriter | |||||
| interface), and one that supports all three of http.CloseNotifier, http.Flusher, | |||||
| and http.Hijacker. If you find that this is not enough, the original | |||||
| http.ResponseWriter can be retrieved by calling Unwrap() on the proxy object. | |||||
| This middleware is automatically applied to every http.Handler passed to this | |||||
| package, and most users will not need to call this function directly. It is | |||||
| exported primarily for documentation purposes and in the off chance that someone | |||||
| really wants more control over their http.Server than we currently provide. | |||||
| */ | |||||
| func Middleware(h http.Handler) http.Handler { | |||||
| if h == nil { | |||||
| return nil | |||||
| } | |||||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
| _, cn := w.(http.CloseNotifier) | |||||
| _, fl := w.(http.Flusher) | |||||
| _, hj := w.(http.Hijacker) | |||||
| bw := basicWriter{ResponseWriter: w} | |||||
| if cn && fl && hj { | |||||
| h.ServeHTTP(&fancyWriter{bw}, r) | |||||
| } else { | |||||
| h.ServeHTTP(&bw, r) | |||||
| } | |||||
| if !bw.headerWritten { | |||||
| bw.maybeClose() | |||||
| } | |||||
| }) | |||||
| } | |||||
| type basicWriter struct { | |||||
| http.ResponseWriter | |||||
| headerWritten bool | |||||
| } | |||||
| func (b *basicWriter) maybeClose() { | |||||
| b.headerWritten = true | |||||
| select { | |||||
| case <-kill: | |||||
| b.ResponseWriter.Header().Add("Connection", "close") | |||||
| default: | |||||
| } | |||||
| } | |||||
| func (b *basicWriter) WriteHeader(code int) { | |||||
| b.maybeClose() | |||||
| b.ResponseWriter.WriteHeader(code) | |||||
| } | |||||
| func (b *basicWriter) Write(buf []byte) (int, error) { | |||||
| if !b.headerWritten { | |||||
| b.maybeClose() | |||||
| } | |||||
| return b.ResponseWriter.Write(buf) | |||||
| } | |||||
| func (b *basicWriter) Unwrap() http.ResponseWriter { | |||||
| return b.ResponseWriter | |||||
| } | |||||
| // Optimize for the common case of a ResponseWriter that supports all three of | |||||
| // CloseNotifier, Flusher, and Hijacker. | |||||
| type fancyWriter struct { | |||||
| basicWriter | |||||
| } | |||||
| func (f *fancyWriter) CloseNotify() <-chan bool { | |||||
| cn := f.basicWriter.ResponseWriter.(http.CloseNotifier) | |||||
| return cn.CloseNotify() | |||||
| } | |||||
| func (f *fancyWriter) Flush() { | |||||
| fl := f.basicWriter.ResponseWriter.(http.Flusher) | |||||
| fl.Flush() | |||||
| } | |||||
| func (f *fancyWriter) Hijack() (c net.Conn, b *bufio.ReadWriter, e error) { | |||||
| hj := f.basicWriter.ResponseWriter.(http.Hijacker) | |||||
| c, b, e = hj.Hijack() | |||||
| if conn, ok := c.(hijackConn); ok { | |||||
| c = conn.hijack() | |||||
| } | |||||
| return | |||||
| } | |||||
| @ -0,0 +1,68 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "net/http" | |||||
| "testing" | |||||
| ) | |||||
| type fakeWriter http.Header | |||||
| func (f fakeWriter) Header() http.Header { | |||||
| return http.Header(f) | |||||
| } | |||||
| func (f fakeWriter) Write(buf []byte) (int, error) { | |||||
| return len(buf), nil | |||||
| } | |||||
| func (f fakeWriter) WriteHeader(status int) {} | |||||
| func testClose(t *testing.T, h http.Handler, expectClose bool) { | |||||
| m := Middleware(h) | |||||
| r, _ := http.NewRequest("GET", "/", nil) | |||||
| w := make(fakeWriter) | |||||
| m.ServeHTTP(w, r) | |||||
| c, ok := w["Connection"] | |||||
| if expectClose { | |||||
| if !ok || len(c) != 1 || c[0] != "close" { | |||||
| t.Fatal("Expected 'Connection: close'") | |||||
| } | |||||
| } else { | |||||
| if ok { | |||||
| t.Fatal("Did not expect Connection header") | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestNormal(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
| w.Write([]byte{}) | |||||
| }) | |||||
| testClose(t, h, false) | |||||
| } | |||||
| func TestClose(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
| close(kill) | |||||
| }) | |||||
| testClose(t, h, true) | |||||
| } | |||||
| func TestCloseWriteHeader(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
| close(kill) | |||||
| w.WriteHeader(200) | |||||
| }) | |||||
| testClose(t, h, true) | |||||
| } | |||||
| func TestCloseWrite(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
| close(kill) | |||||
| w.Write([]byte{}) | |||||
| }) | |||||
| testClose(t, h, true) | |||||
| } | |||||
| @ -0,0 +1,198 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "io" | |||||
| "net" | |||||
| "sync" | |||||
| "time" | |||||
| ) | |||||
| type listener struct { | |||||
| net.Listener | |||||
| } | |||||
| type gracefulConn interface { | |||||
| gracefulShutdown() | |||||
| } | |||||
| // Wrap an arbitrary net.Listener for use with graceful shutdowns. All | |||||
| // net.Conn's Accept()ed by this listener will be auto-wrapped as if WrapConn() | |||||
| // were called on them. | |||||
| func WrapListener(l net.Listener) net.Listener { | |||||
| return listener{l} | |||||
| } | |||||
| func (l listener) Accept() (net.Conn, error) { | |||||
| conn, err := l.Listener.Accept() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return WrapConn(conn), nil | |||||
| } | |||||
| /* | |||||
| Wrap an arbitrary connection for use with graceful shutdowns. The graceful | |||||
| shutdown process will ensure that this connection is closed before terminating | |||||
| the process. | |||||
| In order to use this function, you must call SetReadDeadline() before the call | |||||
| to Read() you might make to read a new request off the wire. The connection is | |||||
| eligible for abrupt closing at any point between when the call to | |||||
| SetReadDeadline() returns and when the call to Read returns with new data. It | |||||
| does not matter what deadline is given to SetReadDeadline()--the default HTTP | |||||
| server provided by this package sets a deadline far into the future when a | |||||
| deadline is not provided, for instance. | |||||
| Unfortunately, this means that it's difficult to use SetReadDeadline() in a | |||||
| great many perfectly reasonable circumstances, such as to extend a deadline | |||||
| after more data has been read, without the connection being eligible for | |||||
| "graceful" termination at an undesirable time. Since this package was written | |||||
| explicitly to target net/http, which does not as of this writing do any of this, | |||||
| fixing the semantics here does not seem especially urgent. | |||||
| As an optimization for net/http over TCP, if the input connection supports the | |||||
| ReadFrom() function, the returned connection will as well. This allows the net | |||||
| package to use sendfile(2) on certain platforms in certain circumstances. | |||||
| */ | |||||
| func WrapConn(c net.Conn) net.Conn { | |||||
| wg.Add(1) | |||||
| nc := conn{ | |||||
| Conn: c, | |||||
| closing: make(chan struct{}), | |||||
| } | |||||
| if _, ok := c.(io.ReaderFrom); ok { | |||||
| c = &sendfile{nc} | |||||
| } else { | |||||
| c = &nc | |||||
| } | |||||
| go c.(gracefulConn).gracefulShutdown() | |||||
| return c | |||||
| } | |||||
| type connstate int | |||||
| /* | |||||
| State diagram. (Waiting) is the starting state. | |||||
| (Waiting) -----Read()-----> Working ---+ | |||||
| | ^ / | ^ Read() | |||||
| | \ / | +----+ | |||||
| kill SetReadDeadline() kill | |||||
| | | +-----+ | |||||
| V V V Read() | |||||
| Dead <-SetReadDeadline()-- Dying ----+ | |||||
| ^ | |||||
| | | |||||
| +--Close()--- [from any state] | |||||
| */ | |||||
| const ( | |||||
| // Waiting for more data, and eligible for killing | |||||
| csWaiting connstate = iota | |||||
| // In the middle of a connection | |||||
| csWorking | |||||
| // Kill has been requested, but waiting on request to finish up | |||||
| csDying | |||||
| // Connection is gone forever. Also used when a connection gets hijacked | |||||
| csDead | |||||
| ) | |||||
| type conn struct { | |||||
| net.Conn | |||||
| m sync.Mutex | |||||
| state connstate | |||||
| closing chan struct{} | |||||
| } | |||||
| type sendfile struct{ conn } | |||||
| func (c *conn) gracefulShutdown() { | |||||
| select { | |||||
| case <-kill: | |||||
| case <-c.closing: | |||||
| return | |||||
| } | |||||
| c.m.Lock() | |||||
| defer c.m.Unlock() | |||||
| switch c.state { | |||||
| case csWaiting: | |||||
| c.unlockedClose(true) | |||||
| case csWorking: | |||||
| c.state = csDying | |||||
| } | |||||
| } | |||||
| func (c *conn) unlockedClose(closeConn bool) { | |||||
| if closeConn { | |||||
| c.Conn.Close() | |||||
| } | |||||
| close(c.closing) | |||||
| wg.Done() | |||||
| c.state = csDead | |||||
| } | |||||
| // We do some hijinks to support hijacking. The semantics here is that any | |||||
| // connection that gets hijacked is dead to us: we return the raw net.Conn and | |||||
| // stop tracking the connection entirely. | |||||
| type hijackConn interface { | |||||
| hijack() net.Conn | |||||
| } | |||||
| func (c *conn) hijack() net.Conn { | |||||
| c.m.Lock() | |||||
| defer c.m.Unlock() | |||||
| if c.state != csDead { | |||||
| close(c.closing) | |||||
| wg.Done() | |||||
| c.state = csDead | |||||
| } | |||||
| return c.Conn | |||||
| } | |||||
| func (c *conn) Read(b []byte) (n int, err error) { | |||||
| defer func() { | |||||
| c.m.Lock() | |||||
| defer c.m.Unlock() | |||||
| if c.state == csWaiting { | |||||
| c.state = csWorking | |||||
| } | |||||
| }() | |||||
| return c.Conn.Read(b) | |||||
| } | |||||
| func (c *conn) Close() error { | |||||
| defer func() { | |||||
| c.m.Lock() | |||||
| defer c.m.Unlock() | |||||
| if c.state != csDead { | |||||
| c.unlockedClose(false) | |||||
| } | |||||
| }() | |||||
| return c.Conn.Close() | |||||
| } | |||||
| func (c *conn) SetReadDeadline(t time.Time) error { | |||||
| defer func() { | |||||
| c.m.Lock() | |||||
| defer c.m.Unlock() | |||||
| switch c.state { | |||||
| case csDying: | |||||
| c.unlockedClose(false) | |||||
| case csWorking: | |||||
| c.state = csWaiting | |||||
| } | |||||
| }() | |||||
| return c.Conn.SetReadDeadline(t) | |||||
| } | |||||
| func (s *sendfile) ReadFrom(r io.Reader) (int64, error) { | |||||
| // conn.Conn.KHAAAAAAAANNNNNN | |||||
| return s.conn.Conn.(io.ReaderFrom).ReadFrom(r) | |||||
| } | |||||
| @ -0,0 +1,198 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "io" | |||||
| "net" | |||||
| "strings" | |||||
| "testing" | |||||
| "time" | |||||
| ) | |||||
| var b = make([]byte, 0) | |||||
| func connify(c net.Conn) *conn { | |||||
| switch c.(type) { | |||||
| case (*conn): | |||||
| return c.(*conn) | |||||
| case (*sendfile): | |||||
| return &c.(*sendfile).conn | |||||
| default: | |||||
| panic("IDK") | |||||
| } | |||||
| } | |||||
| func assertState(t *testing.T, n net.Conn, st connstate) { | |||||
| c := connify(n) | |||||
| c.m.Lock() | |||||
| defer c.m.Unlock() | |||||
| if c.state != st { | |||||
| t.Fatalf("conn was %v, but expected %v", c.state, st) | |||||
| } | |||||
| } | |||||
| // Not super happy about making the tests dependent on the passing of time, but | |||||
| // I'm not really sure what else to do. | |||||
| func expectCall(t *testing.T, ch <-chan struct{}, name string) { | |||||
| select { | |||||
| case <-ch: | |||||
| case <-time.After(5 * time.Millisecond): | |||||
| t.Fatalf("Expected call to %s", name) | |||||
| } | |||||
| } | |||||
| func TestCounting(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| c := WrapConn(fakeConn{}) | |||||
| ch := make(chan struct{}) | |||||
| go func() { | |||||
| wg.Wait() | |||||
| ch <- struct{}{} | |||||
| }() | |||||
| select { | |||||
| case <-ch: | |||||
| t.Fatal("Expected connection to keep us from quitting") | |||||
| case <-time.After(5 * time.Millisecond): | |||||
| } | |||||
| c.Close() | |||||
| expectCall(t, ch, "wg.Wait()") | |||||
| } | |||||
| func TestStateTransitions1(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| ch := make(chan struct{}) | |||||
| onclose := make(chan struct{}) | |||||
| read := make(chan struct{}) | |||||
| deadline := make(chan struct{}) | |||||
| c := WrapConn(fakeConn{ | |||||
| onClose: func() { | |||||
| onclose <- struct{}{} | |||||
| }, | |||||
| onRead: func() { | |||||
| read <- struct{}{} | |||||
| }, | |||||
| onSetReadDeadline: func() { | |||||
| deadline <- struct{}{} | |||||
| }, | |||||
| }) | |||||
| go func() { | |||||
| wg.Wait() | |||||
| ch <- struct{}{} | |||||
| }() | |||||
| assertState(t, c, csWaiting) | |||||
| // Waiting + Read() = Working | |||||
| go c.Read(b) | |||||
| expectCall(t, read, "c.Read()") | |||||
| assertState(t, c, csWorking) | |||||
| // Working + SetReadDeadline() = Waiting | |||||
| go c.SetReadDeadline(time.Now()) | |||||
| expectCall(t, deadline, "c.SetReadDeadline()") | |||||
| assertState(t, c, csWaiting) | |||||
| // Waiting + kill = Dead | |||||
| close(kill) | |||||
| expectCall(t, onclose, "c.Close()") | |||||
| assertState(t, c, csDead) | |||||
| expectCall(t, ch, "wg.Wait()") | |||||
| } | |||||
| func TestStateTransitions2(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| ch := make(chan struct{}) | |||||
| onclose := make(chan struct{}) | |||||
| read := make(chan struct{}) | |||||
| deadline := make(chan struct{}) | |||||
| c := WrapConn(fakeConn{ | |||||
| onClose: func() { | |||||
| onclose <- struct{}{} | |||||
| }, | |||||
| onRead: func() { | |||||
| read <- struct{}{} | |||||
| }, | |||||
| onSetReadDeadline: func() { | |||||
| deadline <- struct{}{} | |||||
| }, | |||||
| }) | |||||
| go func() { | |||||
| wg.Wait() | |||||
| ch <- struct{}{} | |||||
| }() | |||||
| assertState(t, c, csWaiting) | |||||
| // Waiting + Read() = Working | |||||
| go c.Read(b) | |||||
| expectCall(t, read, "c.Read()") | |||||
| assertState(t, c, csWorking) | |||||
| // Working + Read() = Working | |||||
| go c.Read(b) | |||||
| expectCall(t, read, "c.Read()") | |||||
| assertState(t, c, csWorking) | |||||
| // Working + kill = Dying | |||||
| close(kill) | |||||
| time.Sleep(5 * time.Millisecond) | |||||
| assertState(t, c, csDying) | |||||
| // Dying + Read() = Dying | |||||
| go c.Read(b) | |||||
| expectCall(t, read, "c.Read()") | |||||
| assertState(t, c, csDying) | |||||
| // Dying + SetReadDeadline() = Dead | |||||
| go c.SetReadDeadline(time.Now()) | |||||
| expectCall(t, deadline, "c.SetReadDeadline()") | |||||
| assertState(t, c, csDead) | |||||
| expectCall(t, ch, "wg.Wait()") | |||||
| } | |||||
| func TestHijack(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| fake := fakeConn{} | |||||
| c := WrapConn(fake) | |||||
| ch := make(chan struct{}) | |||||
| go func() { | |||||
| wg.Wait() | |||||
| ch <- struct{}{} | |||||
| }() | |||||
| cc := connify(c) | |||||
| if _, ok := cc.hijack().(fakeConn); !ok { | |||||
| t.Error("Expected original connection back out") | |||||
| } | |||||
| assertState(t, c, csDead) | |||||
| expectCall(t, ch, "wg.Wait()") | |||||
| } | |||||
| type fakeSendfile struct { | |||||
| fakeConn | |||||
| } | |||||
| func (f fakeSendfile) ReadFrom(r io.Reader) (int64, error) { | |||||
| return 0, nil | |||||
| } | |||||
| func TestReadFrom(t *testing.T) { | |||||
| kill = make(chan struct{}) | |||||
| c := WrapConn(fakeSendfile{}) | |||||
| r := strings.NewReader("Hello world") | |||||
| if rf, ok := c.(io.ReaderFrom); ok { | |||||
| rf.ReadFrom(r) | |||||
| } else { | |||||
| t.Fatal("Expected a ReaderFrom in return") | |||||
| } | |||||
| } | |||||
| @ -0,0 +1,117 @@ | |||||
| package graceful | |||||
| import ( | |||||
| "log" | |||||
| "os" | |||||
| "os/signal" | |||||
| "sync" | |||||
| ) | |||||
| // This is the channel that the connections select on. When it is closed, the | |||||
| // connections should gracefully exit. | |||||
| var kill = make(chan struct{}) | |||||
| // This is the channel that the Wait() function selects on. It should only be | |||||
| // closed once all the posthooks have been called. | |||||
| var wait = make(chan struct{}) | |||||
| // This is the WaitGroup that indicates when all the connections have gracefully | |||||
| // shut down. | |||||
| var wg sync.WaitGroup | |||||
| // This lock protects the list of pre- and post- hooks below. | |||||
| var hookLock sync.Mutex | |||||
| var prehooks = make([]func(), 0) | |||||
| var posthooks = make([]func(), 0) | |||||
| var sigchan = make(chan os.Signal, 1) | |||||
| func init() { | |||||
| AddSignal(os.Interrupt) | |||||
| go waitForSignal() | |||||
| } | |||||
| // Add the given signal to the set of signals that trigger a graceful shutdown. | |||||
| // Note that for convenience the default interrupt (SIGINT) handler is installed | |||||
| // at package load time, and unless you call ResetSignals() will be listened for | |||||
| // in addition to any signals you provide by calling this function. | |||||
| func AddSignal(sig ...os.Signal) { | |||||
| signal.Notify(sigchan, sig...) | |||||
| } | |||||
| // Reset the list of signals that trigger a graceful shutdown. Useful if, for | |||||
| // instance, you don't want to use the default interrupt (SIGINT) handler. Since | |||||
| // we necessarily install the SIGINT handler before you have a chance to call | |||||
| // ResetSignals(), there will be a brief window during which the set of signals | |||||
| // this package listens for will not be as you intend. Therefore, if you intend | |||||
| // on using this function, we encourage you to call it as soon as possible. | |||||
| func ResetSignals() { | |||||
| signal.Stop(sigchan) | |||||
| } | |||||
| type userShutdown struct{} | |||||
| func (u userShutdown) String() string { | |||||
| return "application initiated shutdown" | |||||
| } | |||||
| func (u userShutdown) Signal() {} | |||||
| // Manually trigger a shutdown from your application. Like Wait(), blocks until | |||||
| // all connections have gracefully shut down. | |||||
| func Shutdown() { | |||||
| sigchan <- userShutdown{} | |||||
| <-wait | |||||
| } | |||||
| // Register a function to be called before any of this package's normal shutdown | |||||
| // actions. All listeners will be called in the order they were added, from a | |||||
| // single goroutine. | |||||
| func PreHook(f func()) { | |||||
| hookLock.Lock() | |||||
| defer hookLock.Unlock() | |||||
| prehooks = append(prehooks, f) | |||||
| } | |||||
| // Register a function to be called after all of this package's normal shutdown | |||||
| // actions. All listeners will be called in the order they were added, from a | |||||
| // single goroutine, and are guaranteed to be called after all listening | |||||
| // connections have been closed, but before Wait() returns. | |||||
| // | |||||
| // If you've Hijack()ed any connections that must be gracefully shut down in | |||||
| // some other way (since this library disowns all hijacked connections), it's | |||||
| // reasonable to use a PostHook() to signal and wait for them. | |||||
| func PostHook(f func()) { | |||||
| hookLock.Lock() | |||||
| defer hookLock.Unlock() | |||||
| posthooks = append(posthooks, f) | |||||
| } | |||||
| func waitForSignal() { | |||||
| sig := <-sigchan | |||||
| log.Printf("Received %v, gracefully shutting down!", sig) | |||||
| hookLock.Lock() | |||||
| defer hookLock.Unlock() | |||||
| for _, f := range prehooks { | |||||
| f() | |||||
| } | |||||
| close(kill) | |||||
| wg.Wait() | |||||
| for _, f := range posthooks { | |||||
| f() | |||||
| } | |||||
| close(wait) | |||||
| } | |||||
| // Wait for all connections to gracefully shut down. This is commonly called at | |||||
| // the bottom of the main() function to prevent the program from exiting | |||||
| // prematurely. | |||||
| func Wait() { | |||||
| <-wait | |||||
| } | |||||