Created
October 28, 2024 15:02
-
-
Save AspireOne/621b25182f0ca246ef5c73504e0d48eb to your computer and use it in GitHub Desktop.
NestJS Request Throttling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Injectable, OnModuleInit, Optional, Logger } from "@nestjs/common"; | |
import { ThrottleMetrics } from "./throttle.metrics"; | |
import { | |
ThrottleConfig, | |
ThrottleRecord, | |
ThrottleCheckResult, | |
ThrottleRule, | |
} from "./throttle.types"; | |
import { ThrottleException } from "./throttle.exception"; | |
@Injectable() | |
export class ThrottleService implements OnModuleInit { | |
private readonly logger = new Logger(ThrottleService.name); | |
private storage = new Map<string, ThrottleRecord>(); | |
constructor(@Optional() private readonly metrics?: ThrottleMetrics) {} | |
private readonly CLEANUP_INTERVAL = 1000 * 60 * 15; // 15 minutes | |
private readonly DEFAULT_MAX_RECORDS = 10000; // Default max records per action | |
private readonly DEFAULT_MAX_TIMESTAMPS = 1000; // Default max timestamps per user | |
private intervalId?: NodeJS.Timeout; | |
private safeNow(): number { | |
return Math.min(Date.now(), Number.MAX_SAFE_INTEGER - 24 * 60 * 60 * 1000); | |
} | |
onModuleInit() { | |
this.intervalId = setInterval(() => { | |
this.cleanup(); | |
}, this.CLEANUP_INTERVAL); | |
} | |
onModuleDestroy() { | |
if (this.intervalId) { | |
clearInterval(this.intervalId); | |
} | |
} | |
private cleanup() { | |
const now = this.safeNow(); | |
const CLEANUP_THRESHOLD = 1000 * 60 * 5; // Only cleanup records every 5 minutes | |
this.metrics?.setStorageSize(this.storage.size); | |
// Enforce maximum records limit | |
if (this.storage.size > this.DEFAULT_MAX_RECORDS) { | |
const sortedEntries = Array.from(this.storage.entries()).sort(([, a], [, b]) => { | |
const lastA = a.timestamps[a.timestamps.length - 1] || 0; | |
const lastB = b.timestamps[b.timestamps.length - 1] || 0; | |
return lastA - lastB; | |
}); | |
// Remove oldest entries until we're under the limit | |
while (this.storage.size > this.DEFAULT_MAX_RECORDS) { | |
const [key] = sortedEntries.shift()!; | |
this.storage.delete(key); | |
} | |
} | |
for (const [key, record] of this.storage.entries()) { | |
// Skip if cleanup was done recently | |
if (now - record.lastCleanup < CLEANUP_THRESHOLD) { | |
continue; | |
} | |
// Only remove if record is empty | |
if (record.timestamps.length === 0) { | |
this.storage.delete(key); | |
continue; | |
} | |
// Find the cutoff index for old timestamps | |
const cutoffTime = now - record.maxWindowMs; | |
const cutoffIndex = this.binarySearch(record.timestamps, cutoffTime); | |
if (cutoffIndex > 0) { | |
// Remove all timestamps before the cutoff | |
record.timestamps = record.timestamps.slice(cutoffIndex); | |
} | |
record.lastCleanup = now; | |
} | |
} | |
private binarySearch(timestamps: number[], target: number): number { | |
if (timestamps.length === 0) return 0; | |
let left = 0; | |
let right = timestamps.length; | |
while (left < right) { | |
const mid = Math.floor((left + right) / 2); | |
if (timestamps[mid] <= target) { | |
left = mid + 1; | |
} else { | |
right = mid; | |
} | |
} | |
return left; | |
} | |
private getStorageKey(userId: string, action?: string): string { | |
return `${userId}:${action || "default"}`; | |
} | |
private checkRule(timestamps: number[], rule: ThrottleRule): ThrottleCheckResult { | |
if (!timestamps || timestamps.length === 0) { | |
return { | |
allowed: true, | |
remainingRequests: rule.maxRequests, | |
}; | |
} | |
const now = this.safeNow(); | |
const windowStart = now - rule.windowMs; | |
// Quick check using total count first | |
if (timestamps.length < rule.maxRequests) { | |
return { | |
allowed: true, | |
remainingRequests: rule.maxRequests - timestamps.length, | |
}; | |
} | |
// Only if we might exceed the limit, we do the window filtering | |
const windowIndex = this.binarySearch(timestamps, windowStart); | |
const requestsInWindow = timestamps.length - windowIndex; | |
const allowed = requestsInWindow < rule.maxRequests; | |
const remainingRequests = Math.max(0, rule.maxRequests - requestsInWindow); | |
return { | |
allowed, | |
remainingRequests, | |
retryAfter: allowed | |
? undefined | |
: Math.ceil((windowStart + rule.windowMs - now) / 1000), | |
rule: allowed ? undefined : rule, | |
}; | |
} | |
async checkThrottle( | |
userId: string, | |
action?: string, | |
config?: ThrottleConfig, | |
): Promise<ThrottleCheckResult> { | |
if (!userId || typeof userId !== "string") { | |
throw new Error("Invalid user ID provided"); | |
} | |
if (!config?.rules?.length) { | |
return { allowed: true, remainingRequests: Infinity }; | |
} | |
// Validate configuration | |
if (!Array.isArray(config.rules)) { | |
throw new Error("Invalid throttle configuration: rules must be an array"); | |
} | |
// Validate config | |
for (const rule of config.rules) { | |
if (!Number.isFinite(rule.windowMs) || !Number.isFinite(rule.maxRequests)) { | |
throw new Error( | |
"Invalid throttle rule: windowMs and maxRequests must be finite numbers", | |
); | |
} | |
if (rule.windowMs <= 0 || rule.maxRequests <= 0) { | |
throw new Error( | |
"Invalid throttle rule: windowMs and maxRequests must be positive", | |
); | |
} | |
if (rule.windowMs > 24 * 60 * 60 * 1000) { | |
console.warn("Throttle window larger than 24 hours may cause issues:", rule); | |
} | |
} | |
const key = this.getStorageKey(userId, action); | |
const now = this.safeNow(); | |
// Calculate max window size from all rules | |
const maxWindowMs = Math.max(...config.rules.map((rule) => rule.windowMs)); | |
// Get or create record atomically | |
let record = this.storage.get(key); | |
if (!record) { | |
record = { | |
userId, | |
timestamps: [], | |
lastCleanup: now, | |
maxWindowMs, | |
}; | |
this.storage.set(key, record); | |
} | |
// Validate record structure | |
if (!Array.isArray(record.timestamps) || typeof record.lastCleanup !== "number") { | |
// Record is corrupted, reset it | |
record = { | |
userId, | |
timestamps: [], | |
lastCleanup: now, | |
maxWindowMs, | |
}; | |
this.storage.set(key, record); | |
} | |
// Enforce maximum timestamps per user | |
const maxTimestamps = config.maxTimestampsPerUser || this.DEFAULT_MAX_TIMESTAMPS; | |
if (record.timestamps.length >= maxTimestamps) { | |
const cutoffTime = now - record.maxWindowMs; | |
const cutoffIndex = this.binarySearch(record.timestamps, cutoffTime); | |
if (cutoffIndex === 0 && record.timestamps.length >= maxTimestamps) { | |
throw new ThrottleException( | |
`Rate limit exceeded: maximum request history (${maxTimestamps}) reached`, | |
Math.ceil((record.timestamps[0] + record.maxWindowMs - now) / 1000), | |
); | |
} | |
} | |
// Update maxWindowMs if config changed | |
record.maxWindowMs = Math.max(record.maxWindowMs, maxWindowMs); | |
// Clean up old timestamps if needed | |
const cutoffTime = now - record.maxWindowMs; | |
const cutoffIndex = this.binarySearch(record.timestamps, cutoffTime); | |
if (cutoffIndex > 0) { | |
record.timestamps = record.timestamps.slice(cutoffIndex); | |
record.lastCleanup = now; | |
// Keep the record even if empty to maintain the maxWindowMs | |
record.lastCleanup = now; | |
record.maxWindowMs = Math.max(...config.rules.map((rule) => rule.windowMs)); | |
} | |
// Check all rules with simulated new timestamp | |
const simulatedTimestamps = [...record.timestamps, now]; | |
for (const rule of config.rules) { | |
const result = this.checkRule(simulatedTimestamps, rule); | |
if (!result.allowed) { | |
if (config.logToConsole) { | |
console.log(`Throttle exceeded for user ${userId} on action ${action}`, result); | |
} | |
return result; | |
} | |
} | |
// If all rules pass with simulated timestamp, update the record | |
// Re-check storage to handle race conditions (with max retries) | |
const currentRecord = this.storage.get(key); | |
if (!currentRecord || currentRecord !== record) { | |
// Another request modified the record, retry once | |
return this.checkThrottle(userId, action, config); | |
} | |
// Ensure timestamps remain sorted | |
record.timestamps.push(now); | |
record.timestamps.sort((a, b) => a - b); | |
this.storage.set(key, record); | |
// Return the most restrictive remaining requests | |
const remainingRequests = Math.min( | |
...config.rules.map( | |
(rule) => this.checkRule(record.timestamps, rule).remainingRequests, | |
), | |
); | |
return { allowed: true, remainingRequests }; | |
} | |
async consumeToken( | |
userId: string, | |
action?: string, | |
config?: ThrottleConfig, | |
): Promise<{ remainingRequests: number }> { | |
if (!config || !config.rules || config.rules.length === 0) { | |
this.metrics?.incrementAllowed(userId, action || 'default'); | |
return { remainingRequests: Infinity }; // No throttling rules defined | |
} | |
const result = await this.checkThrottle(userId, action, config); | |
if (!result.allowed) { | |
this.metrics?.incrementThrottled(userId, action || 'default'); | |
this.logger.warn( | |
`Rate limit exceeded for user ${userId} on action ${action}. ` + | |
`Retry after: ${result.retryAfter}s` | |
); | |
throw new ThrottleException( | |
result.rule?.errorMessage || config?.defaultErrorMessage || "Too Many Requests", | |
result.retryAfter, | |
); | |
} | |
this.metrics?.incrementAllowed(userId, action || 'default'); | |
return { remainingRequests: result.remainingRequests }; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Injectable, CanActivate, ExecutionContext, Inject } from "@nestjs/common"; | |
import { Reflector } from "@nestjs/core"; | |
import { ThrottleService } from "./throttle.service"; | |
import { ThrottleConfig } from "./throttle.types"; | |
import { THROTTLE_METADATA, SKIP_THROTTLE } from "./throttle.decorator"; | |
import { FastifyRequest } from "fastify"; | |
import { THROTTLE_CONFIG_KEY } from "./throttle.module"; | |
/** Guard responsible for managing rate limiting in the backend. */ | |
@Injectable() | |
export class ThrottleGuard implements CanActivate { | |
constructor( | |
private readonly reflector: Reflector, | |
private readonly throttlerService: ThrottleService, | |
// Global throttle config. This is set in ThrottleModule and injected here for consistency. | |
@Inject(THROTTLE_CONFIG_KEY) private readonly defaultConfig?: ThrottleConfig, | |
) {} | |
async canActivate(context: ExecutionContext): Promise<boolean> { | |
const request = context.switchToHttp().getRequest(); | |
// Get user token from request | |
const token = this.extractTokenFromHeader(request); | |
if (!token) { | |
// Optional: throw an error or handle anonymous users differently | |
return true; | |
} | |
// Check if throttling should be skipped | |
const skipThrottle = this.reflector.getAllAndOverride<boolean>(SKIP_THROTTLE, [ | |
context.getHandler(), | |
context.getClass(), | |
]); | |
if (skipThrottle) { | |
return true; | |
} | |
// Get throttle config from decorator or use default | |
const config = this.getThrottleConfig(context); | |
if (!config) { | |
return true; | |
} | |
// Get route info for action naming | |
const action = this.getActionFromContext(context); | |
try { | |
await this.throttlerService.consumeToken(token, action, config); | |
return true; | |
} catch (error) { | |
throw error; | |
} | |
} | |
/** Returns the user's id token extracted from the request, or null. */ | |
private extractTokenFromHeader(request: FastifyRequest): string | null { | |
const authHeader = request.headers.authorization; | |
if (!authHeader) { | |
return null; | |
} | |
const [type, token] = authHeader.split(" "); | |
if (type !== "Bearer" || !token || token.length < 10) { | |
return null; | |
} | |
return token; | |
} | |
/** Retrieves the throttling configuration from the current method -> controller -> global config. */ | |
private getThrottleConfig(context: ExecutionContext): ThrottleConfig | undefined { | |
// Check method-level decorator | |
const methodConfig = this.reflector.get<ThrottleConfig>( | |
THROTTLE_METADATA, | |
context.getHandler(), | |
); | |
if (methodConfig) return methodConfig; | |
// Check controller-level decorator | |
const controllerConfig = this.reflector.get<ThrottleConfig>( | |
THROTTLE_METADATA, | |
context.getClass(), | |
); | |
if (controllerConfig) return controllerConfig; | |
// Fall back to default config | |
return this.defaultConfig; | |
} | |
/** Returns a string representation of the current action/endpoint (controller:method). */ | |
private getActionFromContext(context: ExecutionContext): string { | |
const handler = context.getHandler(); | |
const controller = context.getClass(); | |
return `${controller.name}:${handler.name}`; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Defines a single rate limiting rule with a time window and request limit | |
*/ | |
export type ThrottleRule = { | |
/** Time window in milliseconds during which requests are counted */ | |
windowMs: number; | |
/** Maximum number of requests allowed within the time window */ | |
maxRequests: number; | |
/** Custom error message when this specific rule is violated */ | |
errorMessage?: string; | |
}; | |
/** | |
* Configuration for the throttling behavior | |
*/ | |
export type ThrottleConfig = { | |
/** Array of throttling rules to apply. All rules must pass for request to be allowed */ | |
rules: ThrottleRule[]; | |
/** Fallback error message when no specific rule message is provided */ | |
defaultErrorMessage?: string; | |
/** Whether to log throttling events to console for debugging */ | |
logToConsole?: boolean; | |
/** Maximum number of records to store (per action) */ | |
maxRecords?: number; | |
/** Maximum number of timestamps to store per user */ | |
maxTimestampsPerUser?: number; | |
}; | |
/** | |
* Internal storage record for tracking requests per user | |
*/ | |
export type ThrottleRecord = { | |
/** User identifier (from auth token) */ | |
userId: string; | |
/** Sorted array of timestamp numbers when requests were made */ | |
timestamps: number[]; | |
/** Timestamp of last cleanup */ | |
lastCleanup: number; | |
/** Largest window size in milliseconds from all rules */ | |
maxWindowMs: number; | |
}; | |
/** | |
* Result of checking if a request should be throttled | |
*/ | |
export type ThrottleCheckResult = { | |
/** Whether the request is allowed to proceed */ | |
allowed: boolean; | |
/** Number of requests remaining in the current window */ | |
remainingRequests: number; | |
/** Seconds until the user can retry if request was blocked */ | |
retryAfter?: number; | |
/** The specific rule that caused the throttling, if any */ | |
rule?: ThrottleRule; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment