diff --git a/context.go b/context.go index bcc3faa..a7f7d85 100644 --- a/context.go +++ b/context.go @@ -11,7 +11,7 @@ import ( ) var ( - mutex sync.Mutex + mutex sync.RWMutex data = make(map[*http.Request]map[interface{}]interface{}) datat = make(map[*http.Request]int64) ) @@ -19,63 +19,64 @@ var ( // Set stores a value for a given key in a given request. func Set(r *http.Request, key, val interface{}) { mutex.Lock() - defer mutex.Unlock() if data[r] == nil { data[r] = make(map[interface{}]interface{}) datat[r] = time.Now().Unix() } data[r][key] = val + mutex.Unlock() } // Get returns a value stored for a given key in a given request. func Get(r *http.Request, key interface{}) interface{} { - mutex.Lock() - defer mutex.Unlock() + mutex.RLock() if data[r] != nil { + mutex.RUnlock() return data[r][key] } + mutex.RUnlock() return nil } // GetOk returns stored value and presence state like multi-value return of map access. func GetOk(r *http.Request, key interface{}) (interface{}, bool) { - mutex.Lock() - defer mutex.Unlock() + mutex.RLock() if _, ok := data[r]; ok { value, ok := data[r][key] + mutex.RUnlock() return value, ok } + mutex.RUnlock() return nil, false } // GetAll returns all stored values for the request as a map. Nil is returned for invalid requests. func GetAll(r *http.Request) map[interface{}]interface{} { - mutex.Lock() - defer mutex.Unlock() - + mutex.RLock() if context, ok := data[r]; ok { + mutex.RUnlock() return context } + mutex.RUnlock() return nil } // GetAllOk returns all stored values for the request as a map. It returns not // ok if the request was never registered. func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) { - mutex.Lock() - defer mutex.Unlock() - + mutex.RLock() context, ok := data[r] + mutex.RUnlock() return context, ok } // Delete removes a value stored for a given key in a given request. func Delete(r *http.Request, key interface{}) { mutex.Lock() - defer mutex.Unlock() if data[r] != nil { delete(data[r], key) } + mutex.Unlock() } // Clear removes all values stored for a given request. @@ -84,8 +85,8 @@ func Delete(r *http.Request, key interface{}) { // variables at the end of a request lifetime. See ClearHandler(). func Clear(r *http.Request) { mutex.Lock() - defer mutex.Unlock() clear(r) + mutex.Unlock() } // clear is Clear without the lock. @@ -105,7 +106,6 @@ func clear(r *http.Request) { // periodically until the problem is fixed. func Purge(maxAge int) int { mutex.Lock() - defer mutex.Unlock() count := 0 if maxAge <= 0 { count = len(data) @@ -120,6 +120,7 @@ func Purge(maxAge int) int { } } } + mutex.Unlock() return count } diff --git a/context_test.go b/context_test.go index e1dca3b..6ada8ec 100644 --- a/context_test.go +++ b/context_test.go @@ -85,3 +85,77 @@ func TestContext(t *testing.T) { Clear(r) assertEqual(len(data), 0) } + +func parallelReader(r *http.Request, key string, iterations int, wait, done chan struct{}) { + <-wait + for i := 0; i < iterations; i++ { + Get(r, key) + } + done <- struct{}{} + +} + +func parallelWriter(r *http.Request, key, value string, iterations int, wait, done chan struct{}) { + <-wait + for i := 0; i < iterations; i++ { + Get(r, key) + } + done <- struct{}{} + +} + +func benchmarkMutex(b *testing.B, numReaders, numWriters, iterations int) { + + b.StopTimer() + r, _ := http.NewRequest("GET", "http://localhost:8080/", nil) + done := make(chan struct{}) + b.StartTimer() + + for i := 0; i < b.N; i++ { + wait := make(chan struct{}) + + for i := 0; i < numReaders; i++ { + go parallelReader(r, "test", iterations, wait, done) + } + + for i := 0; i < numWriters; i++ { + go parallelWriter(r, "test", "123", iterations, wait, done) + } + + close(wait) + + for i := 0; i < numReaders+numWriters; i++ { + <-done + } + + } + +} + +func BenchmarkMutexSameReadWrite1(b *testing.B) { + benchmarkMutex(b, 1, 1, 32) +} +func BenchmarkMutexSameReadWrite2(b *testing.B) { + benchmarkMutex(b, 2, 2, 32) +} +func BenchmarkMutexSameReadWrite4(b *testing.B) { + benchmarkMutex(b, 4, 4, 32) +} +func BenchmarkMutex1(b *testing.B) { + benchmarkMutex(b, 2, 8, 32) +} +func BenchmarkMutex2(b *testing.B) { + benchmarkMutex(b, 16, 4, 64) +} +func BenchmarkMutex3(b *testing.B) { + benchmarkMutex(b, 1, 2, 128) +} +func BenchmarkMutex4(b *testing.B) { + benchmarkMutex(b, 128, 32, 256) +} +func BenchmarkMutex5(b *testing.B) { + benchmarkMutex(b, 1024, 2048, 64) +} +func BenchmarkMutex6(b *testing.B) { + benchmarkMutex(b, 2048, 1024, 512) +} diff --git a/doc.go b/doc.go index 2976064..73c7400 100644 --- a/doc.go +++ b/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -Package gorilla/context stores values shared during a request lifetime. +Package context stores values shared during a request lifetime. For example, a router can set variables extracted from the URL and later application handlers can access those values, or it can be used to store