Created
March 29, 2023 14:24
-
-
Save kmpm/42ef1f9b08cd312973f968f1e7128960 to your computer and use it in GitHub Desktop.
Using NATS KV for creating locks.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package lock | |
import ( | |
"context" | |
"errors" | |
"strings" | |
"sync" | |
"time" | |
"testapp/logger" | |
"github.com/nats-io/nats.go" | |
"go.uber.org/zap" | |
) | |
type Options struct { | |
onReleased func(*Lock) | |
onAcquired func(*Lock) | |
onLost func(*Lock) | |
} | |
type Option func(*Options) error | |
func OnAcquired(cb func(*Lock)) Option { | |
return func(o *Options) error { | |
o.onAcquired = cb | |
return nil | |
} | |
} | |
func OnLost(cb func(*Lock)) Option { | |
return func(o *Options) error { | |
o.onLost = cb | |
return nil | |
} | |
} | |
type Lock struct { | |
name string | |
kv nats.KeyValue | |
ttl time.Duration | |
w nats.KeyWatcher | |
cRefresh chan struct{} | |
cWatch chan struct{} | |
state string | |
ticket []byte | |
rev uint64 | |
opts *Options | |
log *zap.Logger | |
mu sync.Mutex | |
} | |
var reregisterDelay time.Duration = 2 * time.Second | |
func New(ctx context.Context, kv nats.KeyValue, name, ticket string, opts ...Option) (*Lock, error) { | |
if name == "" { | |
return nil, errors.New("key can not be empty") | |
} | |
if kv == nil { | |
return nil, errors.New("kv can not be nil") | |
} | |
ctx = logger.NewContext(ctx, "lock", zap.String("lock", name)) | |
l := &Lock{ | |
name: name, | |
kv: kv, | |
ttl: 30 * time.Second, | |
state: "released", | |
ticket: []byte(ticket), | |
log: logger.L(ctx), | |
opts: &Options{}, | |
// OnReleased: make(chan string), | |
} | |
l.SetOptions(opts...) | |
return l, nil | |
} | |
func (l *Lock) SetOptions(opts ...Option) { | |
l.mu.Lock() | |
defer l.mu.Unlock() | |
for _, opt := range opts { | |
opt(l.opts) | |
} | |
} | |
func (l *Lock) HasAcquisition() bool { | |
return l.state == "registerd" && l.rev > 0 | |
} | |
// setState changes state and returns true if there was a change and false if no change | |
func (l *Lock) setState(v string) bool { | |
if l.state != v { | |
l.log.Debug("changing state", zap.String("state", l.state), zap.String("to_state", v)) | |
l.state = v | |
return true | |
} | |
return false | |
} | |
func (l *Lock) register() error { | |
log := l.log | |
oldstate := l.state | |
var err error | |
var rev uint64 | |
if !l.setState("registering") { | |
return errors.New("already registering") | |
} | |
defer func() { | |
go l.refresh() | |
}() | |
rev, err = l.kv.Create(l.name, l.ticket) | |
if err != nil { | |
// logger.Warning("failed to register %s: %v", l.name, err) | |
l.setState(oldstate) | |
return err | |
} | |
l.setState("registerd") | |
l.rev = rev | |
if l.opts.onAcquired != nil { | |
log.Info("pushing on acquired") | |
l.opts.onAcquired(l) | |
} else { | |
log.Info("no one listens for acquisition") | |
} | |
return nil | |
} | |
func (l *Lock) registerDelayed(d time.Duration) { | |
go func() { | |
time.Sleep(d) | |
l.register() | |
}() | |
} | |
// watch is made to be run as a gorutine | |
func (l *Lock) watch() { | |
log := l.log | |
if l.w != nil { | |
log.Warn("watcher is already running") | |
return | |
} | |
l.cWatch = make(chan struct{}) | |
var err error | |
l.w, err = l.kv.Watch(l.name) | |
if err != nil { | |
log.Error("error getting watcher", zap.Error(err)) | |
return | |
} | |
defer log.Debug("watcher exited") | |
for { | |
select { | |
case kve := <-l.w.Updates(): | |
if kve != nil { | |
switch kve.Operation() { | |
case nats.KeyValueDelete: | |
log.Debug("delete", zap.String("state", l.state), zap.String("value", string(kve.Value()))) | |
if strings.HasPrefix(l.state, "reg") { | |
l.lost() | |
} | |
if !strings.HasPrefix(l.state, "rel") { | |
l.registerDelayed(reregisterDelay) | |
} | |
// default: | |
// logger.Debug("%s @ %d -> %q (op: %s)\n", kve.Key(), kve.Revision(), string(kve.Value()), kve.Operation()) | |
} | |
} | |
case <-l.cWatch: | |
log.Debug("stop watching") | |
if l.w != nil { | |
l.w.Stop() | |
l.w = nil | |
} | |
return | |
} | |
} | |
} | |
func (l *Lock) deleteKV() { | |
if l.rev != 0 { | |
l.log.Debug("deleting kv") | |
err := l.kv.Delete(l.name, nats.LastRevision(l.rev)) | |
if err != nil { | |
l.log.Warn("could not delete kv", zap.Error(err)) | |
} else { | |
l.log.Debug("deleted") | |
} | |
l.rev = 0 | |
} else { | |
l.log.Debug("no kv to delete") | |
} | |
} | |
func (l *Lock) lost() { | |
log := l.log | |
l.setState("lost") | |
log.Info("lost lock") | |
l.deleteKV() | |
l.rev = 0 | |
if l.opts.onLost != nil { | |
l.opts.onLost(l) | |
} | |
} | |
func (l *Lock) refresh() { | |
log := l.log | |
if l.cRefresh != nil { | |
return | |
} | |
timeout := l.ttl / 2 | |
log.Debug("refresh started", zap.Duration("timeout", timeout)) | |
l.cRefresh = make(chan struct{}) | |
ticker := time.NewTicker(timeout) | |
bvalue := l.ticket | |
defer func() { | |
log.Debug("refresh exited") | |
l.cRefresh = nil | |
}() | |
for { | |
select { | |
case <-ticker.C: | |
log.Debug("on tick", | |
zap.String("state", l.state), | |
zap.Uint64("rev", l.rev), | |
zap.ByteString("ticket", l.ticket)) | |
switch l.state { | |
case "registerd": | |
if nrev, err := l.kv.Update(l.name, bvalue, l.rev); err != nil { | |
log.Warn("update error", zap.Error(err)) | |
l.lost() | |
} else { | |
// logger.Debug("refreshed '%s' with rev %d", l.name, l.rev) | |
l.rev = nrev | |
} | |
case "lost", "released": | |
if err := l.register(); err != nil { | |
if errors.Is(err, nats.ErrKeyExists) { | |
log.Debug("someone else has lock", zap.Error(err)) | |
} else { | |
log.Warn("error registering", zap.Error(err)) | |
} | |
} | |
default: | |
log.Debug("unhandled state", zap.String("state", l.state)) | |
} | |
case <-l.cRefresh: | |
log.Debug("cRefresh closed") | |
return | |
} | |
} | |
} | |
func (l *Lock) Release() { | |
log := l.log | |
if l.state != "released" { | |
l.setState("releasing") | |
} else { | |
log.Debug("already released", zap.Bool("cWatch", l.cWatch != nil), zap.Bool("cRefresh", l.cRefresh != nil)) | |
return | |
} | |
log.Debug("releaseing", zap.String("state", l.state)) | |
if l.cRefresh != nil { | |
log.Debug("closing cRefresh channel") | |
close(l.cRefresh) | |
l.cRefresh = nil | |
} | |
if l.cWatch != nil { | |
log.Debug("closing cWatch channel") | |
close(l.cWatch) | |
l.cWatch = nil | |
} | |
l.setState("released") | |
l.deleteKV() | |
if l.opts.onReleased != nil { | |
log.Debug("sending on released") | |
l.opts.onReleased(l) | |
} | |
log.Debug("release done") | |
} | |
func (l *Lock) Acquire() { | |
// go func() { | |
if err := l.register(); err != nil { | |
if !errors.Is(err, nats.ErrKeyExists) { | |
l.log.Warn("could not register immediately", zap.Error(err)) | |
} | |
} | |
// }() | |
go l.watch() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment