Last active
May 25, 2023 21:16
-
-
Save avidale/7abc1aa027afd69f6b50eaf7527ed294 to your computer and use it in GitHub Desktop.
BERT-toxicity-classification.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "BERT-toxicity-classification.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOlRuBAiOilI73kwx/pYsaT", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/avidale/7abc1aa027afd69f6b50eaf7527ed294/bert-toxicity-classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "T7mxnesuCO7T" | |
}, | |
"source": [ | |
"Этот блокнот позволяет поиграться с моделькой, классифицирующей тексты как токсичные или неполиткорректные. \n", | |
"\n", | |
"Всё самое весёлое - в последней ячейке. \n", | |
"\n", | |
"https://huggingface.co/cointegrated/rubert-tiny-toxicity" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aJCuD52e7HSn" | |
}, | |
"source": [ | |
"!pip install transformers sentencepiece --quiet" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mkJuRwEp7Kmn" | |
}, | |
"source": [ | |
"import torch\n", | |
"from transformers import AutoTokenizer, AutoModelForSequenceClassification" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zgJAWNaz7RBO" | |
}, | |
"source": [ | |
"model_checkpoint = 'cointegrated/rubert-tiny-toxicity'\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", | |
"model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n", | |
"if torch.cuda.is_available():\n", | |
" model.cuda()" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ggRV-Z5R7Z34", | |
"outputId": "cbbe368c-8a2f-4d69-c45f-8382a3c5a45f" | |
}, | |
"source": [ | |
"def text2toxicity(text, aggregate=True):\n", | |
" \"\"\" Calculate toxicity of a text (if aggregate=True) or a vector of toxicity aspects (if aggregate=False)\"\"\"\n", | |
" with torch.no_grad():\n", | |
" inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(model.device)\n", | |
" proba = torch.sigmoid(model(**inputs).logits).cpu().numpy()\n", | |
" if isinstance(text, str):\n", | |
" proba = proba[0]\n", | |
" if aggregate:\n", | |
" return 1 - proba.T[0] * (1 - proba.T[-1])\n", | |
" return proba\n", | |
"\n", | |
"print(text2toxicity('я люблю нигеров', True))\n", | |
"# 0.57240640889815\n", | |
"\n", | |
"print(text2toxicity('я люблю нигеров', False))\n", | |
"# [9.9336821e-01 6.1555761e-03 1.2781911e-03 9.2758919e-04 5.6955177e-01]\n", | |
"\n", | |
"print(text2toxicity(['я люблю нигеров', 'я люблю африканцев'], True))\n", | |
"# [0.5724064 0.20111847]\n", | |
"\n", | |
"print(text2toxicity(['я люблю нигеров', 'я люблю африканцев'], False))\n", | |
"# [[9.9336821e-01 6.1555761e-03 1.2781911e-03 9.2758919e-04 5.6955177e-01]\n", | |
"# [9.9828428e-01 1.1138428e-03 1.1492912e-03 4.6551935e-04 1.9974548e-01]]" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0.9350118728093193\n", | |
"[0.9715758 0.0180863 0.0045551 0.00189755 0.9331106 ]\n", | |
"[0.93501186 0.04156357]\n", | |
"[[9.7157580e-01 1.8086294e-02 4.5550885e-03 1.8975559e-03 9.3311059e-01]\n", | |
" [9.9979788e-01 1.9048342e-04 1.5297388e-04 1.7452303e-04 4.1369814e-02]]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "qhleF3BY77bq", | |
"outputId": "ca4156e1-cbce-4143-9f68-478372ab3384" | |
}, | |
"source": [ | |
"%%time\n", | |
"print(text2toxicity('Иди ты нафиг!'))" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0.4770178304282737\n", | |
"CPU times: user 9.94 ms, sys: 147 µs, total: 10.1 ms\n", | |
"Wall time: 17.5 ms\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "a6G0bvwJ7gg2" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 10, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment