Last active
April 21, 2022 09:55
-
-
Save nekomeowww/e8244e624b60719ddd2d25297ca5813f to your computer and use it in GitHub Desktop.
限流中间件
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package middleware | |
type responseBuffer struct { | |
gin.ResponseWriter // the actual ResponseWriter to flush to | |
Body *bytes.Buffer // the response content body | |
Flushed bool | |
} | |
func (w *responseBuffer) Write(b []byte) (int, error) { | |
return w.Body.Write(b) | |
} | |
func couldCountRate(method, endpoint string, rate int, perDuration time.Duration, clientIP string) (int, bool) { | |
// 设定 Redis 存储的键名 | |
keys := keys.RateLimitLock1.Format(method, endpoint, clientIP) | |
var err error | |
var countedRate int | |
// 获取已记录的请求频率 | |
countedRate, err = database.Redis.Get(keys).Int() | |
// 如果遇到错误 | |
if err != nil { | |
// 则初始化请求频率为 0 | |
countedRate = 0 | |
// 如果不是缓存不存在的错误,记录日志 | |
if !database.IsRedisNil(err) { | |
logger.Error(err) | |
} | |
} | |
// 如果频率超过限制 | |
if countedRate >= rate { | |
// 返回当前的请求频率和 false,表示限流 | |
return countedRate, false | |
} | |
// 否则,增加请求频率计数 | |
countedRate++ | |
// 将请求频率计数写入 Redis | |
err = database.Redis.Set(keys, countedRate, perDuration).Err() | |
if err != nil { | |
// 如果遇到错误,记录日志 | |
logger.Error(err) | |
} | |
return countedRate, true | |
} | |
// LimitRateFor 限流中间件 | |
func LimitRateFor(method, endpoint string, rate int, perDuration time.Duration, shoulfCountFunc func(userID int64, isAborted bool) bool) func(c *gin.Context) { | |
return func(c *gin.Context) { | |
// 创建 hanlder 包定义的 Context 实例 | |
ctx := &handler.Context{Context: c} | |
// 获取请求的 URL 和请求方法,并于限流的 URL 和请求方法进行比较 | |
if ctx.Request.Method != method || (ctx.Request.URL != nil && ctx.Request.URL.String() != endpoint) { | |
return | |
} | |
// 定义 responseBuffer 实例 | |
var bodyWriteBuffer *responseBuffer | |
// 获取原始的 ResponseWriter | |
originalWriter, ok := ctx.Writer.(gin.ResponseWriter) | |
if ok { | |
// 创建 responseBuffer 实例 | |
bodyWriteBuffer = &responseBuffer{ResponseWriter: originalWriter, Body: &bytes.Buffer{}} | |
// 覆盖原始的 ResponseWriter | |
c.Writer = bodyWriteBuffer | |
// 等待后续中间件进行处理 | |
ctx.Next() | |
} else { | |
// 如果不是 ResponseWriter 实例,则直接跳过,退出限流中间件 | |
ctx.Next() | |
return | |
} | |
// 执行回调函数来判断是否需要限流 | |
if !shoulfCountFunc(ctx.User().UserID, ctx.IsAborted()) { | |
bodyWriteBuffer.ResponseWriter.Write(bodyWriteBuffer.Body.Bytes()) | |
return | |
} | |
// 是否能够继续进行计数,即是否允许进行请求 | |
currentRate, ok := couldCountRate(method, endpoint, rate, perDuration, c.ClientIP()) | |
// 可以的话直接进行请求写入 | |
if ok { | |
bodyWriteBuffer.ResponseWriter.Write(bodyWriteBuffer.Body.Bytes()) | |
return | |
} | |
// 否则返回限流错误信息 | |
// 清空原有的 Body buffer | |
bodyWriteBuffer.Body.Reset() | |
// HTTP 报文码设定为 StatusTooManyRequests (429) | |
bodyWriteBuffer.WriteHeader(http.StatusTooManyRequests) | |
// 序列化限流错误信息 | |
jsonData, _ := json.Marshal(handler.FinalResponse{ | |
Code: apierror.CodeErrRequestRateLimitReached, // 错误码 | |
Data: nil, | |
Message: apierror.ErrRequestRateLimitReached.FormatMessage(ctx.Language()), // 错误消息 | |
}) | |
// 写入限流错误信息到响应 Body | |
bodyWriteBuffer.ResponseWriter.Write(jsonData) | |
// 记录日志 | |
logger.WithFields(logger.Fields{"endpoint": endpoint, "user_id": ctx.User().UserID, "client_ip": c.ClientIP(), "current_rate": currentRate, "rate": rate}).Warn("达到请求频率上限") | |
return | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment