Skip to content

Instantly share code, notes, and snippets.

@CAFxX
Created April 22, 2022 09:40
Show Gist options
  • Save CAFxX/4923e8ea2dbb574625f9e57b0cd93636 to your computer and use it in GitHub Desktop.
Save CAFxX/4923e8ea2dbb574625f9e57b0cd93636 to your computer and use it in GitHub Desktop.
Golang map with custom-compare/hash keys
package main
import (
"encoding/binary"
"fmt"
"hash/maphash"
"math/rand"
)
type Comparable[T any] interface {
Equal(other T) (isEqual bool)
Hash(seed uintptr) (hash uintptr)
}
type Map[K Comparable[K], V any] struct {
m map[uintptr][]entry[K, V]
seed uintptr
len int
}
type entry[K Comparable[K], V any] struct {
key K
value V
}
func New[K Comparable[K], V any]() *Map[K, V] {
m := make(map[uintptr][]entry[K, V])
seed := uintptr(rand.Uint64())
return &Map[K, V]{m: m, seed: seed}
}
func (m *Map[K, V]) Get(key K) (V, bool) {
hash := key.Hash(m.seed)
for _, entry := range m.m[hash] {
if key.Equal(entry.key) {
return entry.value, true
}
}
var zero V
return zero, false
}
func (m *Map[K, V]) Put(key K, value V) (V, bool) {
hash := key.Hash(m.seed)
bucket := m.m[hash]
for idx := range bucket {
if key.Equal(bucket[idx].key) {
oldValue := bucket[idx].value
bucket[idx].value = value
return oldValue, true
}
}
bucket = append(bucket, entry[K, V]{key, value})
m.m[hash] = bucket
m.len++
var zero V
return zero, false
}
func (m *Map[K, V]) Delete(key K) (V, bool) {
hash := key.Hash(m.seed)
bucket := m.m[hash]
for idx := range bucket {
if key.Equal(bucket[idx].key) {
oldValue := bucket[idx].value
bucket[idx] = bucket[len(bucket)-1]
bucket = bucket[:len(bucket)-1]
if len(bucket) == 0 {
delete(m.m, hash)
} else {
m.m[hash] = bucket
}
m.len--
return oldValue, true
}
}
var zero V
return zero, false
}
func (m *Map[K, V]) ForEach(callback func(key K, value V) bool) bool {
for _, bucket := range m.m {
for _, entry := range bucket {
if !callback(entry.key, entry.value) {
return false
}
}
}
return true
}
func (m *Map[K, V]) Len() int {
return m.len
}
type str struct {
s string
}
func (s str) Equal(other str) bool {
return s.s == other.s
}
func (s str) Hash(seed uintptr) uintptr {
var hash maphash.Hash
hash.SetSeed(_seed)
var buf [8]byte
binary.LittleEndian.PutUint64(buf[:], uint64(seed))
hash.Write(buf[:])
hash.WriteString(s.s)
return uintptr(hash.Sum64())
}
func (s str) String() string {
return s.s
}
var _seed = maphash.MakeSeed()
func main() {
m := New[str, string]()
m.Put(str{"hello"}, "world")
m.Put(str{"key"}, "value")
m.Put(str{"hello"}, "world!")
m.ForEach(func(k str, v string) bool {
fmt.Println(k, v)
return true
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment