Last active
May 2, 2021 17:27
-
-
Save hritik5102/2e259f117c673f9d5a9c728a288ef1a3 to your computer and use it in GitHub Desktop.
BERT Inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
# For Removing this warning : "Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found" | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import logging | |
logging.getLogger('tensorflow').disabled = True | |
# "0" means no logging. # For Removing XLA warnings | |
import tensorflow as tf | |
tf.autograph.set_verbosity(0) | |
import re | |
import bert | |
import numpy as np | |
from tensorflow import keras | |
from bert.loader import StockBertConfig, map_stock_config_to_params, load_stock_weights | |
from bert.tokenization.bert_tokenization import FullTokenizer | |
from bert.model import BertModelLayer | |
class Bert_Classifier: | |
def __init__(self, max_seq_len, lr): | |
_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
self.bert_ckpt_dir = os.path.join(_ROOT, "Models","BERT_Fake_News_Classification",".model","uncased_L-12_H-768_A-12") | |
self.bert_ckpt_file = os.path.join( | |
self.bert_ckpt_dir, "bert_model.ckpt") | |
self.bert_config_file = os.path.join( | |
self.bert_ckpt_dir, "bert_config.json") | |
self.max_seq_len = max_seq_len | |
self.lr = lr | |
self.bert_weight = os.path.join(_ROOT,"Models","BERT_Fake_News_Classification","pretrained_weights","bert_news.h5") | |
def clean_txt(self, text): | |
text = re.sub("'", "", text) | |
text = re.sub("(\\W)+", " ", text) | |
text = text.lower() | |
return text | |
def get_split(self, text): | |
""" | |
Split each news text to subtexts no longer than 150 words. | |
""" | |
l_total = [] | |
l_parcial = [] | |
if len(text.split())//120 > 0: | |
n = len(text.split())//120 | |
else: | |
n = 1 | |
for w in range(n): | |
if w == 0: | |
l_parcial = text.split()[:150] | |
l_total.append(" ".join(l_parcial)) | |
else: | |
l_parcial = text.split()[w*120:w*120 + 150] | |
l_total.append(" ".join(l_parcial)) | |
return l_total | |
def create_model(self): | |
""" | |
Creates a BERT classification model. | |
The model architecutre is raw input -> BERT input -> drop out layer to prevent overfitting -> dense layer that outputs predicted probability. | |
max_seq_len: the maximum sequence length | |
lr: learning rate of optimizer | |
""" | |
# create the bert layer | |
with tf.io.gfile.GFile(self.bert_config_file, "r") as reader: | |
bc = StockBertConfig.from_json_string(reader.read()) | |
bert_params = map_stock_config_to_params(bc) | |
bert = BertModelLayer.from_params(bert_params, name="bert") | |
input_ids = keras.layers.Input( | |
shape=(self.max_seq_len,), dtype='int32', name="input_ids") | |
output = bert(input_ids) | |
# print("bert shape", output.shape) | |
cls_out = keras.layers.Lambda(lambda seq: seq[:, 0, :])(output) | |
# Dropout layer | |
cls_out = keras.layers.Dropout(0.8)(cls_out) | |
# Dense layer with probibility output | |
logits = keras.layers.Dense(units=2, activation="softmax")(cls_out) | |
model = keras.Model(inputs=input_ids, outputs=logits) | |
model.build(input_shape=(None, self.max_seq_len)) | |
# load the pre-trained model weights | |
load_stock_weights(bert, self.bert_ckpt_file) | |
model.compile(optimizer=keras.optimizers.Adam(learning_rate=self.lr), | |
loss=keras.losses.SparseCategoricalCrossentropy( | |
from_logits=True), | |
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")]) | |
model.summary() | |
return model | |
def predict_new(self, text): | |
""" | |
Predict new document using the trained model. | |
doc: input document in format of a string | |
""" | |
# clean the text | |
doc = self.clean_txt(text) | |
# split the string text into list of subtexts | |
doc = self.get_split(doc) | |
# tokenize the subtexts as well as padding | |
tokenizer = FullTokenizer( | |
vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt")) | |
pred_tokens = map(tokenizer.tokenize, doc) | |
pred_tokens = map(lambda tok: ["[CLS]"] + tok + ["[SEP]"], pred_tokens) | |
pred_token_ids = list( | |
map(tokenizer.convert_tokens_to_ids, pred_tokens)) | |
pred_token_ids = map(lambda tids: tids +[0]*(self.max_seq_len-len(tids)),pred_token_ids) | |
pred_token_ids = np.array(list(pred_token_ids)) | |
# create model and load previous weights | |
model = self.create_model() | |
model.load_weights(self.bert_weight) | |
# predict the subtexts and average the prediction | |
predictions = model.predict(pred_token_ids) | |
predictions = predictions[:, 1] | |
avg_pred = predictions.mean() | |
if avg_pred > 0.5: | |
doc_label = 0 | |
else: | |
doc_label = 1 | |
return doc_label, avg_pred | |
# The following code runs only while testing. | |
if __name__ == "__main__": | |
input_text = "Chinese President Xi Jinping on Friday offered help to India in the fight against the coronavirus pandemic, Chinese state media reported. According to China's state media agency Xinhua, President Xi Jinping also extended condolences to Prime Minister Narendra Modi over the Covid-19 pandemic in India. President Xi Jinping sent a message of condolences to Indian Prime Minister Narendra Modi over the Covid-19 pandemic in the country, Xinhua reported. In his message, President Xi Jinping said China is willing to strengthen anti-pandemic cooperation with India, and provide support and help. This development comes amid strained relations between India and China for nearly a year over the standoff in Eastern Ladakh and violent clashes between the two sides. India is currently reeling under the second wave of coronavirus infections and has been consistently recording over 3,00,000 new cases every day. The death too has mounted as hospitals are facing an acute shortage of oxygen-supported beds, ICU beds, ventilators and medical oxygen.Chinese Ambassador to India, Sun Weidong in a series of tweets said the Chinese President has sent a message of sympathy to Prime Minister Narendra Modi.Chinese President Xi Jinping sends a message of sympathy to Indian Prime Minister Narendra Modi today. President Xi says, I am very concerned about the recent situation of Covid-19 pandemic in India. On behalf of the Chinese Government and people, as well as in my own name, I would like to express sincere sympathies to the Indian Government and people." | |
obj = Bert_Classifier(max_seq_len=150, lr=1e-5) | |
print(obj.predict_new(input_text)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Bash file to download BERT Pretrained weights using Google Cloud Shell
mkdir -p .model .model/uncased_L-12_H-768_A-12 gsutil cp gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_config.json .model/uncased_L-12_H-768_A-12 gsutil cp gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/vocab.txt .model/uncased_L-12_H-768_A-12 gsutil cp gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_model.ckpt.meta .model/uncased_L-12_H-768_A-12 gsutil cp gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_model.ckpt.index .model/uncased_L-12_H-768_A-12 gsutil cp gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001 .model/uncased_L-12_H-768_A-12 ls -la .model .model/uncased_L-12_H-768_A-12 mkdir pretrained_model gdown --id 1pI12nFTsNdcoAY-tEyuybuYgP8kKOILA mv bert_news.h5 ~/pretrained_model