Skip to content

Instantly share code, notes, and snippets.

@dtjm
Created July 30, 2014 22:55
Show Gist options
  • Save dtjm/ed5ee637d7d9841af78a to your computer and use it in GitHub Desktop.
Save dtjm/ed5ee637d7d9841af78a to your computer and use it in GitHub Desktop.
Go in-memory cache
package cache
import (
"errors"
"math"
"sync"
"time"
)
type ByteCacher interface {
Get(key string) ([]byte, error)
Set(key string, data []byte, ttl time.Duration) error
}
var tooLargeError error = errors.New("data exceeds cache size limit")
const (
Forever time.Duration = math.MaxInt64
)
type MemoryCache struct {
writeLock sync.Mutex
data map[string][]byte
expirations map[string]time.Time
sizeLimit int
size int
}
func NewMemoryCache(sizeLimit int) *MemoryCache {
c := &MemoryCache{
data: make(map[string][]byte),
expirations: make(map[string]time.Time),
sizeLimit: sizeLimit,
}
return c
}
func (c *MemoryCache) Get(key string) ([]byte, error) {
val, exists := c.data[key]
if !exists {
return []byte{}, nil
}
now := time.Now()
expired := c.expirations[key].Before(now)
if expired {
c.flushExpired()
return []byte{}, nil
}
return val, nil
}
// Set a value in the cache. TTLs are strictly followed. If you
// would like the cached value to exist forever, use the Forever
// constant
func (c *MemoryCache) Set(key string, data []byte, ttl time.Duration) error {
size := len(data)
// We can't store items larger than the size of the cache
if size > c.sizeLimit {
return tooLargeError
}
// If we don't have enough space, flush expired items
if c.size+size > c.sizeLimit {
c.flushExpired()
}
// If we still don't have enough space for this item, flush enough items to
// make room for it
if c.size+size > c.sizeLimit {
c.flushBytes(size)
}
c.writeLock.Lock()
c.data[key] = data
c.expirations[key] = time.Now().Add(ttl)
c.size += size
c.writeLock.Unlock()
return nil
}
func (c *MemoryCache) flush(key string) {
c.size = c.size - len(c.data[key])
delete(c.data, key)
delete(c.expirations, key)
}
func (c *MemoryCache) flushExpired() int {
now := time.Now()
var bytesFlushed int
c.writeLock.Lock()
for k, v := range c.data {
if c.expirations[k].Before(now) {
c.flush(k)
bytesFlushed += len(v)
}
}
c.writeLock.Unlock()
return bytesFlushed
}
// Ranging over a map is supposed to be random, so we
// essentially have a random eviction strategy here
func (c *MemoryCache) flushBytes(n int) (bytesFlushed int) {
c.writeLock.Lock()
for k, v := range c.data {
c.flush(k)
bytesFlushed += len(v)
if bytesFlushed >= n {
break
}
}
c.writeLock.Unlock()
return bytesFlushed
}
package cache
import (
"bytes"
"testing"
"time"
)
func TestSetAndGet(t *testing.T) {
t.Parallel()
cache := NewMemoryCache(10)
key := "foo"
data := []byte("bar")
err := cache.Set(key, data, 1000*time.Millisecond)
if err != nil {
t.Fatalf("Unable to set cache: %s", err)
}
val, err := cache.Get(key)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, val) {
t.Errorf("Expected %+q == %+q", data, val)
}
val, err = cache.Get("not here")
if len(val) != 0 {
t.Errorf("")
}
}
func TestExpiration(t *testing.T) {
t.Parallel()
cache := NewMemoryCache(10)
key := "foo"
data := []byte("bar")
err := cache.Set(key, data, 1*time.Millisecond)
if err != nil {
t.Fatalf("Unable to set cache: %s", err)
}
val, err := cache.Get(key)
if !bytes.Equal(val, data) {
t.Fatalf("Expected '%s' == '%s': %s", data, val, err)
}
time.Sleep(100 * time.Millisecond)
val, err = cache.Get(key)
if !bytes.Equal(val, []byte{}) {
t.Fatalf("Expected '%s' == ''", err)
}
}
func TestEviction(t *testing.T) {
t.Parallel()
cache := NewMemoryCache(10)
key := "foo"
data := make([]byte, 11)
err := cache.Set(key, data, 1*time.Second)
if err == nil {
t.Errorf("Should not cache items larger than size limit")
}
data = make([]byte, 6)
err = cache.Set("foo2", data, 100*time.Millisecond)
if err != nil {
t.Errorf("Unable to set cache: %s", err)
}
err = cache.Set("foo3", data, 100*time.Millisecond)
if err != nil {
t.Errorf("Unable to set cache: %s", err)
}
val, err := cache.Get("foo2")
if err != nil {
t.Errorf("Unable to get cache: %s", err)
}
if !bytes.Equal(val, []byte{}) {
t.Error("foo2 should have been evicted")
}
val, err = cache.Get("foo3")
if err != nil {
t.Errorf("Unable to get cache: %s", err)
}
if !bytes.Equal(val, data) {
t.Error("foo3 should be in the cache")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment