Skip to content

Instantly share code, notes, and snippets.

@kmpm
Created March 29, 2023 14:24
Show Gist options
  • Save kmpm/42ef1f9b08cd312973f968f1e7128960 to your computer and use it in GitHub Desktop.
Save kmpm/42ef1f9b08cd312973f968f1e7128960 to your computer and use it in GitHub Desktop.
Using NATS KV for creating locks.
package lock
import (
"context"
"errors"
"strings"
"sync"
"time"
"testapp/logger"
"github.com/nats-io/nats.go"
"go.uber.org/zap"
)
type Options struct {
onReleased func(*Lock)
onAcquired func(*Lock)
onLost func(*Lock)
}
type Option func(*Options) error
func OnAcquired(cb func(*Lock)) Option {
return func(o *Options) error {
o.onAcquired = cb
return nil
}
}
func OnLost(cb func(*Lock)) Option {
return func(o *Options) error {
o.onLost = cb
return nil
}
}
type Lock struct {
name string
kv nats.KeyValue
ttl time.Duration
w nats.KeyWatcher
cRefresh chan struct{}
cWatch chan struct{}
state string
ticket []byte
rev uint64
opts *Options
log *zap.Logger
mu sync.Mutex
}
var reregisterDelay time.Duration = 2 * time.Second
func New(ctx context.Context, kv nats.KeyValue, name, ticket string, opts ...Option) (*Lock, error) {
if name == "" {
return nil, errors.New("key can not be empty")
}
if kv == nil {
return nil, errors.New("kv can not be nil")
}
ctx = logger.NewContext(ctx, "lock", zap.String("lock", name))
l := &Lock{
name: name,
kv: kv,
ttl: 30 * time.Second,
state: "released",
ticket: []byte(ticket),
log: logger.L(ctx),
opts: &Options{},
// OnReleased: make(chan string),
}
l.SetOptions(opts...)
return l, nil
}
func (l *Lock) SetOptions(opts ...Option) {
l.mu.Lock()
defer l.mu.Unlock()
for _, opt := range opts {
opt(l.opts)
}
}
func (l *Lock) HasAcquisition() bool {
return l.state == "registerd" && l.rev > 0
}
// setState changes state and returns true if there was a change and false if no change
func (l *Lock) setState(v string) bool {
if l.state != v {
l.log.Debug("changing state", zap.String("state", l.state), zap.String("to_state", v))
l.state = v
return true
}
return false
}
func (l *Lock) register() error {
log := l.log
oldstate := l.state
var err error
var rev uint64
if !l.setState("registering") {
return errors.New("already registering")
}
defer func() {
go l.refresh()
}()
rev, err = l.kv.Create(l.name, l.ticket)
if err != nil {
// logger.Warning("failed to register %s: %v", l.name, err)
l.setState(oldstate)
return err
}
l.setState("registerd")
l.rev = rev
if l.opts.onAcquired != nil {
log.Info("pushing on acquired")
l.opts.onAcquired(l)
} else {
log.Info("no one listens for acquisition")
}
return nil
}
func (l *Lock) registerDelayed(d time.Duration) {
go func() {
time.Sleep(d)
l.register()
}()
}
// watch is made to be run as a gorutine
func (l *Lock) watch() {
log := l.log
if l.w != nil {
log.Warn("watcher is already running")
return
}
l.cWatch = make(chan struct{})
var err error
l.w, err = l.kv.Watch(l.name)
if err != nil {
log.Error("error getting watcher", zap.Error(err))
return
}
defer log.Debug("watcher exited")
for {
select {
case kve := <-l.w.Updates():
if kve != nil {
switch kve.Operation() {
case nats.KeyValueDelete:
log.Debug("delete", zap.String("state", l.state), zap.String("value", string(kve.Value())))
if strings.HasPrefix(l.state, "reg") {
l.lost()
}
if !strings.HasPrefix(l.state, "rel") {
l.registerDelayed(reregisterDelay)
}
// default:
// logger.Debug("%s @ %d -> %q (op: %s)\n", kve.Key(), kve.Revision(), string(kve.Value()), kve.Operation())
}
}
case <-l.cWatch:
log.Debug("stop watching")
if l.w != nil {
l.w.Stop()
l.w = nil
}
return
}
}
}
func (l *Lock) deleteKV() {
if l.rev != 0 {
l.log.Debug("deleting kv")
err := l.kv.Delete(l.name, nats.LastRevision(l.rev))
if err != nil {
l.log.Warn("could not delete kv", zap.Error(err))
} else {
l.log.Debug("deleted")
}
l.rev = 0
} else {
l.log.Debug("no kv to delete")
}
}
func (l *Lock) lost() {
log := l.log
l.setState("lost")
log.Info("lost lock")
l.deleteKV()
l.rev = 0
if l.opts.onLost != nil {
l.opts.onLost(l)
}
}
func (l *Lock) refresh() {
log := l.log
if l.cRefresh != nil {
return
}
timeout := l.ttl / 2
log.Debug("refresh started", zap.Duration("timeout", timeout))
l.cRefresh = make(chan struct{})
ticker := time.NewTicker(timeout)
bvalue := l.ticket
defer func() {
log.Debug("refresh exited")
l.cRefresh = nil
}()
for {
select {
case <-ticker.C:
log.Debug("on tick",
zap.String("state", l.state),
zap.Uint64("rev", l.rev),
zap.ByteString("ticket", l.ticket))
switch l.state {
case "registerd":
if nrev, err := l.kv.Update(l.name, bvalue, l.rev); err != nil {
log.Warn("update error", zap.Error(err))
l.lost()
} else {
// logger.Debug("refreshed '%s' with rev %d", l.name, l.rev)
l.rev = nrev
}
case "lost", "released":
if err := l.register(); err != nil {
if errors.Is(err, nats.ErrKeyExists) {
log.Debug("someone else has lock", zap.Error(err))
} else {
log.Warn("error registering", zap.Error(err))
}
}
default:
log.Debug("unhandled state", zap.String("state", l.state))
}
case <-l.cRefresh:
log.Debug("cRefresh closed")
return
}
}
}
func (l *Lock) Release() {
log := l.log
if l.state != "released" {
l.setState("releasing")
} else {
log.Debug("already released", zap.Bool("cWatch", l.cWatch != nil), zap.Bool("cRefresh", l.cRefresh != nil))
return
}
log.Debug("releaseing", zap.String("state", l.state))
if l.cRefresh != nil {
log.Debug("closing cRefresh channel")
close(l.cRefresh)
l.cRefresh = nil
}
if l.cWatch != nil {
log.Debug("closing cWatch channel")
close(l.cWatch)
l.cWatch = nil
}
l.setState("released")
l.deleteKV()
if l.opts.onReleased != nil {
log.Debug("sending on released")
l.opts.onReleased(l)
}
log.Debug("release done")
}
func (l *Lock) Acquire() {
// go func() {
if err := l.register(); err != nil {
if !errors.Is(err, nats.ErrKeyExists) {
l.log.Warn("could not register immediately", zap.Error(err))
}
}
// }()
go l.watch()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment