Skip to content

Instantly share code, notes, and snippets.

@haozibi
Last active November 18, 2019 10:38
Show Gist options
  • Save haozibi/801edeb1434906d5a05a99d1b1a47973 to your computer and use it in GitHub Desktop.
Save haozibi/801edeb1434906d5a05a99d1b1a47973 to your computer and use it in GitHub Desktop.
多个并发请求只让其中一个真正执行,其余阻塞等待到执行的那个请求完成,将结果传递给阻塞的其他请求达到防止雪崩的效果
package main
// by haozibi
// https://play.golang.org/p/QiQ4sBoVbCu
//
// 参考:
// https://segmentfault.com/a/1190000018464029
// https://github.com/golang/sync/tree/master/singleflight
// https://github.com/golang/groupcache/tree/master/singleflight
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/golang/groupcache/singleflight"
)
func main() {
err := single()
if err != nil {
panic(err)
}
}
// "github.com/golang/groupcache/singleflight" 源码相对简单,结构清晰
// "golang.org/x/sync/singleflight" 比较完善,原理和👆的一致
func single() error {
var g singleflight.Group
c := make(chan string)
t1 := time.Now()
var calls int32
fn := func() (interface{}, error) {
// calls 统计 fn 运行次数,原子操作
atomic.AddInt32(&calls, 1)
// 阻塞,直到 chan c 接收到值
return <-c, nil
}
const n = 10
var wg sync.WaitGroup
for i := 0; i < n; i++ {
wg.Add(1)
// 小技巧
// 如果不设置 i:=i,则多个协程读取的都是同一个 i,也是唯一的 i,即 i:=0;i<n;i++ 中的 i
// 也可以通过传送 go func(i int){} 的方式解决此问题
i := i
// 启动多个协程同时执行
go func() {
v, err := g.Do("key", fn)
if err != nil {
fmt.Printf("Do error: %v", err)
return
}
// if v.(string) != "bar" {
// fmt.Printf("got %q; want %q", v, "bar")
// return
// }
fmt.Println(i, v)
wg.Done()
}()
}
// sleep 100ms 有两个作用
// 1. 等待,让多个协程都启动成功,处于等待第一个协程完成的状态
// 2. 第一个启动成功的协程会 “真正” 执行 fn, 100ms 也相当于 fn 执行耗时,第一个协程需要等 100ms 才能接收到 chan c 的输入
// ps: 在 "github.com/golang/groupcache/singleflight" 包中的 Group
// 当 “真正执行” fn 的协程完成了会立即从 map 中删除标致位(对应的 key)
// 所以只有在 “真正执行” fn 执行的过程中(此例为 100ms)进入到 Group 的协程才能获取到共同的返回值
// ps: 等第一个 fn 执行完毕后缓存就可以查到,下一次再执行到这个方法的场景只有缓存失效之后了
time.Sleep(100 * time.Millisecond) // let goroutines above block
c <- "bar"
// 等待所有协程完成
wg.Wait()
// 检查 fn 的运行次数
got := atomic.LoadInt32(&calls)
fmt.Printf("calls: %d,time: %v\n", got, time.Since(t1))
return nil
}
// output:
// 4 bar
// 1 bar
// 8 bar
// 3 bar
// 2 bar
// 0 bar
// 7 bar
// 9 bar
// 6 bar
// 5 bar
// calls: 1,time: 100.576411ms
@haozibi
Copy link
Author

haozibi commented Nov 18, 2019

http 中间件实现请求的 http.Handler

package singleflight

import (
	"bytes"
	"crypto/md5"
	"encoding/hex"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"sync"

	"github.com/pkg/errors"
)

// 多个相同请求,其中一个访问真正的后端,其他请求等待其完成
// 反向代理

type call struct {
	mu sync.Mutex
	wg sync.WaitGroup
	w  *httptest.ResponseRecorder
	mw *multiWriter
}

type multiWriter struct {
	writers []http.ResponseWriter
}

func (c *call) addResponseWriter(w http.ResponseWriter) {
	c.mu.Lock()
	if c.mw.writers == nil {
		c.mw.writers = make([]http.ResponseWriter, 0)
	}
	c.mw.writers = append(c.mw.writers, w)
	c.mu.Unlock()
}

func (c *call) reset() {
	c.w = nil
	c.mw = nil
}

func (c *call) flush() {

	resp := c.w.Result()
	body, _ := ioutil.ReadAll(resp.Body)
	header := map[string][]string(resp.Header.Clone())
	statusCode := resp.StatusCode

	for _, w := range c.mw.writers {
		w.Write(body)
		w.WriteHeader(statusCode)
		for k, v := range header {
			for _, vv := range v {
				w.Header().Add(k, vv)
			}
		}
	}
}

// Group group
type Group struct {
	mu       sync.Mutex
	m        map[string]*call
	callPool *sync.Pool
	buffPool *sync.Pool
}

// Do do
func (g *Group) Do(next http.Handler, opts ...Option) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

		opt := options{"", ""}

		for _, o := range opts {
			o.apply(&opt)
		}

		if len(opt.headerName) == 0 ||
			len(opt.headerValue) == 0 {
			next.ServeHTTP(w, r)
			return
		}

		if r.Header.Get(opt.headerName) !=
			opt.headerValue {
			next.ServeHTTP(w, r)
			return
		}

		err := g.do(w, r, next)
		if err != nil {
			panic(err)
		}
	})
}

func (g *Group) do(w http.ResponseWriter, r *http.Request, next http.Handler) error {
	g.mu.Lock()
	if g.m == nil {
		g.m = make(map[string]*call)
	}
	if g.callPool == nil {
		g.callPool = &sync.Pool{
			New: func() interface{} {
				return new(call)
			},
		}
	}
	if g.buffPool == nil {
		g.buffPool = &sync.Pool{
			New: func() interface{} {
				return bytes.NewBuffer(make([]byte, 4096))
			},
		}
	}
	key, err := g.requestKey(r)
	if err != nil {
		return err
	}

	// 只保证在首个请求进行请求时进入的相同请求会被阻塞
	// 返回相同响应
	if c, ok := g.m[key]; ok {
		g.mu.Unlock()
		// 是否并发安全,遗漏响应
		c.addResponseWriter(w)
		c.wg.Wait()
		return nil
	}

	c := g.callPool.Get().(*call)
	c.wg.Add(1)
	g.m[key] = c
	c.mw = new(multiWriter)
	c.w = httptest.NewRecorder()
	c.addResponseWriter(w)
	g.mu.Unlock()

	// 唯一请求转发
	// todo: 利用 channel 把这段分离,避免耦合
	next.ServeHTTP(c.w, r)

	// 把请求响应统一复制到所有等待的响应
	c.flush()
	c.wg.Done()

	g.mu.Lock()
	delete(g.m, key)
	// c.reset()
	g.callPool.Put(c)
	g.mu.Unlock()

	return nil
}

func (g *Group) requestKey(r *http.Request) (string, error) {

	var (
		key    string
		method string
		path   string
		body   string
		header string
		proto  string
	)

	proto = r.Proto
	path = r.URL.RawPath
	method = r.Method

	headers := map[string][]string(r.Header)
	for k, v := range headers {
		s := ""
		for _, vv := range v {
			s += vv
		}
		header += k + "=" + s + ";"
	}

	if r.Body != nil {
		// buffer := bytes.NewBuffer(make([]byte, 4096))
		buffer := g.buffPool.Get().(*bytes.Buffer)
		// 不知为啥,在 Put 前会造成阻塞
		buffer.Reset()
		_, err := io.Copy(buffer, r.Body)
		if err != nil {
			g.buffPool.Put(buffer)
			return "", errors.Wrap(err, "get request key")
		}
		body = mmd5(buffer.Bytes())
		g.buffPool.Put(buffer)
	}

	key = method + path + proto + header + body

	return key, nil
}

func mmd5(body []byte) string {
	hasher := md5.New()
	hasher.Write(body)
	return hex.EncodeToString(hasher.Sum(nil))
}

// Option option
type Option interface {
	apply(*options)
}

type optionFunc func(*options)

func (f optionFunc) apply(o *options) {
	f(o)
}

// WithHeader 只有 name 和 value 都符合的请求才会被执行
// name 和 value 任意为空都不会执行
func WithHeader(name, value string) Option {
	return optionFunc(func(o *options) {
		o.headerName = name
		o.headerValue = value
	})
}

type options struct {
	headerName  string
	headerValue string
}

go test

package singleflight

import (
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"sync"
	"sync/atomic"
	"testing"
	"time"
)

func TestSingleGroup(t *testing.T) {

	var (
		g           Group
		calls       int32
		wg          sync.WaitGroup
		n           = 10
		c           = make(chan string)
		t1          = time.Now()
		headerName  = "abc"
		headerValue = "abcabc"
	)

	req := httptest.NewRequest("GET", "http://example.com/foo", nil)
	req.Header.Set(headerName, headerValue)

	var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		atomic.AddInt32(&calls, 1)
		io.WriteString(w, <-c)
	})
	handler = g.Do(handler, WithHeader(headerName, headerValue))

	for i := 0; i < n; i++ {
		wg.Add(1)
		i := i
		go func() {
			fmt.Println("req num:", i)
			w := httptest.NewRecorder()

			handler.ServeHTTP(w, req)

			resp := w.Result()
			respBody, _ := ioutil.ReadAll(resp.Body)

			fmt.Println(i, resp.StatusCode, string(respBody))
			wg.Done()
		}()

	}

	time.Sleep(100 * time.Millisecond)
	fmt.Println("===")
	c <- "bar"
	// close(c)
	wg.Wait()

	got := atomic.LoadInt32(&calls)

	fmt.Printf("calls: %d,time: %v\n", got, time.Since(t1))
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment