Skip to content

Instantly share code, notes, and snippets.

@NaniteFactory
Created June 8, 2020 03:37
Show Gist options
  • Save NaniteFactory/51fff8444b11a6c60bb0f1b3b2ead43e to your computer and use it in GitHub Desktop.
Save NaniteFactory/51fff8444b11a6c60bb0f1b3b2ead43e to your computer and use it in GitHub Desktop.
Discordgo sharding
package shard
import (
"errors"
"fmt"
"strconv"
"sync"
"github.com/bwmarrin/discordgo"
)
// Discord manages shards.
type Discord struct {
botToken string
mu sync.RWMutex
shards []*discordgo.Session
}
// NewDiscord is a constructor.
func NewDiscord(botToken string) *Discord {
return &Discord{
botToken: botToken,
mu: sync.RWMutex{},
shards: nil,
}
}
// OpenUnsharded opens a single session with no shards.
func (d *Discord) OpenUnsharded(handlers ...interface{}) error {
d.mu.Lock()
defer d.mu.Unlock()
//
sess, err := discordgo.New("Bot " + d.botToken)
if err != nil {
return err
}
d.shards = []*discordgo.Session{sess}
return sess.Open()
}
// Open shards sessions and add event handlers to them.
func (d *Discord) Open(handlers ...interface{}) error {
d.mu.Lock()
defer d.mu.Unlock()
// Get Gateway Bot.
gateway, err := discordgo.New("Bot " + d.botToken)
if err != nil {
return err
}
st, err := gateway.GatewayBot()
if err != nil {
return err
}
// Shard into sessions.
d.shards = make([]*discordgo.Session, st.Shards)
for i := 0; i < st.Shards; i++ {
d.shards[i], err = discordgo.New("Bot " + d.botToken)
if err != nil {
return err
}
d.shards[i].ShardID = i
d.shards[i].ShardCount = st.Shards
}
// Add handlers.
for _, sess := range d.shards {
for _, handler := range handlers {
sess.AddHandler(handler)
}
}
var errOpen error
{ // Open ws connections.
wg := sync.WaitGroup{}
for _, sess := range d.shards {
wg.Add(1)
go func(sess *discordgo.Session) {
if err := sess.Open(); err != nil {
errOpen = err
}
wg.Done()
}(sess)
}
wg.Wait()
}
// Close if fail.
if errOpen != nil {
wg := sync.WaitGroup{}
for _, sess := range d.shards {
wg.Add(1)
go func(sess *discordgo.Session) {
_ = sess.Close()
wg.Done()
}(sess)
}
wg.Wait()
}
return errOpen
}
// Close tries to close all sessions.
func (d *Discord) Close() error {
d.mu.Lock()
defer d.mu.Unlock()
//
wg := sync.WaitGroup{}
var errClose error
for _, sess := range d.shards {
wg.Add(1)
go func(sess *discordgo.Session) {
if err := sess.Close(); err != nil {
errClose = err
}
wg.Done()
}(sess)
}
wg.Wait()
return errClose
}
// ShardByShardID finds a session with ShardID.
func (d *Discord) ShardByShardID(id int) (*discordgo.Session, error) {
d.mu.RLock()
defer d.mu.RUnlock()
//
if id >= len(d.shards) || id < 0 {
return nil, errors.New("outbound shard id")
}
return d.shards[id], nil
}
// ShardByGuild finds a session with GuildID.
// Errors if an invalid guild id given.
func (d *Discord) ShardByGuild(guildID string) (*discordgo.Session, error) {
d.mu.RLock()
defer d.mu.RUnlock()
//
gid, err := strconv.Atoi(guildID)
if err != nil {
return nil, err
}
return d.shardByGuild(gid)
}
func (d *Discord) shardByGuild(guildID int) (*discordgo.Session, error) {
// no mutex
//
// https://discord.com/developers/docs/topics/gateway#sharding-sharding-formula
// (guild_id >> 22) % num_shards == shard_id
i := (guildID >> 22) % len(d.shards)
if i >= len(d.shards) || i < 0 {
return nil, errors.New(fmt.Sprint("calculated outbound shard id ", i))
}
return d.shards[i], nil
//
// for _, sess := range d.shards {
// g, err := sess.State.Guild(guildID)
// if err != nil {
// continue
// }
// if g == nil {
// continue
// }
// if g.ID == guildID {
// return sess
// }
// }
// return nil
}
// ShardByChannel finds a session with ChannelID.
func (d *Discord) ShardByChannel(channelID string) (*discordgo.Session, error) {
d.mu.RLock()
defer d.mu.RUnlock()
//
return d.shardByChannel(channelID)
}
func (d *Discord) shardByChannel(channelID string) (*discordgo.Session, error) {
var errLast error
for _, shard := range d.shards {
st, err := shard.State.Channel(channelID)
if err != nil {
errLast = err
continue
}
guildID, err := strconv.Atoi(st.GuildID)
if err != nil {
errLast = err
continue
}
ret, err := d.shardByGuild(guildID)
if err != nil {
errLast = err
continue
}
if ret != nil {
return ret, nil
}
}
// retry if state not found
for _, shard := range d.shards {
st, err := shard.Channel(channelID)
if err != nil {
errLast = err
continue
}
guildID, err := strconv.Atoi(st.GuildID)
if err != nil {
errLast = err
continue
}
ret, err := d.shardByGuild(guildID)
if err != nil {
errLast = err
continue
}
if ret != nil {
return ret, nil
}
}
return nil, errLast
}
// Each for all shards.
func (d *Discord) Each(f func(sess *discordgo.Session)) {
d.mu.RLock()
defer d.mu.RUnlock()
//
wg := sync.WaitGroup{}
for _, shard := range d.shards {
wg.Add(1)
go func(shard *discordgo.Session) {
f(shard)
wg.Done()
}(shard)
}
wg.Wait()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment