|
using Newtonsoft.Json; |
|
using Newtonsoft.Json.Linq; |
|
using System.Net; |
|
using System.Net.Http.Json; |
|
using System.Text; |
|
using System.Web; |
|
using Discord; |
|
using Discord.WebSocket; |
|
using FreneticUtilities.FreneticToolkit; |
|
using Microsoft.Extensions.DependencyInjection; |
|
using System.Xml.Linq; |
|
|
|
namespace TestDiscordAIBot; |
|
|
|
public class LLMParams |
|
{ |
|
public int max_new_tokens = 500; |
|
public bool do_sample = true; |
|
public float temperature = 0.7f; |
|
public float top_p = 0.1f; |
|
public float typical_p = 1; |
|
public float repetition_penalty = 1.18f; |
|
public float encoder_repetition_penalty = 1.0f; |
|
public int top_k = 40; |
|
public int min_length = 0; |
|
public int no_repeat_ngram_size = 0; |
|
public int num_beams = 1; |
|
public float penalty_alpha = 0; |
|
public int length_penalty = 1; |
|
public bool early_stopping = false; |
|
public int seed = -1; |
|
public bool add_bos_token = false; |
|
public bool skip_special_tokens = true; |
|
public string[] stopping_strings = Array.Empty<string>(); |
|
} |
|
|
|
public static class TextGenAPI |
|
{ |
|
public static HttpClient Client = new(); |
|
|
|
public static string URLBase = "YOUR TEXT GEN WEBUI HERE"; // !!!!!!!!!!!!!! FILL ME IN |
|
|
|
public static UTF8Encoding Encoding = new(false); |
|
|
|
static TextGenAPI() |
|
{ |
|
Client.DefaultRequestHeaders.Add("user-agent", "TestDiscordAIBot/1.0"); |
|
} |
|
|
|
public static string SendRequest(string prompt, LLMParams llmParam) |
|
{ |
|
JObject jData = new() |
|
{ |
|
["prompt"] = prompt, |
|
["max_new_tokens"] = llmParam.max_new_tokens, |
|
["do_sample"] = llmParam.do_sample, |
|
["temperature"] = llmParam.temperature, |
|
["top_p"] = llmParam.top_p, |
|
["typical_p"] = llmParam.typical_p, |
|
["repetition_penalty"] = llmParam.repetition_penalty, |
|
["encoder_repetition_penalty"] = llmParam.encoder_repetition_penalty, |
|
["top_k"] = llmParam.top_k, |
|
["min_length"] = llmParam.min_length, |
|
["no_repeat_ngram_size"] = llmParam.no_repeat_ngram_size, |
|
["num_beams"] = llmParam.num_beams, |
|
["penalty_alpha"] = llmParam.penalty_alpha, |
|
["length_penalty"] = llmParam.length_penalty, |
|
["early_stopping"] = llmParam.early_stopping, |
|
["seed"] = llmParam.seed, |
|
["add_bos_token"] = llmParam.add_bos_token, |
|
["skip_special_tokens"] = llmParam.skip_special_tokens, |
|
["stopping_strings"] = JToken.FromObject(llmParam.stopping_strings) |
|
}; |
|
string serialized = JsonConvert.SerializeObject(jData); |
|
Console.WriteLine($"will send: {serialized}"); |
|
HttpResponseMessage response = Client.PostAsync($"{URLBase}/api/v1/generate", new StringContent(serialized, Encoding, "application/json")).Result; |
|
Console.WriteLine($"Response type: {(int)response.StatusCode} {response.StatusCode}, {response.Content}"); |
|
string responseText = response.Content.ReadAsStringAsync().Result; |
|
Console.WriteLine($"Response text: {responseText}"); |
|
string result = JObject.Parse(responseText)["results"][0]["text"].ToString(); |
|
Console.WriteLine($"Result text: {result}"); |
|
return result; |
|
} |
|
} |
|
|
|
public static class Program |
|
{ |
|
public static string PrePrompt = "YOUR PRE PROMPT HERE"; // !!!!!!!!!!!!!! FILL ME IN |
|
|
|
public static DiscordSocketClient Client; |
|
|
|
public static AsciiMatcher AlphanumericMatcher = new(AsciiMatcher.BothCaseLetters + AsciiMatcher.Digits); |
|
|
|
public record class CachedMessage(string Content, ulong RefId, ulong Author, string AuthorName); |
|
|
|
public static Dictionary<ulong, CachedMessage> MessageCache = new(); |
|
|
|
public static CachedMessage GetMessageCached(ulong channel, ulong id) |
|
{ |
|
if (MessageCache.TryGetValue(id, out CachedMessage res)) |
|
{ |
|
return res; |
|
} |
|
IMessage message = (Client.GetChannel(channel) as SocketTextChannel).GetMessageAsync(id).Result; |
|
Console.WriteLine($"Must fill cache on message {message.Id}"); |
|
if (message is null) |
|
{ |
|
MessageCache[id] = null; |
|
return null; |
|
} |
|
CachedMessage cache = new(message.Content, message.Reference?.MessageId.GetValueOrDefault(0) ?? 0, message.Author?.Id ?? 0, message.Author?.Username ?? ""); |
|
MessageCache[id] = cache; |
|
return cache; |
|
} |
|
|
|
public static void Main() |
|
{ |
|
Console.WriteLine("Starting..."); |
|
DiscordSocketConfig config = new() |
|
{ |
|
MessageCacheSize = 50, |
|
AlwaysDownloadUsers = true, |
|
GatewayIntents = GatewayIntents.AllUnprivileged | GatewayIntents.MessageContent |
|
}; |
|
Client = new DiscordSocketClient(config); |
|
Client.Ready += () => |
|
{ |
|
Console.WriteLine("Bot ready."); |
|
return Task.CompletedTask; |
|
}; |
|
LLMParams llmParams = new() { stopping_strings = new[] { "\n###" } }; |
|
Client.MessageReceived += (message) => |
|
{ |
|
if (message.Content is null || message.Author.IsBot || message.Author.IsWebhook || message is not IUserMessage userMessage || message.Channel is not IGuildChannel guildChannel) |
|
{ |
|
return Task.CompletedTask; |
|
} |
|
string prePrompt, user, helper; |
|
string rawUser = AlphanumericMatcher.TrimToMatches(message.Author.Username); |
|
user = rawUser; |
|
if (user.Length < 3) |
|
{ |
|
user = "User"; |
|
} |
|
ulong guild = guildChannel.GuildId; |
|
prePrompt = PrePrompt; |
|
/* // !!!!!!!!!!!!!! FILL ME IN - OPTIONAL PROMPT SWAPPER PER GUILD |
|
if (guild == 123ul) |
|
{ |
|
prePrompt = PrePromptA; |
|
user = "User"; |
|
helper = "Helper"; |
|
} |
|
else if (guild == 456ul) |
|
{ |
|
prePrompt = PrePromptB; |
|
helper = "Llama"; |
|
} |
|
else if (guild == 789ul) |
|
{ |
|
prePrompt = PrePromptC; |
|
} |
|
else |
|
{ |
|
Console.WriteLine("Bad guild"); |
|
return Task.CompletedTask; |
|
}*/ |
|
user = "### Human"; |
|
helper = "### Assistant"; |
|
string prior = ""; |
|
bool isSelfRef = message.Content.Contains($"<@{Client.CurrentUser.Id}>") || message.Content.Contains($"<@!{Client.CurrentUser.Id}>"); |
|
if (message.Reference is not null && message.Reference.ChannelId == message.Channel.Id) |
|
{ |
|
CachedMessage cache = GetMessageCached(message.Channel.Id, message.Reference.MessageId.Value); |
|
if (cache is null) |
|
{ |
|
return Task.CompletedTask; |
|
} |
|
CachedMessage refMessage = GetMessageCached(message.Channel.Id, message.Reference.MessageId.Value); |
|
while (refMessage is not null) |
|
{ |
|
isSelfRef = true; |
|
if (refMessage.Author != Client.CurrentUser.Id || refMessage.RefId == 0) |
|
{ |
|
return Task.CompletedTask; |
|
} |
|
CachedMessage ref2 = GetMessageCached(message.Channel.Id, refMessage.RefId); |
|
if (ref2 is null) |
|
{ |
|
return Task.CompletedTask; |
|
} |
|
string aname = AlphanumericMatcher.TrimToMatches(ref2.AuthorName); |
|
if (aname.Length < 3) |
|
{ |
|
aname = "User"; |
|
} |
|
prior = $"{aname}: {ref2.Content}\n{helper}: {refMessage.Content}\n{prior}"; |
|
refMessage = ref2.RefId == 0 ? null : GetMessageCached(message.Channel.Id, ref2.RefId); |
|
} |
|
} |
|
prePrompt = prePrompt.Replace("{{user}}", user).Replace("{{username}}", rawUser).Replace("{{helper}}", helper).Replace("{{date}}", DateTimeOffset.Now.ToString("yyyy-MM-dd HH:mm")); |
|
if (!isSelfRef) |
|
{ |
|
return Task.CompletedTask; |
|
} |
|
string input = message.Content.Replace($"<@{Client.CurrentUser.Id}>", "").Replace($"<@!{Client.CurrentUser.Id}>", "").Trim(); |
|
Console.WriteLine($"Got input: {prior} {input}"); |
|
if (input.StartsWith("[nopreprompt]")) |
|
{ |
|
prePrompt = ""; |
|
input = input["[nopreprompt]".Length..].Trim(); |
|
} |
|
else |
|
{ |
|
input = input.Replace("\n", " "); |
|
} |
|
using (message.Channel.EnterTypingState()) |
|
{ |
|
string res = TextGenAPI.SendRequest($"{prePrompt}{prior}{user}: {input}\n{helper}:", llmParams); |
|
int line = res.IndexOf("\n###"); |
|
if (line != -1) |
|
{ |
|
res = res[..line]; |
|
} |
|
Console.WriteLine($"\n\nUser: {input}\n{helper}:{res}\n\n"); |
|
res = res.Replace("\\", "\\\\").Replace("<", "\\<").Replace(">", "\\>").Replace("@", "\\@ ") |
|
.Replace("http://", "").Replace("https://", "").Trim(); |
|
if (string.IsNullOrWhiteSpace(res)) |
|
{ |
|
res = "[Error]"; |
|
} |
|
(message as IUserMessage).ReplyAsync(res, allowedMentions: AllowedMentions.None).Wait(); |
|
} |
|
return Task.CompletedTask; |
|
}; |
|
Console.WriteLine("Logging in to Discord..."); |
|
Client.LoginAsync(TokenType.Bot, "YOUR TOKEN HERE").Wait(); // !!!!!!!!!!!!!! FILL ME IN |
|
Console.WriteLine("Connecting to Discord..."); |
|
Client.StartAsync().Wait(); |
|
Console.WriteLine("Running Discord!"); |
|
while (true) |
|
{ |
|
string input = Console.ReadLine(); |
|
if (input is null) |
|
{ |
|
return; |
|
} |
|
input = input.Replace("\n", " "); |
|
string fullPrompt = $"User: {input}\nHelper: "; |
|
string res = TextGenAPI.SendRequest(fullPrompt, llmParams); |
|
Console.WriteLine($"AI says back: {res}"); |
|
int line = res.IndexOf('\n'); |
|
if (line != -1) |
|
{ |
|
res = res[..line]; |
|
} |
|
Console.WriteLine($"\n\nUser: {input}\nHelper: {res}\n\n"); |
|
} |
|
} |
|
} |