Skip to content

Instantly share code, notes, and snippets.

@trevorflahardy
Created November 28, 2021 06:30
Show Gist options
  • Save trevorflahardy/86c0566bddc9f1b92723dc94ff80984c to your computer and use it in GitHub Desktop.
Save trevorflahardy/86c0566bddc9f1b92723dc94ff80984c to your computer and use it in GitHub Desktop.
A simple addon to slash utils that allows for max_concurrency and cooldowns
from __future__ import annotations
from enum import Enum
import datetime
import inspect
from collections import defaultdict
import discord
import discord, discord.channel, discord.http, discord.state
from discord.ext import commands
from discord.utils import MISSING, utcnow
# New Coldown Imports, you can sort them yourself if desired
from discord.ext.commands.cooldowns import _Semaphore # Needs to be imported speficially
from discord.ext.commands._types import CoroFunc
from typing import Coroutine, Optional, TypeVar, Union, get_args, get_origin, overload, Generic, TYPE_CHECKING, Dict
T = TypeVar('T')
BotT = TypeVar("BotT", bound='Bot')
CtxT = TypeVar("CtxT", bound='Context')
CogT = TypeVar("CogT", bound='ApplicationCog')
NumT = Union[int, float]
__all__ = ['describe', 'SlashCommand', 'ApplicationCog', 'Range', 'Context', 'Bot', 'slash_command', 'message_command', 'user_command']
if TYPE_CHECKING:
from typing import Any, Awaitable, Callable, ClassVar
from typing_extensions import Concatenate, ParamSpec
import datetime
CmdP = ParamSpec("CmdP")
CmdT = Callable[Concatenate[CogT, CtxT, CmdP], Awaitable[Any]]
MsgCmdT = Callable[[CogT, CtxT, discord.Message], Awaitable[Any]]
UsrCmdT = Callable[[CogT, CtxT, discord.Member], Awaitable[Any]]
CtxMnT = Union[MsgCmdT, UsrCmdT]
RngT = TypeVar("RngT", bound='Range')
command_type_map: dict[type[Any], int] = {
str: 3,
int: 4,
bool: 5,
discord.User: 6,
discord.Member: 6,
discord.TextChannel: 7,
discord.VoiceChannel: 7,
discord.CategoryChannel: 7,
discord.Role: 8,
float: 10
}
channel_filter: dict[type[discord.abc.GuildChannel], int] = {
discord.TextChannel: 0,
discord.VoiceChannel: 2,
discord.CategoryChannel: 4
}
def describe(**kwargs):
"""
Sets the description for the specified parameters of the slash command. Sample usage:
```python
@slash_util.slash_command()
@describe(channel="The channel to ping")
async def mention(self, ctx: slash_util.Context, channel: discord.TextChannel):
await ctx.send(f'{channel.mention}')
```
If this decorator is not used, parameter descriptions will be set to "No description provided." instead."""
def _inner(cmd):
func = cmd.func if isinstance(cmd, SlashCommand) else cmd
for name, desc in kwargs.items():
try:
func._param_desc_[name] = desc
except AttributeError:
func._param_desc_ = {name: desc}
return cmd
return _inner
def slash_command(**kwargs) -> Callable[[CmdT], SlashCommand]:
"""
Defines a function as a slash-type application command.
Parameters:
- name: ``str``
- - The display name of the command. If unspecified, will use the functions name.
- guild_id: ``Optional[int]``
- - The guild ID this command will belong to. If unspecified, the command will be uploaded globally.
- description: ``str``
- - The description of the command. If unspecified, will use the functions docstring, or "No description provided" otherwise.
"""
def _inner(func: CmdT) -> SlashCommand:
return SlashCommand(func, **kwargs)
return _inner
def message_command(**kwargs) -> Callable[[MsgCmdT], MessageCommand]:
"""
Defines a function as a message-type application command.
Parameters:
- name: ``str``
- - The display name of the command. If unspecified, will use the functions name.
- guild_id: ``Optional[int]``
- - The guild ID this command will belong to. If unspecified, the command will be uploaded globally.
"""
def _inner(func: MsgCmdT) -> MessageCommand:
return MessageCommand(func, **kwargs)
return _inner
def user_command(**kwargs) -> Callable[[UsrCmdT], UserCommand]:
"""
Defines a function as a user-type application command.
Parameters:
- name: ``str``
- - The display name of the command. If unspecified, will use the functions name.
- guild_id: ``Optional[int]``
- - The guild ID this command will belong to. If unspecified, the command will be uploaded globally.
"""
def _inner(func: UsrCmdT) -> UserCommand:
return UserCommand(func, **kwargs)
return _inner
# Command cooldowns
# We need to re-make the BucketType class to be able to work with our custom Context.
# Because we dont have a Message object to use a created_at, and we dont have an Interaction.created_at
# it's best to add one to the Context class and continue.
# Please note we can't subclass an enum and have to re-make it
class BucketType(Enum):
default = 0
user = 1
guild = 2
channel = 3
member = 4
category = 5
role = 6
def get_key(self, obj: discord.Interaction | Context) -> Any:
current_user = obj.author if isinstance(obj, Context) else obj.user
if self is BucketType.user:
return current_user.id # type: ignore
elif self is BucketType.guild:
return (obj.guild or current_user).id # type: ignore
elif self is BucketType.channel:
return obj.channel.id # type: ignore
elif self is BucketType.member:
return ((obj.guild and obj.guild.id), current_user.id) # type: ignore
elif self is BucketType.category:
return (obj.channel.category or obj.channel).id # type: ignore
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (obj.channel if isinstance(obj.channel, PrivateChannel) else current_user.top_role).id # type: ignore
def __call__(self, obj: discord.Interaction | Context) -> Any:
return self.get_key(obj)
# All we simple do here is overwrite any functions that use a message parameter and
# replace it with an interaction object. It will raise a Type error EVERY time,
# but eh who cares.
class CooldownMapping(commands.CooldownMapping):
def __init__(
self,
original: Optional[commands.Cooldown],
type: Callable[[discord.Interaction], Any],
) -> None:
self._cache: Dict[Any, commands.Cooldown] = {}
self._cooldown: Optional[commands.Cooldown] = original
self._type: Callable[[discord.Interaction], Any] = type
def _bucket_key(self, interaction: discord.Interaction) -> Any:
return self._type(interaction) # type: ignore
def create_bucket(self, interaction: discord.Interaction) -> commands.Cooldown:
return self._cooldown.copy() # type: ignore
def get_bucket(self, interaction: discord.Interaction, current: Optional[float] = None) -> commands.Cooldown:
if self._type is BucketType.default:
return self._cooldown # type: ignore
self._verify_cache_integrity(current)
key = self._bucket_key(interaction)
if key not in self._cache:
bucket = self.create_bucket(interaction)
if bucket is not None:
self._cache[key] = bucket
else:
bucket = self._cache[key]
return bucket
def update_rate_limit(self, interaction: discord.Interaction, current: Optional[float] = None) -> Optional[float]:
bucket = self.get_bucket(interaction, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self,
factory: Callable[[discord.Interaction], commands.Cooldown],
type: Callable[[discord.Interaction], Any] #
) -> None:
super().__init__(None, type)
self._factory: Callable[[discord.Interaction], commands.Cooldown] = factory
def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type) # type: ignore
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
return True
def create_bucket(self, message: discord.Interaction) -> commands.Cooldown:
return self._factory(message)
# Please note that we can't call super().__init__ because a value error will get raised
# bc of our custom BucketType
class MaxConcurrency(commands.MaxConcurrency):
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
self.per: BucketType = per
self.number: int = number
self.wait: bool = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
def get_key(self, interaction: discord.Interaction) -> Any:
return self.per.get_key(interaction) # type: ignore
async def acquire(self, interaction: discord.Interaction) -> None:
key = self.get_key(interaction)
try:
sem = self._mapping[key]
except KeyError:
self._mapping[key] = sem = _Semaphore(self.number)
acquired = await sem.acquire(wait=self.wait)
if not acquired:
raise commands.errors.MaxConcurrencyReached(self.number, self.per) # type: ignore
async def release(self, interaction: discord.Interaction) -> None:
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(interaction)
try:
sem = self._mapping[key]
except KeyError:
# ...? peculiar
return
else:
sem.release()
if sem.value >= self.number and not sem.is_active():
del self._mapping[key]
# Now let's add our decos
# Please note this is usinc copy_doc, the doc string will be out of line because normal
# dpy uses message sand not interactions.
@discord.utils.copy_doc(commands.cooldown)
def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[discord.Interaction], Any]] = BucketType.default) -> Callable[[T], T]:
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func._buckets = CooldownMapping(commands.Cooldown(rate, per), type) # type: ignore
else:
func.__commands_cooldown__ = CooldownMapping(commands.Cooldown(rate, per), type) # type: ignore
return func
return decorator
@discord.utils.copy_doc(commands.dynamic_cooldown)
def dynamic_cooldown(cooldown: Union[BucketType, Callable[[discord.Interaction], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]:
if not callable(cooldown):
raise TypeError("A callable must be provided")
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
if isinstance(func, Command):
func._buckets = DynamicCooldownMapping(cooldown, type) # type: ignore
else:
func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) # type: ignore
return func
return decorator
@discord.utils.copy_doc(commands.max_concurrency)
def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]:
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
value = MaxConcurrency(number, per=per, wait=wait)
if isinstance(func, Command):
func._max_concurrency = value # type: ignore
else:
func.__commands_max_concurrency__ = value
return func
return decorator
class _RangeMeta(type):
@overload
def __getitem__(cls: type[RngT], max: int) -> type[int]: ...
@overload
def __getitem__(cls: type[RngT], max: tuple[int, int]) -> type[int]: ...
@overload
def __getitem__(cls: type[RngT], max: float) -> type[float]: ...
@overload
def __getitem__(cls: type[RngT], max: tuple[float, float]) -> type[float]: ...
def __getitem__(cls, max):
if isinstance(max, tuple):
return cls(*max)
return cls(None, max)
class Range(metaclass=_RangeMeta):
"""
Defines a minimum and maximum value for float or int values. The minimum value is optional.
```python
async def number(self, ctx, num: slash_util.Range[0, 10], other_num: slash_util.Range[10]):
...
```"""
def __init__(self, min: NumT | None, max: NumT):
if min is not None and min >= max:
raise ValueError("`min` value must be lower than `max`")
self.min = min
self.max = max
class Bot(commands.Bot):
async def start(self, token: str, *, reconnect: bool = True) -> None:
await self.login(token)
app_info = await self.application_info()
self._connection.application_id = app_info.id
await self.sync_commands()
await self.connect(reconnect=reconnect)
def get_application_command(self, name: str) -> Command | None:
"""
Gets and returns an application command by the given name.
Parameters:
- name: ``str``
- - The name of the command.
Returns:
- [Command](#deco-slash_commandkwargs)
- - The relevant command object
- ``None``
- - No command by that name was found.
"""
for c in self.cogs.values():
if isinstance(c, ApplicationCog):
c = c._commands.get(name)
if c:
return c
async def delete_all_commands(self, guild_id: int | None = None):
"""
Deletes all commands on the specified guild, or all global commands if no guild id was given.
Parameters:
- guild_id: ``Optional[str]``
- - The guild ID to delete from, or ``None`` to delete global commands.
"""
path = f'/applications/{self.application_id}'
if guild_id is not None:
path += f'/guilds/{guild_id}'
path += '/commands'
route = discord.http.Route("GET", path)
data = await self.http.request(route)
for cmd in data:
snow = cmd['id']
await self.delete_command(snow, guild_id=guild_id)
async def delete_command(self, id: int, *, guild_id: int | None = None):
"""
Deletes a command with the specified ID. The ID is a snowflake, not the name of the command.
Parameters:
- id: ``int``
- - The ID of the command to delete.
- guild_id: ``Optional[str]``
- - The guild ID to delete from, or ``None`` to delete a global command.
"""
route = discord.http.Route('DELETE', f'/applications/{self.application_id}{f"/guilds/{guild_id}" if guild_id else ""}/commands/{id}')
await self.http.request(route)
async def sync_commands(self) -> None:
"""
Uploads all commands from cogs found and syncs them with discord.
Global commands will take up to an hour to update. Guild specific commands will update immediately.
"""
if not self.application_id:
raise RuntimeError("sync_commands must be called after `run`, `start` or `login`")
for cog in self.cogs.values():
if not isinstance(cog, ApplicationCog):
continue
for cmd in cog._commands.values():
cmd.cog = cog
route = f"/applications/{self.application_id}"
if cmd.guild_id:
route += f"/guilds/{cmd.guild_id}"
route += '/commands'
body = cmd._build_command_payload()
route = discord.http.Route('POST', route)
await self.http.request(route, json=body)
class Context(Generic[BotT, CogT]):
"""
The command interaction context.
Attributes
- bot: [``slash_util.Bot``](#class-botcommand_prefix-help_commanddefault-help-command-descriptionnone-options)
- - Your bot object.
- command: Union[[SlashCommand](#deco-slash_commandkwargs), [UserCommand](#deco-user_commandkwargs), [MessageCommand](deco-message_commandkwargs)]
- - The command used with this interaction.
- interaction: [``discord.Interaction``](https://discordpy.readthedocs.io/en/master/api.html#discord.Interaction)
- - The interaction tied to this context.
craeted_at: :class:`datetime.datetime`
The time this context class was created.
"""
def __init__(self, bot: BotT, command: Command[CogT], interaction: discord.Interaction):
self.bot = bot
self.command = command
self.interaction = interaction
self._responded = False
# Ths makes it a bit easier to strap onto buckets
self.created_at: datetime.datetime = utcnow()
@overload
def send(self, content: str = MISSING, *, embed: discord.Embed = MISSING, ephemeral: bool = MISSING, tts: bool = MISSING, view: discord.ui.View = MISSING, file: discord.File = MISSING) -> Coroutine[Any, Any, Union[discord.InteractionMessage, discord.WebhookMessage]]: ...
@overload
def send(self, content: str = MISSING, *, embed: discord.Embed = MISSING, ephemeral: bool = MISSING, tts: bool = MISSING, view: discord.ui.View = MISSING, files: list[discord.File] = MISSING) -> Coroutine[Any, Any, Union[discord.InteractionMessage, discord.WebhookMessage]]: ...
@overload
def send(self, content: str = MISSING, *, embeds: list[discord.Embed] = MISSING, ephemeral: bool = MISSING, tts: bool = MISSING, view: discord.ui.View = MISSING, file: discord.File = MISSING) -> Coroutine[Any, Any, Union[discord.InteractionMessage, discord.WebhookMessage]]: ...
@overload
def send(self, content: str = MISSING, *, embeds: list[discord.Embed] = MISSING, ephemeral: bool = MISSING, tts: bool = MISSING, view: discord.ui.View = MISSING, files: list[discord.File] = MISSING) -> Coroutine[Any, Any, Union[discord.InteractionMessage, discord.WebhookMessage]]: ...
async def send(self, content = MISSING, **kwargs) -> Union[discord.InteractionMessage, discord.WebhookMessage]:
"""
Responds to the given interaction. If you have responded already, this will use the follow-up webhook instead.
Parameters ``embed`` and ``embeds`` cannot be specified together.
Parameters ``file`` and ``files`` cannot be specified together.
Parameters:
- content: ``str``
- - The content of the message to respond with
- embed: [``discord.Embed``](https://discordpy.readthedocs.io/en/master/api.html#discord.Embed)
- - An embed to send with the message. Incompatible with ``embeds``.
- embeds: ``List[``[``discord.Embed``](https://discordpy.readthedocs.io/en/master/api.html#discord.Embed)``]``
- - A list of embeds to send with the message. Incompatible with ``embed``.
- file: [``discord.File``](https://discordpy.readthedocs.io/en/master/api.html#discord.File)
- - A file to send with the message. Incompatible with ``files``.
- files: ``List[``[``discord.File``](https://discordpy.readthedocs.io/en/master/api.html#discord.File)``]``
- - A list of files to send with the message. Incompatible with ``file``.
- ephemeral: ``bool``
- - Whether the message should be ephemeral (only visible to the interaction user).
- tts: ``bool``
- - Whether the message should be played via Text To Speech. Send TTS Messages permission is required.
- view: [``discord.ui.View``](https://discordpy.readthedocs.io/en/master/api.html#discord.ui.View)
- - Components to attach to the sent message.
Returns
- [``discord.InteractionMessage``](https://discordpy.readthedocs.io/en/master/api.html#discord.InteractionMessage) if this is the first time responding.
- [``discord.WebhookMessage``](https://discordpy.readthedocs.io/en/master/api.html#discord.WebhookMessage) for consecutive responses.
"""
if self._responded:
return await self.interaction.followup.send(content, wait=True, **kwargs)
await self.interaction.response.send_message(content or None, **kwargs)
self._responded = True
return await self.interaction.original_message()
@property
def cog(self) -> CogT:
"""The cog this command belongs to."""
return self.command.cog
@property
def guild(self) -> discord.Guild:
"""The guild this interaction was executed in."""
return self.interaction.guild # type: ignore
@property
def message(self) -> discord.Message:
"""The message that executed this interaction."""
return self.interaction.message # type: ignore
@property
def channel(self) -> discord.interactions.InteractionChannel:
"""The channel the interaction was executed in."""
return self.interaction.channel # type: ignore
@property
def author(self) -> discord.Member:
"""The user that executed this interaction."""
return self.interaction.user # type: ignore
class Command(Generic[CogT]):
cog: CogT
func: Callable
name: str
guild_id: int | None
def _build_command_payload(self) -> dict[str, Any]:
raise NotImplementedError
def _build_arguments(self, interaction: discord.Interaction, state: discord.state.ConnectionState) -> dict[str, Any]:
raise NotImplementedError
async def invoke(self, context: Context[BotT, CogT], **params) -> None:
await self.func(self.cog, context, **params)
class SlashCommand(Command[CogT]):
def __init__(self, func: CmdT, **kwargs):
self.func = func
self.cog: CogT
self.name: str = kwargs.get("name", func.__name__)
self.description: str = inspect.cleandoc(kwargs.get('description', func.__doc__ or 'No description provided.'))
self.guild_id: int | None = kwargs.get("guild_id")
self.parameters = self._build_parameters()
self._parameter_descriptions: dict[str, str] = defaultdict(lambda: "No description provided")
# Command cooldowns
try:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get('cooldown')
if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default) # type: ignore
elif isinstance(cooldown, CooldownMapping):
buckets = cooldown
else:
raise TypeError("Cooldown must be a an instance of CooldownMapping or None.")
self._buckets: CooldownMapping = buckets
# Max concurrency
try:
max_concurrency = func.__commands_max_concurrency__
except AttributeError:
max_concurrency = kwargs.get('max_concurrency')
self._max_concurrency: Optional[MaxConcurrency] = max_concurrency
def _prepare_cooldowns(self, ctx: Context) -> None:
if self._buckets.valid:
dt = ctx.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
bucket = self._buckets.get_bucket(ctx.interaction, current)
print(bucket)
if bucket is not None:
retry_after = bucket.update_rate_limit(current)
if retry_after:
raise commands.CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
# We're going to overwrite invoke so we can do some operations with cooldown and max concurrecy
async def invoke(self, ctx: Context, **kwargs) -> Any:
if self._max_concurrency is not None:
await self._max_concurrency.acquire(ctx) # type: ignore
try:
self._prepare_cooldowns(ctx)
except:
if self._max_concurrency is not None:
await self._max_concurrency.release(ctx) # type: ignore
raise
await super().invoke(ctx, **kwargs)
def _build_arguments(self, interaction, state):
resolved = _parse_resolved_data(interaction, interaction.data.get('resolved'), state)
result = {}
# Please note that the key options can not be included
for option in interaction.data.get('options', []):
value = option['value']
if option['type'] in (6, 7, 8):
value = resolved[int(value)]
result[option['name']] = value
return result
def _build_parameters(self) -> dict[str, inspect.Parameter]:
params = list(inspect.signature(self.func).parameters.values())
try:
params.pop(0)
except IndexError:
raise ValueError("expected argument `self` is missing")
try:
params.pop(0)
except IndexError:
raise ValueError("expected argument `context` is missing")
return {p.name: p for p in params}
def _build_descriptions(self):
if not hasattr(self.func, '_param_desc_'):
return
for k, v in self.func._param_desc_.items():
if k not in self.parameters:
raise TypeError(f"@describe used to describe a non-existant parameter `{k}`")
self._parameter_descriptions[k] = v
def _build_command_payload(self):
self._build_descriptions()
payload = {
"name": self.name,
"description": self.description,
"type": 1
}
params = self.parameters
if params:
options = []
for name, param in params.items():
ann = param.annotation
if ann is param.empty:
raise TypeError(f"missing type annotation for parameter `{param.name}` for command `{self.name}`")
if isinstance(ann, str):
ann = eval(ann)
if isinstance(ann, Range):
real_t = type(ann.max)
elif get_origin(ann) is Union:
args = get_args(ann)
real_t = args[0]
else:
real_t = ann
typ = command_type_map[real_t]
option = {
'type': typ,
'name': name,
'description': self._parameter_descriptions[name]
}
if param.default is param.empty:
option['required'] = True
if isinstance(ann, Range):
option['max_value'] = ann.max
option['min_value'] = ann.min
elif get_origin(ann) is Union:
args = get_args(ann)
if not all(issubclass(k, discord.abc.GuildChannel) for k in args):
raise TypeError(f"Union parameter types only supported on *Channel types")
if len(args) != 3:
filtered = [channel_filter[i] for i in args]
option['channel_types'] = filtered
elif issubclass(ann, discord.abc.GuildChannel):
option['channel_types'] = [channel_filter[ann]]
options.append(option)
options.sort(key=lambda f: not f.get('required'))
payload['options'] = options
return payload
class ContextMenuCommand(Command[CogT]):
_type: ClassVar[int]
def __init__(self, func: CtxMnT, **kwargs):
self.func = func
self.guild_id: int | None = kwargs.get('guild_id', None)
self.name: str = kwargs.get('name', func.__name__)
def _build_command_payload(self):
payload = {
'name': self.name,
'type': self._type
}
if self.guild_id is not None:
payload['guild_id'] = self.guild_id
return payload
def _build_arguments(self, interaction: discord.Interaction, state: discord.state.ConnectionState) -> dict[str, Any]:
resolved = _parse_resolved_data(interaction, interaction.data.get('resolved'), state) # type: ignore
value = resolved[int(interaction.data['target_id'])] # type: ignore
return {'target': value}
async def invoke(self, context: Context[BotT, CogT], **params) -> None:
await self.func(self.cog, context, *params.values())
class MessageCommand(ContextMenuCommand[CogT]):
_type = 3
class UserCommand(ContextMenuCommand[CogT]):
_type = 2
def _parse_resolved_data(interaction: discord.Interaction, data, state: discord.state.ConnectionState):
if not data:
return {}
assert interaction.guild
resolved = {}
resolved_users = data.get('users')
if resolved_users:
resolved_members = data['members']
for id, d in resolved_users.items():
member_data = resolved_members[id]
member_data['user'] = d
member = discord.Member(data=member_data, guild=interaction.guild, state=state)
resolved[int(id)] = member
resolved_channels = data.get('channels')
if resolved_channels:
for id, d in resolved_channels.items():
d['position'] = None
cls, _ = discord.channel._guild_channel_factory(d['type'])
channel = cls(state=state, guild=interaction.guild, data=d)
resolved[int(id)] = channel
resolved_messages = data.get('messages')
if resolved_messages:
for id, d in resolved_messages.items():
msg = discord.Message(state=state, channel=interaction.channel, data=d) # type: ignore
resolved[int(id)] = msg
return resolved
class ApplicationCog(commands.Cog, Generic[BotT]):
"""
The cog that must be used for application commands.
Attributes:
- bot: [``slash_util.Bot``](#class-botcommand_prefix-help_commanddefault-help-command-descriptionnone-options)
- - The bot instance."""
def __init__(self, bot: BotT):
self.bot: BotT = bot
self._commands: dict[str, Command] = {}
slashes = inspect.getmembers(self, lambda c: isinstance(c, Command))
for k, v in slashes:
self._commands[v.name] = v
@commands.Cog.listener("on_interaction")
async def _internal_interaction_handler(self, interaction: discord.Interaction):
if interaction.type is not discord.InteractionType.application_command:
return
name = interaction.data['name'] # type: ignore
command = self._commands[name]
state = self.bot._connection
params: dict = command._build_arguments(interaction, state)
ctx = Context(self.bot, command, interaction)
await command.invoke(ctx, **params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment