-
-
Save kalloc/ae9a334346a2439b8c99143d392a8897 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from telethon import TelegramClient, events, sync | |
import numpy as np | |
import asyncio | |
import random | |
from telethon.tl.functions.account import UpdateProfileRequest | |
from telethon.tl.functions.messages import SendReactionRequest | |
from telethon.tl.types import IpPort, ReactionEmoji | |
import torch | |
import os | |
import sys | |
import logging | |
import time | |
from transformers import BertTokenizer, BertForSequenceClassification | |
def load_toxic_model(): | |
model_name = 'Skoltech/russian-sensitive-topics' | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
model = BertForSequenceClassification.from_pretrained(model_name) | |
return tokenizer, model | |
def load_sensistive_model(): | |
model_name = 'apanc/russian-sensitive-topics' | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
model = BertForSequenceClassification.from_pretrained(model_name); | |
return tokenizer, model | |
def load_inappropriate_model(): | |
model_name = 'apanc/russian-inappropriate-messages' | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
model = BertForSequenceClassification.from_pretrained(model_name); | |
return tokenizer, model | |
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def get_env(name, message): | |
if name in os.environ: | |
return os.environ[name] | |
return input(message) | |
sens_tokenizer, sens_model = load_sensistive_model() | |
inapp_tokenizer, inapp_model = load_inappropriate_model() | |
toxic_tokenizer, toxic_model = load_toxic_model() | |
import json | |
with open("id2topic.json") as f: | |
target_vaiables_id2topic_dict = json.load(f) | |
def adjust_multilabel(y, is_pred = False): | |
y_adjusted = [] | |
for y_c in y: | |
index = str(int(np.argmax(y_c))) | |
if index == '0': | |
continue | |
y_adjusted.append(target_vaiables_id2topic_dict[index]) | |
return y_adjusted | |
# These example values won't work. You must get your own api_id and | |
# api_hash from https://my.telegram.org, under API Development. | |
api_id = REDACTED | |
api_hash = REDACTED | |
# Telethon client | |
client = TelegramClient('wuzabot', api_id, api_hash, device_model="Linux") | |
client.start() | |
me = client.get_me() | |
def predict_toxic(text): | |
batch = toxic_tokenizer.encode(text[:511], return_tensors='pt') | |
output = toxic_model(batch) | |
y_pred = np.argmax(output.logits.detach().numpy(), axis=1) | |
return bool(y_pred[0]) | |
def predict_inapp(text): | |
tokenized = inapp_tokenizer.batch_encode_plus( | |
[text[:511]], | |
max_length = 512, | |
truncation=True, | |
return_token_type_ids=False | |
) | |
tokens_ids, mask = torch.tensor(tokenized['input_ids']), torch.tensor(tokenized['attention_mask']) | |
model_output = inapp_model(tokens_ids, mask) | |
return bool(torch.argmax(model_output['logits'], dim = 1)[0]) | |
def predict_sens_topics(text): | |
tokenized = sens_tokenizer.batch_encode_plus( | |
[text[:511]], | |
max_length = 512, truncation=True, return_token_type_ids=False) | |
tokens_ids, mask = torch.tensor(tokenized['input_ids']),torch.tensor(tokenized['attention_mask']) | |
with torch.no_grad(): | |
model_output = sens_model(tokens_ids, mask) | |
preds = adjust_multilabel(model_output['logits'], is_pred = True) | |
return preds | |
async def process_message(message): | |
if not message.text: | |
return | |
is_toxic = predict_toxic(message.text) | |
labels = predict_sens_topics(message.text) | |
is_inapp = predict_inapp(message.text) | |
if message.sender.id == 87677941 and (is_toxic or is_inapp): | |
print("Ivan", "shit him", message.id, message.text) | |
reaction = is_toxic and ReactionEmoji("💩") or ReactionEmoji("🤮") | |
await client(SendReactionRequest( | |
peer=message.peer_id, | |
msg_id=message.id, | |
reaction=[ReactionEmoji("💩")], | |
)) | |
await asyncio.sleep(random.randrange(1,4)) | |
if is_toxic: | |
labels.append('toxic') | |
if is_inapp: | |
labels.append('inappropriate') | |
print("Message", | |
message.sender.username, | |
message.sender.id, | |
message.text, | |
labels) | |
# else: | |
# print("Non-toxic message", | |
# message.sender.username, | |
# message.sender.id, | |
# message.text, | |
# output.logits.detach().numpy() | |
# ) | |
open("/tmp/shit_last", "w").write(str(message.id)) | |
try: | |
offset_id = int(open("/tmp/shit_last", "r").read()) | |
except: | |
offset_id = None | |
async def main(): | |
if offset_id: | |
async for message in client.iter_messages(-1001085244538, reverse=True, offset_id=offset_id): | |
await process_message(message=message) | |
@client.on(events.MessageDeleted(chats=[-1001085244538])) | |
async def on_delete(event): | |
print(event.to_json()) | |
@client.on(events.NewMessage(chats=[-1001085244538])) | |
async def handler(event): | |
await process_message(message=event.message) | |
client.loop.run_until_complete(main()) | |
# client.run_until_disconnected() | |
client.disconnect() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment