Last active
March 31, 2019 16:09
-
-
Save cmantas/fe4d055563ab466b0e912592e9c621a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from random import shuffle | |
import requests | |
from urllib.parse import urlparse, parse_qs | |
import tensorflow as tf | |
tf.logging.set_verbosity(tf.logging.ERROR) | |
import matplotlib.pyplot as plt | |
def read_data(fname, token='#'): | |
with open(fname) as f: | |
lines = f.readlines() | |
rv = [] | |
for l in lines: | |
l = l.strip() | |
elems = l.split(token) | |
if len(elems) != 2: | |
continue | |
pname, cid = elems | |
rv.append((pname, int(cid))) | |
return rv | |
def binarize(data, balance=False, target_cid=40): | |
positive, negative = [], [] | |
for line, cid in data: | |
if cid == target_cid: | |
positive.append((line, True)) | |
else: | |
negative.append((line, False)) | |
shuffle(negative) | |
if balance: | |
negative = negative[:len(positive)] | |
print("Returning %d positive and %d negative examples" % | |
(len(positive), len(negative))) | |
rv = negative + positive | |
shuffle(rv) | |
return rv | |
def vectorize_batch(batch, tokenizer, encoder): | |
texts, cats = zip(*batch) | |
X = tokenizer.texts_to_matrix(texts, mode='count') | |
Y = encoder.transform(cats) | |
return(X, Y) | |
def batcher(phrases, batch_size): | |
for i in range(0, len(phrases), batch_size): | |
frrom = i | |
to = i+batch_size | |
yield phrases[frrom:to] | |
def training_gen(texts, batch_size, tokenizer, label_encoder): | |
while True: | |
shuffle(texts) | |
for batch in batcher(texts, batch_size): | |
X, Y = vectorize_batch(batch, tokenizer, label_encoder) | |
yield (X, Y) | |
def plot_accuracy(history): | |
metrics = { 'acc': 'training accuracy', | |
'val_acc': 'validation accuracy', | |
'sparse_categorical_accuracy': 'training accuracy', | |
'val_sparse_categorical_accuracy': 'validation accuracy' | |
} | |
labels = [] | |
for metric, label in metrics.items(): | |
if metric in history.history: | |
plt.plot(history.history[metric]) | |
labels.append(label) | |
plt.title('Model Metrics') | |
plt.xlabel('epoch') | |
plt.legend(labels, loc='lower right') | |
plt.show() | |
# credit: https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url?answertab=votes#tab-top | |
def get_confirm_token(response): | |
for key, value in response.cookies.items(): | |
if key.startswith('download_warning'): | |
return value | |
return None | |
def save_response_content(response, destination): | |
CHUNK_SIZE = 32768 | |
print("Response" + str(response)) | |
with open(destination, "wb") as f: | |
for chunk in response.iter_content(CHUNK_SIZE): | |
if chunk: # filter out keep-alive new chunks | |
f.write(chunk) | |
def get_drive_id(link): | |
parsed = urlparse(link) | |
return parse_qs(parsed.query)['id'][0] | |
def download_file_from_google_drive_id(id, destination): | |
URL = "https://docs.google.com/uc?export=download" | |
session = requests.Session() | |
response = session.get(URL, params = { 'id' : id }, stream = True) | |
token = get_confirm_token(response) | |
print('Token: '+ str(token)) | |
if token: | |
params = { 'id' : id, 'confirm' : token } | |
response = session.get(URL, params = params, stream = True) | |
save_response_content(response, destination) | |
def download_gdrive_link(link, destination): | |
id = get_drive_id(link) | |
print("id: " + id) | |
download_file_from_google_drive_id(id, destination) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment