Created
October 24, 2024 04:38
-
-
Save aloysb/d285b0966498d4e8678945ef99778d09 to your computer and use it in GitHub Desktop.
Basic rate limiter in Go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"log" | |
"net/http" | |
"net/http/httputil" | |
"net/url" | |
"sync" | |
"time" | |
) | |
type RateLimiter struct { | |
postRate int | |
getRate int | |
interval time.Duration | |
postTokens map[string]int | |
getTokens map[string]int | |
mutex sync.Mutex | |
} | |
func NewRateLimiter(postRate, getRate int, interval time.Duration) *RateLimiter { | |
rl := &RateLimiter{ | |
postRate: postRate, | |
getRate: getRate, | |
interval: interval, | |
postTokens: make(map[string]int), | |
getTokens: make(map[string]int), | |
} | |
go rl.refillTokens() | |
return rl | |
} | |
func (rl *RateLimiter) refillTokens() { | |
ticker := time.NewTicker(rl.interval) | |
for range ticker.C { | |
rl.mutex.Lock() | |
rl.postTokens = make(map[string]int) // Reset POST tokens | |
rl.getTokens = make(map[string]int) // Reset GET tokens | |
rl.mutex.Unlock() | |
} | |
} | |
func (rl *RateLimiter) Allow(ip string, method string) bool { | |
rl.mutex.Lock() | |
defer rl.mutex.Unlock() | |
switch method { | |
case http.MethodPost: | |
if rl.postTokens[ip] < rl.postRate { | |
log.Println("Allowed POST request", "ip", ip) | |
log.Println("POST tokens", "ip", ip, "tokens", rl.postTokens[ip], "/", rl.postRate) | |
rl.postTokens[ip]++ | |
return true | |
} | |
case http.MethodGet: | |
log.Println("Allowed GET request", "ip", ip) | |
log.Println("GET tokens", "ip", ip, "tokens", rl.getTokens[ip], "/", rl.getRate) | |
if rl.getTokens[ip] < rl.getRate { | |
rl.getTokens[ip]++ | |
return true | |
} | |
} | |
log.Println("Denied request", "ip", ip) | |
return false | |
} | |
func main() { | |
backendURL, err := url.Parse("http://caddy:80") | |
if err != nil { | |
log.Fatal(err) | |
} | |
proxy := httputil.NewSingleHostReverseProxy(backendURL) | |
rateLimiter := NewRateLimiter(6, 30, 1*time.Minute) // 6 POSTs, 30 GETs per minute | |
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
clientIP := r.RemoteAddr | |
if rateLimiter.Allow(clientIP, r.Method) { | |
log.Println("Allowed request", "ip", clientIP) | |
proxy.ServeHTTP(w, r) | |
} else { | |
log.Println("Denied request", "ip", clientIP) | |
http.Error(w, "Too Many Requests", http.StatusTooManyRequests) | |
} | |
}) | |
log.Println("Proxy with rate limiting running on :8081") | |
log.Fatal(http.ListenAndServe(":8081", nil)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment