Skip to content

Instantly share code, notes, and snippets.

@schwarzeni
Last active January 26, 2021 08:53
Show Gist options
  • Save schwarzeni/546df3232b5ba9d5f5c3a0e35ae95c0d to your computer and use it in GitHub Desktop.
Save schwarzeni/546df3232b5ba9d5f5c3a0e35ae95c0d to your computer and use it in GitHub Desktop.
一个简单的拥有 LoadBalance 的生产者消费者代码
package main
import (
"context"
"errors"
"fmt"
"log"
"math/rand"
"sync"
"time"
)
type Task func()
type Selector interface {
Select([]*WorkerInfo) (*WorkerInfo, error)
}
type NormalSelector struct{}
func (ns *NormalSelector) Select(workerInfos []*WorkerInfo) (*WorkerInfo, error) {
minWorkLoad := 1.0
var minW *WorkerInfo
for _, wi := range workerInfos {
wi.CountLock.RLock()
if wl := float64(wi.CurrentTasksCount) / float64(wi.MaxTasksCount); wl < minWorkLoad {
minWorkLoad = wl
minW = wi
}
wi.CountLock.RUnlock()
}
if minW == nil {
return nil, errors.New("worker not found!")
}
return minW, nil
}
// WorkerInfo Worker的信息
type WorkerInfo struct {
ID int // 唯一的ID
CurrentTasksCount int // 当前的任务数
MaxTasksCount int // 该 Worker 最多能承载的任务数
CountLock sync.RWMutex // 改变count时的锁
}
type LoadBalancer struct {
workerInfos []*WorkerInfo
workerChannel map[int](chan Task)
wg sync.WaitGroup
selector Selector
cancel context.CancelFunc
}
func (lb *LoadBalancer) submit(task Task) {
lb.wg.Add(1)
worker, err := lb.selector.Select(lb.workerInfos)
if err != nil {
// TODO: 直接panic
panic(err)
}
worker.CountLock.Lock()
worker.CurrentTasksCount += 1
worker.CountLock.Unlock()
lb.workerChannel[worker.ID] <- task
}
func (lb *LoadBalancer) wait() {
lb.wg.Wait()
lb.cancel()
}
func LB(lbCount int) *LoadBalancer {
ctx, cancel := context.WithCancel(context.Background())
// 这里直接硬编码,使用一个Selector
lb := &LoadBalancer{
selector: &NormalSelector{},
workerChannel: make(map[int](chan Task)),
cancel: cancel,
}
// 启动 Worker协程
// 慢启动?一次性全部启动?这里就是一次性全部启动
for i := 0; i < lbCount; i++ {
// 编号作为ID,最大工作数默认设置为 1, 2, 3, ....
w := &WorkerInfo{ID: i, MaxTasksCount: i+1}
lb.workerInfos = append(lb.workerInfos, w)
ch := make(chan Task, 1)
lb.workerChannel[i] = ch
go func(w *WorkerInfo, task chan Task, wg *sync.WaitGroup, ctx context.Context) {
for {
select {
case fn := <-task:
// 这里为Worker的内容
id := w.ID
fmt.Printf("[BEGIN %d] worker %d got task\n", id, id)
fn()
fmt.Printf("[DONE %d] worker %d finish task\n", id, id)
w.CountLock.Lock()
w.CurrentTasksCount -= 1
w.CountLock.Unlock()
wg.Done()
// 防止 goroutine泄漏
case <-ctx.Done():
return
}
}
}(w, ch, &lb.wg, ctx)
}
return lb
}
func main() {
lb := LB(4)
for i := 0; i < 10; i++ {
j := i
lb.submit(func() {
log.Printf("doing job %d\n", j)
// 模拟任务耗时
time.Sleep((time.Duration(rand.Intn(10)) * time.Second))
})
}
lb.wait()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment