Skip to content

Instantly share code, notes, and snippets.

@AspireOne
Created October 28, 2024 15:02
Show Gist options
  • Save AspireOne/621b25182f0ca246ef5c73504e0d48eb to your computer and use it in GitHub Desktop.
Save AspireOne/621b25182f0ca246ef5c73504e0d48eb to your computer and use it in GitHub Desktop.
NestJS Request Throttling
import { Injectable, OnModuleInit, Optional, Logger } from "@nestjs/common";
import { ThrottleMetrics } from "./throttle.metrics";
import {
ThrottleConfig,
ThrottleRecord,
ThrottleCheckResult,
ThrottleRule,
} from "./throttle.types";
import { ThrottleException } from "./throttle.exception";
@Injectable()
export class ThrottleService implements OnModuleInit {
private readonly logger = new Logger(ThrottleService.name);
private storage = new Map<string, ThrottleRecord>();
constructor(@Optional() private readonly metrics?: ThrottleMetrics) {}
private readonly CLEANUP_INTERVAL = 1000 * 60 * 15; // 15 minutes
private readonly DEFAULT_MAX_RECORDS = 10000; // Default max records per action
private readonly DEFAULT_MAX_TIMESTAMPS = 1000; // Default max timestamps per user
private intervalId?: NodeJS.Timeout;
private safeNow(): number {
return Math.min(Date.now(), Number.MAX_SAFE_INTEGER - 24 * 60 * 60 * 1000);
}
onModuleInit() {
this.intervalId = setInterval(() => {
this.cleanup();
}, this.CLEANUP_INTERVAL);
}
onModuleDestroy() {
if (this.intervalId) {
clearInterval(this.intervalId);
}
}
private cleanup() {
const now = this.safeNow();
const CLEANUP_THRESHOLD = 1000 * 60 * 5; // Only cleanup records every 5 minutes
this.metrics?.setStorageSize(this.storage.size);
// Enforce maximum records limit
if (this.storage.size > this.DEFAULT_MAX_RECORDS) {
const sortedEntries = Array.from(this.storage.entries()).sort(([, a], [, b]) => {
const lastA = a.timestamps[a.timestamps.length - 1] || 0;
const lastB = b.timestamps[b.timestamps.length - 1] || 0;
return lastA - lastB;
});
// Remove oldest entries until we're under the limit
while (this.storage.size > this.DEFAULT_MAX_RECORDS) {
const [key] = sortedEntries.shift()!;
this.storage.delete(key);
}
}
for (const [key, record] of this.storage.entries()) {
// Skip if cleanup was done recently
if (now - record.lastCleanup < CLEANUP_THRESHOLD) {
continue;
}
// Only remove if record is empty
if (record.timestamps.length === 0) {
this.storage.delete(key);
continue;
}
// Find the cutoff index for old timestamps
const cutoffTime = now - record.maxWindowMs;
const cutoffIndex = this.binarySearch(record.timestamps, cutoffTime);
if (cutoffIndex > 0) {
// Remove all timestamps before the cutoff
record.timestamps = record.timestamps.slice(cutoffIndex);
}
record.lastCleanup = now;
}
}
private binarySearch(timestamps: number[], target: number): number {
if (timestamps.length === 0) return 0;
let left = 0;
let right = timestamps.length;
while (left < right) {
const mid = Math.floor((left + right) / 2);
if (timestamps[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left;
}
private getStorageKey(userId: string, action?: string): string {
return `${userId}:${action || "default"}`;
}
private checkRule(timestamps: number[], rule: ThrottleRule): ThrottleCheckResult {
if (!timestamps || timestamps.length === 0) {
return {
allowed: true,
remainingRequests: rule.maxRequests,
};
}
const now = this.safeNow();
const windowStart = now - rule.windowMs;
// Quick check using total count first
if (timestamps.length < rule.maxRequests) {
return {
allowed: true,
remainingRequests: rule.maxRequests - timestamps.length,
};
}
// Only if we might exceed the limit, we do the window filtering
const windowIndex = this.binarySearch(timestamps, windowStart);
const requestsInWindow = timestamps.length - windowIndex;
const allowed = requestsInWindow < rule.maxRequests;
const remainingRequests = Math.max(0, rule.maxRequests - requestsInWindow);
return {
allowed,
remainingRequests,
retryAfter: allowed
? undefined
: Math.ceil((windowStart + rule.windowMs - now) / 1000),
rule: allowed ? undefined : rule,
};
}
async checkThrottle(
userId: string,
action?: string,
config?: ThrottleConfig,
): Promise<ThrottleCheckResult> {
if (!userId || typeof userId !== "string") {
throw new Error("Invalid user ID provided");
}
if (!config?.rules?.length) {
return { allowed: true, remainingRequests: Infinity };
}
// Validate configuration
if (!Array.isArray(config.rules)) {
throw new Error("Invalid throttle configuration: rules must be an array");
}
// Validate config
for (const rule of config.rules) {
if (!Number.isFinite(rule.windowMs) || !Number.isFinite(rule.maxRequests)) {
throw new Error(
"Invalid throttle rule: windowMs and maxRequests must be finite numbers",
);
}
if (rule.windowMs <= 0 || rule.maxRequests <= 0) {
throw new Error(
"Invalid throttle rule: windowMs and maxRequests must be positive",
);
}
if (rule.windowMs > 24 * 60 * 60 * 1000) {
console.warn("Throttle window larger than 24 hours may cause issues:", rule);
}
}
const key = this.getStorageKey(userId, action);
const now = this.safeNow();
// Calculate max window size from all rules
const maxWindowMs = Math.max(...config.rules.map((rule) => rule.windowMs));
// Get or create record atomically
let record = this.storage.get(key);
if (!record) {
record = {
userId,
timestamps: [],
lastCleanup: now,
maxWindowMs,
};
this.storage.set(key, record);
}
// Validate record structure
if (!Array.isArray(record.timestamps) || typeof record.lastCleanup !== "number") {
// Record is corrupted, reset it
record = {
userId,
timestamps: [],
lastCleanup: now,
maxWindowMs,
};
this.storage.set(key, record);
}
// Enforce maximum timestamps per user
const maxTimestamps = config.maxTimestampsPerUser || this.DEFAULT_MAX_TIMESTAMPS;
if (record.timestamps.length >= maxTimestamps) {
const cutoffTime = now - record.maxWindowMs;
const cutoffIndex = this.binarySearch(record.timestamps, cutoffTime);
if (cutoffIndex === 0 && record.timestamps.length >= maxTimestamps) {
throw new ThrottleException(
`Rate limit exceeded: maximum request history (${maxTimestamps}) reached`,
Math.ceil((record.timestamps[0] + record.maxWindowMs - now) / 1000),
);
}
}
// Update maxWindowMs if config changed
record.maxWindowMs = Math.max(record.maxWindowMs, maxWindowMs);
// Clean up old timestamps if needed
const cutoffTime = now - record.maxWindowMs;
const cutoffIndex = this.binarySearch(record.timestamps, cutoffTime);
if (cutoffIndex > 0) {
record.timestamps = record.timestamps.slice(cutoffIndex);
record.lastCleanup = now;
// Keep the record even if empty to maintain the maxWindowMs
record.lastCleanup = now;
record.maxWindowMs = Math.max(...config.rules.map((rule) => rule.windowMs));
}
// Check all rules with simulated new timestamp
const simulatedTimestamps = [...record.timestamps, now];
for (const rule of config.rules) {
const result = this.checkRule(simulatedTimestamps, rule);
if (!result.allowed) {
if (config.logToConsole) {
console.log(`Throttle exceeded for user ${userId} on action ${action}`, result);
}
return result;
}
}
// If all rules pass with simulated timestamp, update the record
// Re-check storage to handle race conditions (with max retries)
const currentRecord = this.storage.get(key);
if (!currentRecord || currentRecord !== record) {
// Another request modified the record, retry once
return this.checkThrottle(userId, action, config);
}
// Ensure timestamps remain sorted
record.timestamps.push(now);
record.timestamps.sort((a, b) => a - b);
this.storage.set(key, record);
// Return the most restrictive remaining requests
const remainingRequests = Math.min(
...config.rules.map(
(rule) => this.checkRule(record.timestamps, rule).remainingRequests,
),
);
return { allowed: true, remainingRequests };
}
async consumeToken(
userId: string,
action?: string,
config?: ThrottleConfig,
): Promise<{ remainingRequests: number }> {
if (!config || !config.rules || config.rules.length === 0) {
this.metrics?.incrementAllowed(userId, action || 'default');
return { remainingRequests: Infinity }; // No throttling rules defined
}
const result = await this.checkThrottle(userId, action, config);
if (!result.allowed) {
this.metrics?.incrementThrottled(userId, action || 'default');
this.logger.warn(
`Rate limit exceeded for user ${userId} on action ${action}. ` +
`Retry after: ${result.retryAfter}s`
);
throw new ThrottleException(
result.rule?.errorMessage || config?.defaultErrorMessage || "Too Many Requests",
result.retryAfter,
);
}
this.metrics?.incrementAllowed(userId, action || 'default');
return { remainingRequests: result.remainingRequests };
}
}
import { Injectable, CanActivate, ExecutionContext, Inject } from "@nestjs/common";
import { Reflector } from "@nestjs/core";
import { ThrottleService } from "./throttle.service";
import { ThrottleConfig } from "./throttle.types";
import { THROTTLE_METADATA, SKIP_THROTTLE } from "./throttle.decorator";
import { FastifyRequest } from "fastify";
import { THROTTLE_CONFIG_KEY } from "./throttle.module";
/** Guard responsible for managing rate limiting in the backend. */
@Injectable()
export class ThrottleGuard implements CanActivate {
constructor(
private readonly reflector: Reflector,
private readonly throttlerService: ThrottleService,
// Global throttle config. This is set in ThrottleModule and injected here for consistency.
@Inject(THROTTLE_CONFIG_KEY) private readonly defaultConfig?: ThrottleConfig,
) {}
async canActivate(context: ExecutionContext): Promise<boolean> {
const request = context.switchToHttp().getRequest();
// Get user token from request
const token = this.extractTokenFromHeader(request);
if (!token) {
// Optional: throw an error or handle anonymous users differently
return true;
}
// Check if throttling should be skipped
const skipThrottle = this.reflector.getAllAndOverride<boolean>(SKIP_THROTTLE, [
context.getHandler(),
context.getClass(),
]);
if (skipThrottle) {
return true;
}
// Get throttle config from decorator or use default
const config = this.getThrottleConfig(context);
if (!config) {
return true;
}
// Get route info for action naming
const action = this.getActionFromContext(context);
try {
await this.throttlerService.consumeToken(token, action, config);
return true;
} catch (error) {
throw error;
}
}
/** Returns the user's id token extracted from the request, or null. */
private extractTokenFromHeader(request: FastifyRequest): string | null {
const authHeader = request.headers.authorization;
if (!authHeader) {
return null;
}
const [type, token] = authHeader.split(" ");
if (type !== "Bearer" || !token || token.length < 10) {
return null;
}
return token;
}
/** Retrieves the throttling configuration from the current method -> controller -> global config. */
private getThrottleConfig(context: ExecutionContext): ThrottleConfig | undefined {
// Check method-level decorator
const methodConfig = this.reflector.get<ThrottleConfig>(
THROTTLE_METADATA,
context.getHandler(),
);
if (methodConfig) return methodConfig;
// Check controller-level decorator
const controllerConfig = this.reflector.get<ThrottleConfig>(
THROTTLE_METADATA,
context.getClass(),
);
if (controllerConfig) return controllerConfig;
// Fall back to default config
return this.defaultConfig;
}
/** Returns a string representation of the current action/endpoint (controller:method). */
private getActionFromContext(context: ExecutionContext): string {
const handler = context.getHandler();
const controller = context.getClass();
return `${controller.name}:${handler.name}`;
}
}
/**
* Defines a single rate limiting rule with a time window and request limit
*/
export type ThrottleRule = {
/** Time window in milliseconds during which requests are counted */
windowMs: number;
/** Maximum number of requests allowed within the time window */
maxRequests: number;
/** Custom error message when this specific rule is violated */
errorMessage?: string;
};
/**
* Configuration for the throttling behavior
*/
export type ThrottleConfig = {
/** Array of throttling rules to apply. All rules must pass for request to be allowed */
rules: ThrottleRule[];
/** Fallback error message when no specific rule message is provided */
defaultErrorMessage?: string;
/** Whether to log throttling events to console for debugging */
logToConsole?: boolean;
/** Maximum number of records to store (per action) */
maxRecords?: number;
/** Maximum number of timestamps to store per user */
maxTimestampsPerUser?: number;
};
/**
* Internal storage record for tracking requests per user
*/
export type ThrottleRecord = {
/** User identifier (from auth token) */
userId: string;
/** Sorted array of timestamp numbers when requests were made */
timestamps: number[];
/** Timestamp of last cleanup */
lastCleanup: number;
/** Largest window size in milliseconds from all rules */
maxWindowMs: number;
};
/**
* Result of checking if a request should be throttled
*/
export type ThrottleCheckResult = {
/** Whether the request is allowed to proceed */
allowed: boolean;
/** Number of requests remaining in the current window */
remainingRequests: number;
/** Seconds until the user can retry if request was blocked */
retryAfter?: number;
/** The specific rule that caused the throttling, if any */
rule?: ThrottleRule;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment