Last active
March 9, 2018 01:00
-
-
Save atremblay/25a7654f844013dcafdbcc7ab60148da to your computer and use it in GitHub Desktop.
Keras SMS Callback
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
from twilio.rest import Client | |
class SMSCallback(tf.keras.callbacks.Callback): | |
""" | |
SMSCallback | |
Because you are not always standing in front of your computer waiting for | |
the next epoch to be completed. This callback will use your Twilio account | |
to send you (or anyone else you want to keep informed) the evolution of | |
the training. Your ACCOUNT_SID and AUTH_TOKEN needs to be in your | |
environment variables. | |
""" | |
def __init__(self, to, sender, period=1): | |
""" | |
Params | |
------ | |
to: str or list(str) | |
Number(s) to send the message to. Format: '+15555555555' | |
sender: str | |
The number from twilio that you want to use. Same format as 'to': '+15555555555' | |
period: int | |
How often to send the SMS | |
""" | |
super(SMSCallback, self).__init__() | |
if isinstance(to, str): | |
self.to = [to] | |
self.sender = sender | |
self.client = Client( | |
os.environ['ACCOUNT_SID'], | |
os.environ['AUTH_TOKEN'] | |
) | |
self.period = period | |
def on_epoch_end(self, epoch, logs={}): | |
if (epoch + 1) % self.period != 0: | |
return | |
msg = "epoch {epoch} | loss: {loss:.4f}, acc: {acc:.4f}" | |
if "val_loss" in logs: | |
msg += " val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}" | |
logs['epoch'] = epoch | |
for receiver in self.to: | |
self.client.messages.create( | |
to=receiver, | |
from_=self.sender, | |
body=msg.format(**logs) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment