Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save oguzhaneren/8920d6b9b0eb2258340905f03494bb81 to your computer and use it in GitHub Desktop.
Save oguzhaneren/8920d6b9b0eb2258340905f03494bb81 to your computer and use it in GitHub Desktop.
Rate limiter middleware for requests which contains RequestVerificationToken header
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.Internal;
using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Primitives;
namespace Test.Infrastructure
{
public class RateLimiterByAntiForgeryTokenMiddlewareOptions
{
public TimeSpan CacheExpireOffset { get; set; } = TimeSpan.FromSeconds(30);
public TimeSpan PeriodOffsetForOneRequest { get; set; } = TimeSpan.FromSeconds(5);
}
public class RateLimiterByAntiForgeryTokenMiddleware
{
private readonly RequestDelegate _next;
private readonly IMemoryCache _cache;
private readonly RateLimiterByAntiForgeryTokenMiddlewareOptions _options;
class Request
{
public string RequestVerificationToken { get; set; }
public DateTime LastRequest { get; set; }
}
public RateLimiterByAntiForgeryTokenMiddleware(RequestDelegate next, IMemoryCache memoryCache,
RateLimiterByAntiForgeryTokenMiddlewareOptions options)
{
_next = next;
_cache = memoryCache;
_options = options;
}
public async Task Invoke(HttpContext context)
{
StringValues headerValue;
context.Request.Headers.TryGetValue("RequestVerificationToken", out headerValue);
var requestVerificationToken = headerValue.FirstOrDefault()?.Trim();
if (!headerValue.Any() || string.IsNullOrEmpty(requestVerificationToken))
{
await _next(context);
return;
}
var key = GetMd5Hash(requestVerificationToken);
object request;
if (!_cache.TryGetValue(key, out request))
{
request = new Request()
{
LastRequest = DateTime.UtcNow,
RequestVerificationToken = requestVerificationToken
};
}
else
{
var req = request as Request;
var diff = DateTime.UtcNow - req.LastRequest;
if (diff< _options.PeriodOffsetForOneRequest)
{
context.Response.StatusCode = 429;
await context.Response.WriteAsync("Too many requests");
return;
}
req.LastRequest = DateTime.UtcNow;
}
var cacheEntryOptions = new MemoryCacheEntryOptions()
.SetSlidingExpiration(_options.CacheExpireOffset);
_cache.Set(key, request, cacheEntryOptions);
await _next(context);
}
static string GetMd5Hash(string input)
{
using (var md5Hash = MD5.Create())
{
byte[] data = md5Hash.ComputeHash(Encoding.UTF8.GetBytes(input));
StringBuilder sBuilder = new StringBuilder();
for (int i = 0; i < data.Length; i++)
{
sBuilder.Append(data[i].ToString("x2"));
}
return sBuilder.ToString();
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment