Last active
April 1, 2017 09:15
-
-
Save jiqiujia/665335474cb5c372f0219ac1ac6c1bd3 to your computer and use it in GitHub Desktop.
keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#adjust learning rate policy by callbacks | |
def scheduler(epoch): | |
if epoch == 5: | |
model.lr.set_value(.02) | |
return model.lr.get_value() | |
change_lr = LearningRateScheduler(scheduler) | |
model.fit(x_embed, y, nb_epoch=1, batch_size = batch_size, show_accuracy=True, | |
callbacks=[chage_lr]) | |
#adapt gpu memory usage when using tensorflow as backend | |
import os | |
import tensorflow as tf | |
import keras.backend.tensorflow_backend as KTF | |
def get_session(gpu_fraction=0.3): | |
'''Assume that you have 6GB of GPU memory and want to allocate ~2GB''' | |
num_threads = os.environ.get('OMP_NUM_THREADS') | |
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction) | |
if num_threads: | |
return tf.Session(config=tf.ConfigProto( | |
gpu_options=gpu_options, intra_op_parallelism_threads=num_threads)) | |
else: | |
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) | |
KTF.set_session(get_session()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment