95 lines
2.0 KiB
Go
95 lines
2.0 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
)
|
||
|
|
|
||
|
|
// ipBucket is a per-IP sliding-window counter.
|
||
|
|
type ipBucket struct {
|
||
|
|
mu sync.Mutex
|
||
|
|
times []time.Time
|
||
|
|
limit int
|
||
|
|
window time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
func (b *ipBucket) allow() bool {
|
||
|
|
b.mu.Lock()
|
||
|
|
defer b.mu.Unlock()
|
||
|
|
|
||
|
|
now := time.Now()
|
||
|
|
cutoff := now.Add(-b.window)
|
||
|
|
|
||
|
|
// Evict timestamps outside the window
|
||
|
|
valid := b.times[:0]
|
||
|
|
for _, t := range b.times {
|
||
|
|
if t.After(cutoff) {
|
||
|
|
valid = append(valid, t)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
b.times = valid
|
||
|
|
|
||
|
|
if len(b.times) >= b.limit {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
b.times = append(b.times, now)
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
// RateLimiter is a per-IP sliding-window rate limiter backed by sync.Map.
|
||
|
|
// It is safe for concurrent use and cleans up idle buckets automatically.
|
||
|
|
type RateLimiter struct {
|
||
|
|
buckets sync.Map // string(ip) → *ipBucket
|
||
|
|
limit int
|
||
|
|
window time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewRateLimiter creates a limiter that allows up to limit requests per window
|
||
|
|
// per IP address.
|
||
|
|
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||
|
|
rl := &RateLimiter{limit: limit, window: window}
|
||
|
|
go rl.gc()
|
||
|
|
return rl
|
||
|
|
}
|
||
|
|
|
||
|
|
// gc periodically removes buckets that have had no activity for one full window.
|
||
|
|
func (rl *RateLimiter) gc() {
|
||
|
|
ticker := time.NewTicker(rl.window)
|
||
|
|
defer ticker.Stop()
|
||
|
|
for range ticker.C {
|
||
|
|
rl.buckets.Range(func(k, v any) bool {
|
||
|
|
b := v.(*ipBucket)
|
||
|
|
b.mu.Lock()
|
||
|
|
if len(b.times) == 0 {
|
||
|
|
rl.buckets.Delete(k)
|
||
|
|
}
|
||
|
|
b.mu.Unlock()
|
||
|
|
return true
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Middleware returns a gin.HandlerFunc that enforces the rate limit.
|
||
|
|
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
ip := c.ClientIP()
|
||
|
|
v, _ := rl.buckets.LoadOrStore(ip, &ipBucket{
|
||
|
|
times: make([]time.Time, 0, rl.limit),
|
||
|
|
limit: rl.limit,
|
||
|
|
window: rl.window,
|
||
|
|
})
|
||
|
|
b := v.(*ipBucket)
|
||
|
|
if !b.allow() {
|
||
|
|
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||
|
|
"code": "rate_limited",
|
||
|
|
"message": "too many requests — slow down",
|
||
|
|
})
|
||
|
|
return
|
||
|
|
}
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|