Skip to content

Instantly share code, notes, and snippets.

@tarampampam
Last active December 2, 2024 08:00
Show Gist options
  • Save tarampampam/f96538257ff125ab71785710d48b3118 to your computer and use it in GitHub Desktop.
Save tarampampam/f96538257ff125ab71785710d48b3118 to your computer and use it in GitHub Desktop.
Golang SyncMap (sync.Map with generics, type-safe)
package syncmap
import "sync"
// SyncMap is like a Go sync.Map but type-safe using generics.
//
// The zero SyncMap is empty and ready for use. A SyncMap must not be copied after first use.
type SyncMap[K comparable, V any] struct {
mu sync.Mutex
m map[K]V
}
// Grow grows the map to the given size. It can be called before the first write operation used.
func (s *SyncMap[K, V]) Grow(size int) {
s.mu.Lock()
s.grow(size)
s.mu.Unlock()
}
func (s *SyncMap[K, V]) grow(size ...int) {
if s.m == nil {
if len(size) == 0 {
s.m = make(map[K]V) // let runtime decide the needed map size
} else {
s.m = make(map[K]V, size[0])
}
}
}
// Clone returns a copy (clone) of current SyncMap.
func (s *SyncMap[K, V]) Clone() SyncMap[K, V] {
s.mu.Lock()
defer s.mu.Unlock()
var clone = make(map[K]V, len(s.m))
for k, v := range s.m {
clone[k] = v
}
return SyncMap[K, V]{m: clone}
}
// Load returns the value stored in the map for a key, or nil if no value is present.
// The ok result indicates whether value was found in the map.
func (s *SyncMap[K, V]) Load(key K) (value V, loaded bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.m == nil { // fast operation terminator
return
}
value, loaded = s.m[key]
return
}
// Store sets the value for a key.
func (s *SyncMap[K, V]) Store(key K, value V) {
s.mu.Lock()
defer s.mu.Unlock()
s.grow()
s.m[key] = value
}
// LoadOrStore returns the existing value for the key if present. Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (s *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
s.mu.Lock()
defer s.mu.Unlock()
if actual, loaded = s.m[key]; !loaded {
s.grow()
s.m[key], actual = value, value
}
return
}
// LoadAndDelete deletes the value for a key, returning the previous value if any. The loaded result reports whether
// the key was present.
func (s *SyncMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.m == nil { // fast operation terminator
return
}
s.grow()
if value, loaded = s.m[key]; loaded {
delete(s.m, key)
}
return
}
// Delete deletes the value for a key.
func (s *SyncMap[K, V]) Delete(key K) {
s.mu.Lock()
defer s.mu.Unlock()
if s.m == nil { // fast operation terminator
return
}
s.grow()
delete(s.m, key)
}
// Range calls f sequentially for each key and value present in the map. If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot of the Map's contents: no key will be visited more
// than once. Range does not block other methods on the receiver; even f itself may call any method on m.
func (s *SyncMap[K, V]) Range(f func(key K, value V) (shouldContinue bool)) {
s.mu.Lock()
if s.m == nil { // fast operation terminator
s.mu.Unlock()
return
}
s.grow()
for k, v := range s.m {
s.mu.Unlock()
if !f(k, v) {
return
}
s.mu.Lock()
}
s.mu.Unlock()
}
// Len returns the count of values in the map.
func (s *SyncMap[K, V]) Len() (l int) {
s.mu.Lock()
l = len(s.m)
s.mu.Unlock()
return
}
// Keys return slice with all map keys.
func (s *SyncMap[K, V]) Keys() []K {
s.mu.Lock()
defer s.mu.Unlock()
var keys, i = make([]K, len(s.m)), 0
for k := range s.m {
keys[i], i = k, i+1
}
return keys
}
// Values return slice with all map values.
func (s *SyncMap[K, V]) Values() []V {
s.mu.Lock()
defer s.mu.Unlock()
var values, i = make([]V, len(s.m)), 0
for _, v := range s.m {
values[i], i = v, i+1
}
return values
}
package syncmap_test
import (
"sort"
"sync"
"testing"
"github.com/stretchr/testify/require"
"app/internal/syncmap"
)
func TestSyncMap_InitialState(t *testing.T) {
var m syncmap.SyncMap[string, int]
require.Equal(t, 0, m.Len())
require.EqualValues(t, []string{}, m.Keys())
require.EqualValues(t, []int{}, m.Values())
}
func TestSyncMap_Clone(t *testing.T) {
var m syncmap.SyncMap[string, int]
const key1, key2, val1, val2 = "a", "b", 111, 222
m.Store(key1, val1)
require.Equal(t, 1, m.Len())
require.EqualValues(t, []string{key1}, m.Keys())
require.EqualValues(t, []int{val1}, m.Values())
v, loaded := m.Load(key1)
require.True(t, loaded)
require.EqualValues(t, val1, v)
v, loaded = m.Load(key2)
require.False(t, loaded)
require.EqualValues(t, 0, v)
var clone = m.Clone()
require.EqualValues(t, m.Len(), clone.Len())
v, loaded = clone.Load(key1)
require.True(t, loaded)
require.EqualValues(t, val1, v)
m.Store(key1, val2) // overwrite in original map
v, loaded = clone.Load(key1)
require.True(t, loaded)
require.EqualValues(t, val1, v)
v, loaded = m.Load(key1)
require.True(t, loaded)
require.EqualValues(t, val2, v)
}
func TestSyncMap_Grow(t *testing.T) {
var m syncmap.SyncMap[string, int]
require.EqualValues(t, 0, m.Len())
m.Grow(3)
require.EqualValues(t, 0, m.Len())
}
func TestSyncMap_Load(t *testing.T) {
var m syncmap.SyncMap[string, int]
const key = "a"
value, ok := m.Load(key) // not exists
require.False(t, ok)
require.EqualValues(t, 0, value)
m.Store(key, 111)
m.Store(key, 111)
m.Store(key, 111) // repeated call
value, ok = m.Load(key) // exists
require.True(t, ok)
require.EqualValues(t, 111, value)
}
func TestSyncMap_Store(t *testing.T) {
var m syncmap.SyncMap[string, int]
m.Grow(2)
const (
key1, key2 = "a", "b"
val1, val2 = 123, 321
)
m.Store(key1, val1)
require.Equal(t, 1, m.Len())
require.EqualValues(t, []string{key1}, m.Keys())
require.EqualValues(t, []int{val1}, m.Values())
m.Store(key2, val2)
m.Store(key2, val2) // repeated call
require.Equal(t, 2, m.Len())
var wantKeys, gotKeys = []string{key2, key1}, m.Keys()
sort.Strings(wantKeys)
sort.Strings(gotKeys)
require.EqualValues(t, wantKeys, gotKeys)
var wantValues, gotValues = []int{val1, val2}, m.Values()
sort.Ints(wantValues)
sort.Ints(gotValues)
require.EqualValues(t, wantValues, gotValues)
}
func TestSyncMap_LoadOrStore(t *testing.T) {
var m syncmap.SyncMap[string, float64]
const (
key = "a"
val1, val2 = 123.123, 321.321
)
v, loaded := m.LoadOrStore(key, val1)
require.False(t, loaded)
require.EqualValues(t, val1, v)
require.Equal(t, 1, m.Len())
v, loaded = m.LoadOrStore(key, val2) // another value is passed
require.True(t, loaded)
require.EqualValues(t, val1, v)
require.Equal(t, 1, m.Len())
}
func TestSyncMap_LoadAndDelete(t *testing.T) {
var m syncmap.SyncMap[string, int]
const key, val = "a", 123
v, loaded := m.LoadAndDelete(key)
require.False(t, loaded)
require.EqualValues(t, 0, v)
m.Store(key, val)
require.Equal(t, 1, m.Len())
require.EqualValues(t, []string{key}, m.Keys())
require.EqualValues(t, []int{val}, m.Values())
v, loaded = m.LoadAndDelete(key)
require.True(t, loaded)
require.EqualValues(t, val, v)
require.Equal(t, 0, m.Len())
require.EqualValues(t, []string{}, m.Keys())
require.EqualValues(t, []int{}, m.Values())
v, loaded = m.LoadAndDelete(key) //nolint:ineffassign
v, loaded = m.LoadAndDelete(key) //nolint:ineffassign
v, loaded = m.LoadAndDelete(key) // repeated call
require.False(t, loaded)
require.EqualValues(t, 0, v)
require.Equal(t, 0, m.Len())
require.EqualValues(t, []string{}, m.Keys())
require.EqualValues(t, []int{}, m.Values())
}
func TestSyncMap_Delete(t *testing.T) {
var m syncmap.SyncMap[string, int]
const key, val = "a", 123
m.Delete(key)
m.Delete(key) // repeated call
require.Equal(t, 0, m.Len())
require.EqualValues(t, []string{}, m.Keys())
require.EqualValues(t, []int{}, m.Values())
m.Store(key, val)
require.Equal(t, 1, m.Len())
require.EqualValues(t, []string{key}, m.Keys())
require.EqualValues(t, []int{val}, m.Values())
m.Delete(key)
m.Delete(key)
m.Delete(key) // repeated call
require.Equal(t, 0, m.Len())
require.EqualValues(t, []string{}, m.Keys())
require.EqualValues(t, []int{}, m.Values())
}
func TestSyncMap_Range(t *testing.T) {
var m syncmap.SyncMap[string, int]
const (
key1, key2 = "a", "b"
val1, val2 = 123, 321
)
var iter uint
m.Range(func(key string, val int) bool {
iter++
return false
})
require.EqualValues(t, 0, iter)
iter = 0 // reset
m.Store(key1, val1)
m.Store(key1, val1) // repeated call
m.Store(key2, val2)
m.Store(key2, val2) // repeated call
require.Equal(t, 2, m.Len())
m.Range(func(key string, val int) bool {
if key == key1 {
require.EqualValues(t, val1, val)
} else if key == key2 {
require.EqualValues(t, val2, val)
}
iter++
return true
})
require.EqualValues(t, 2, iter)
iter = 0 // reset
m.Range(func(key string, val int) bool {
iter++
return false
})
require.EqualValues(t, 1, iter)
}
func TestSyncMap_Struct(t *testing.T) {
type some struct{ foo string }
m := syncmap.SyncMap[[2]int, some]{}
require.Equal(t, 0, m.Len())
var key = [2]int{1, 2}
val, ok := m.Load(key)
require.False(t, ok)
require.EqualValues(t, some{}, val) // NOT nil
m.Store(key, some{"bar"})
require.Equal(t, 1, m.Len())
val, ok = m.Load(key)
require.True(t, ok)
require.EqualValues(t, some{"bar"}, val)
}
func TestSyncMap_Map(t *testing.T) {
type some map[uint]sync.Mutex
m := syncmap.SyncMap[uint, *some]{}
require.Equal(t, 0, m.Len())
var key uint = 1
val, ok := m.Load(key)
require.False(t, ok)
require.Nil(t, val) // nil here is correct
var mu sync.Mutex
m.Store(key, &some{1: mu}) //nolint:govet
require.Equal(t, 1, m.Len())
val, ok = m.Load(key)
require.True(t, ok)
require.EqualValues(t, &some{1: mu}, val) //nolint:govet
}
//go:noinline
func TestNoCopy_ConcurrentUsage(t *testing.T) { // race detector provocation
var (
m syncmap.SyncMap[string, int]
wg sync.WaitGroup
)
for i := 0; i < 100; i++ {
wg.Add(12)
go func() { defer wg.Done(); m.Grow(3) }()
go func() { defer wg.Done(); _, _ = m.LoadOrStore("foo", 1) }() // +
go func() { defer wg.Done(); m.Store("foo", 1) }() // +
go func() { defer wg.Done(); _, _ = m.Load("foo") }()
go func() { defer wg.Done(); _, _ = m.LoadAndDelete("foo") }() // -
go func() { defer wg.Done(); m.Delete("foo") }() // -
go func() { defer wg.Done(); m.Range(func(_ string, _ int) bool { return true }) }()
go func() { defer wg.Done(); m.Range(func(_ string, _ int) bool { return false }) }()
go func() { defer wg.Done(); _ = m.Len() }()
go func() { defer wg.Done(); _ = m.Keys() }()
go func() { defer wg.Done(); _ = m.Values() }()
go func() { defer wg.Done(); _ = m.Clone() }()
}
wg.Wait()
}
// BenchmarkSyncMap_NativeMapMutex-8 20247865 60.29 ns/op 0 B/op 0 allocs/op
// BenchmarkSyncMap_SyncMapUnderTheHood-8 9286198 131.2 ns/op 32 B/op 2 allocs/op
func BenchmarkSyncMap_NativeMap(b *testing.B) {
b.ReportAllocs()
var (
m = syncmap.SyncMap[string, int]{}
v, ok = 0, false
)
const key = "a"
for i := 0; i < b.N; i++ {
m.Store(key, 1)
v, ok = m.Load(key)
m.Delete(key)
}
require.True(b, ok)
require.EqualValues(b, 1, v)
}
// BenchmarkSyncMap_Stdlib-8 13189734 93.28 ns/op 16 B/op 1 allocs/op
func BenchmarkSyncMap_Stdlib(b *testing.B) {
b.ReportAllocs()
var (
m = sync.Map{}
v, ok any = 0, false
)
const key = "a"
for i := 0; i < b.N; i++ {
m.Store(key, 1)
v, ok = m.Load(key)
m.Delete(key)
}
require.True(b, ok.(bool))
require.EqualValues(b, 1, v)
}
@escalopa
Copy link

Why did you use Mutex and not RWMutex?

@tarampampam
Copy link
Author

https://github.com/golang/go/blob/c5adb8216968be46bd11f7b7360a7c8bde1258d9/src/sync/map.go#L43

Actually, it depends on which operations are more frequent for your map. RWMutex uses two mutexes under the hood, which might slightly slow down the code. I believe the best choice for your use case can be determined through benchmarking and profiling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment