Skip to content

Instantly share code, notes, and snippets.

@hulucc
Created December 16, 2021 12:12
Show Gist options
  • Save hulucc/e229011578d1561ac04f3a213799d734 to your computer and use it in GitHub Desktop.
Save hulucc/e229011578d1561ac04f3a213799d734 to your computer and use it in GitHub Desktop.
package main
import (
"net"
"fmt"
"context"
"time"
"sync"
"go.uber.org/yarpc/api/peer"
"go.uber.org/yarpc/api/transport"
"go.uber.org/yarpc/pkg/lifecycle"
)
type Identify string
func (it Identify) Identifier() string {
return string(it)
}
func BindHostPort(hostport string) peer.Binder {
return func(pl peer.List) transport.Lifecycle {
return &HostPortPeersUpdater{
once: lifecycle.NewOnce(),
pl: pl,
hostport: hostport,
}
}
}
type HostPortPeersUpdater struct {
once *lifecycle.Once
pl peer.List
hostport string
last []peer.Identifier
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
func (it *HostPortPeersUpdater) resolve() ([]peer.Identifier, error) {
host, port, err := net.SplitHostPort(it.hostport)
if err != nil {
return nil, fmt.Errorf("net.SplitHostPort(%s) err: %w", it.hostport, err)
}
if net.ParseIP(host) != nil {
return []peer.Identifier{Identify(it.hostport)}, nil
}
ips, err := net.LookupIP(host)
if err != nil {
return nil, fmt.Errorf("net.LookupIP(%s) err: %w", host, err)
}
plist := make([]peer.Identifier, 0, len(ips))
for _, ip := range ips {
plist = append(plist, Identify(fmt.Sprintf("%s:%s", ip, port)))
}
return plist, nil
}
func (it *HostPortPeersUpdater) dict(plist []peer.Identifier) map[string]peer.Identifier {
dict := make(map[string]peer.Identifier)
for _, p := range plist {
dict[p.Identifier()] = p
}
return dict
}
func (it *HostPortPeersUpdater) diff(plist []peer.Identifier) peer.ListUpdates {
changes := peer.ListUpdates{Additions: []peer.Identifier{}, Removals: []peer.Identifier{}}
prev := it.dict(it.last)
curr := it.dict(plist)
for key, value := range prev {
if _, ok := curr[key]; !ok {
changes.Removals = append(changes.Removals, value)
}
}
for key, value := range curr {
if _, ok := prev[key]; !ok {
changes.Additions = append(changes.Additions, value)
}
}
return changes
}
func (it *HostPortPeersUpdater) update() error {
plist, err := it.resolve()
if err != nil {
return fmt.Errorf("resolve() err: %w", err)
}
changes := it.diff(plist)
fmt.Println("updating peers %v", changes)
if err := it.pl.Update(changes); err != nil {
return fmt.Errorf("plist.Update() err: %w", err)
}
it.last = plist
return nil
}
func (it *HostPortPeersUpdater) clear() error {
if err := it.pl.Update(it.diff(nil)); err != nil {
return fmt.Errorf("plist.Update() err: %w", err)
}
it.last = nil
return nil
}
func (it *HostPortPeersUpdater) watch(ctx context.Context) error {
defer it.wg.Done()
for {
select {
case <- time.After(time.Second):
if err := it.update(); err != nil {
return fmt.Errorf("update err: %w", err)
}
case <-ctx.Done():
return fmt.Errorf("watch err: %w", ctx.Err())
}
}
}
func (it *HostPortPeersUpdater) start() error {
if err := it.update(); err != nil {
return fmt.Errorf("update() err: %w", err)
}
it.ctx, it.cancel = context.WithCancel(context.Background())
it.wg = sync.WaitGroup{}
it.wg.Add(1)
go it.watch(it.ctx)
return nil
}
func (it *HostPortPeersUpdater) stop() error {
it.cancel()
it.wg.Wait()
if err := it.clear(); err != nil {
return fmt.Errorf("clear() err: %w", err)
}
return nil
}
func (it *HostPortPeersUpdater) Start() error {
return it.once.Start(it.start)
}
func (it *HostPortPeersUpdater) Stop() error {
return it.once.Stop(it.stop)
}
func (it *HostPortPeersUpdater) IsRunning() bool {
return it.once.IsRunning()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment