Last active
October 9, 2023 17:55
-
-
Save gabbhack/0609c9813d3287fad3d3f07c4514a9eb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from aiogram import Bot, Dispatcher, executor, types | |
from aiogram.contrib.fsm_storage.redis import RedisStorage2 | |
from aiogram.dispatcher import DEFAULT_RATE_LIMIT | |
from aiogram.dispatcher.handler import CancelHandler, current_handler | |
from aiogram.dispatcher.middlewares import BaseMiddleware | |
from aiogram.utils.exceptions import Throttled | |
TOKEN = 'BOT TOKEN HERE' | |
loop = asyncio.get_event_loop() | |
# In this example Redis storage is used | |
storage = RedisStorage2(db=5) | |
bot = Bot(token=TOKEN, loop=loop) | |
dp = Dispatcher(bot, storage=storage) | |
def rate_limit(limit: int, key=None): | |
""" | |
Decorator for configuring rate limit and key in different functions. | |
:param limit: | |
:param key: | |
:return: | |
""" | |
def decorator(func): | |
setattr(func, 'throttling_rate_limit', limit) | |
if key: | |
setattr(func, 'throttling_key', key) | |
return func | |
return decorator | |
class ThrottlingMiddleware(BaseMiddleware): | |
""" | |
Simple middleware | |
""" | |
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'): | |
self.rate_limit = limit | |
self.prefix = key_prefix | |
super(ThrottlingMiddleware, self).__init__() | |
async def on_process_message(self, message: types.Message, data: dict): | |
""" | |
This handler is called when dispatcher receives a message | |
:param message: | |
""" | |
# Get current handler | |
handler = current_handler.get() | |
# Get dispatcher from context | |
dispatcher = Dispatcher.get_current() | |
# If handler was configured, get rate limit and key from handler | |
if handler: | |
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit) | |
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") | |
else: | |
limit = self.rate_limit | |
key = f"{self.prefix}_message" | |
# Use Dispatcher.throttle method. | |
try: | |
await dispatcher.throttle(key, rate=limit) | |
except Throttled as t: | |
# Execute action | |
await self.message_throttled(message, t) | |
# Cancel current handler | |
raise CancelHandler() | |
async def on_process_callback_query(self, query: types.CallbackQuery, data: dict): | |
# Get current handler | |
handler = current_handler.get() | |
# Get dispatcher from context | |
dispatcher = Dispatcher.get_current() | |
# If handler was configured, get rate limit and key from handler | |
if handler: | |
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit) | |
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") | |
else: | |
limit = self.rate_limit | |
key = f"{self.prefix}_message" | |
# Use Dispatcher.throttle method. | |
try: | |
await dispatcher.throttle(key, rate=limit) | |
except Throttled as t: | |
# Execute action | |
await self.callback_query_throttled(query, t) | |
# Cancel current handler | |
raise CancelHandler() | |
async def message_throttled(self, message: types.Message, throttled: Throttled): | |
""" | |
Notify user only on first exceed and notify about unlocking only on last exceed | |
:param message: | |
:param throttled: | |
""" | |
handler = current_handler.get() | |
dispatcher = Dispatcher.get_current() | |
if handler: | |
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") | |
else: | |
key = f"{self.prefix}_message" | |
# Calculate how many time is left till the block ends | |
delta = throttled.rate - throttled.delta | |
# Prevent flooding | |
if throttled.exceeded_count <= 2: | |
await message.reply('Too many requests! ') | |
# Sleep. | |
await asyncio.sleep(delta) | |
# Check lock status | |
thr = await dispatcher.check_key(key) | |
# If current message is not last with current key - do not send message | |
if thr.exceeded_count == throttled.exceeded_count: | |
await message.reply('Unlocked.') | |
async def callback_query_throttled(self, query: types.CallbackQuery, throttled: Throttled): | |
handler = current_handler.get() | |
dispatcher = Dispatcher.get_current() | |
if handler: | |
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") | |
else: | |
key = f"{self.prefix}_message" | |
# Calculate how many time is left till the block ends | |
delta = throttled.rate - throttled.delta | |
# Prevent flooding | |
if throttled.exceeded_count <= 2: | |
await query.answer('Too many requests! ', show_alert=True) | |
@dp.message_handler(commands=['start']) | |
@rate_limit(5, 'start') # this is not required but you can configure throttling manager for current handler using it | |
async def cmd_test(message: types.Message): | |
# You can use this command every 5 seconds | |
await message.reply('Test passed! You can use this command every 5 seconds.') | |
if __name__ == '__main__': | |
# Setup middleware | |
dp.middleware.setup(ThrottlingMiddleware()) | |
# Start long-polling | |
executor.start_polling(dp, loop=loop) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment