This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MetaLearner: | |
""" | |
This is nothing more than a regular learning flow. However, we create this | |
class, as we plan on using separate (meta-)learners for each task. | |
""" | |
def __init__(self, | |
model:torch.nn.Module, | |
loss_fn:Callable, | |
optimizer): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Reptile: | |
""" | |
Repile-optimization as described by Ravi,et.al. (https://openreview.net/pdf?id=rJY0-Kcl) | |
""" | |
def __init__(self, | |
model:torch.nn.Module, | |
metalearners:List[MetaLearner]): | |
self.n_tasks = len(metalearners) | |
self.model = model | |
self.metalearners = metalearners |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.train.MonitoredTrainingSession(master=server.target,\ | |
is_chief=is_chiefing, | |
checkpoint_dir=arsg['save_dir'],\ | |
hooks=hooks,\ | |
save_checkpoint_secs=600.) as mon_sess: | |
tf_feed = ctx.get_data_feed(train_mode=True) | |
step = 0 | |
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']: | |
batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size'])) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.device(tf.train.replica_device_setter( | |
worker_device="/job:worker/task:%d" % task_index, | |
cluster=cluster)): | |
def build_model(): | |
model_input = tf.placeholder(tf.float32,\ | |
[None,args['num_features'] ]) | |
model_labels = tf.placeholder(tf.float32, [None, args['num_classes'] ]) | |
logits = tf.keras.layers.Dense(args['num_classes'])(model_input) | |
model_output = tf.nn.softmax(logits) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if mode == "train": | |
cluster.train(dataRDD, epochs) | |
else: | |
labelRDD = cluster.inference(dataRDD) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def map_fun(args, ctx): | |
try: | |
import tensorflow as tf | |
#utils | |
from datetime import datetime | |
import time | |
import logging | |
import numpy as np | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
from tensorflowonspark import TFCluster, TFNode | |
spark = SparkSession \ | |
.builder \ | |
.config("...") | |
.appName("model-training") \ | |
.getOrCreate() | |
spark.sparkContext.addPyFile("/usr/local/tensorflow/tfspark-{version}.zip") | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def map_fun(args, ctx): | |
worker_num = ctx.worker_num | |
job_name = ctx.job_name | |
task_index = ctx.task_index | |
cluster, server = ctx.start_cluster_server(1) | |
if job_name == "ps": | |
server.join() | |
elif job_name == "worker": | |
is_chiefing = (task_index == 0) | |
with tf.device(tf.train.replica_device_setter( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cluster_spec = tf.train.ClusterSpec({'worker' : ['localhost:2222']}) | |
server = tf.train.Server(cluster_spec) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def RosenbrockOpt(optimizer,MAX_EPOCHS = 4000, MAX_STEP = 100): | |
''' | |
returns distance of each step*MAX_STEP w.r.t minimum (1,1) | |
''' | |
x1_data = tf.Variable(initial_value=tf.random_uniform([1], minval=-3, maxval=3,seed=0),name='x1') | |
x2_data = tf.Variable(initial_value=tf.random_uniform([1], minval=-3, maxval=3,seed=1), name='x2') | |
y = tf.add(tf.pow(tf.subtract(1.0, x1_data), 2.0), | |
tf.multiply(100.0, tf.pow(tf.subtract(x2_data, tf.pow(x1_data, 2.0)), 2.0)), 'y') |
NewerOlder