Skip to content

Instantly share code, notes, and snippets.

@nlpjoe
Last active June 9, 2020 02:49
Show Gist options
  • Select an option

  • Save nlpjoe/ccbb7b025d57ca6c8ca7fcc7ea3cbc70 to your computer and use it in GitHub Desktop.

Select an option

Save nlpjoe/ccbb7b025d57ca6c8ca7fcc7ea3cbc70 to your computer and use it in GitHub Desktop.
[keras utils] #python #ml

JZTrainCategory.py训练策略

import keras.backend as K
import keras 
import numpy as np
import warnings


class JZTrainCategory(keras.callbacks.Callback):
    def __init__(self, filepath, monitor='val_loss', factor=0.1, verbose=1,
                     save_weights_only=False,
                     mode='auto', period=1):
        super(JZTrainCategory, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.factor = factor
        self.save_weights_only = save_weights_only

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('ModelCheckpoint mode %s is unknown, '
                          'fallback to auto mode.' % (mode),
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['lr'] = K.get_value(self.model.optimizer.lr)
        
        filepath = self.filepath.format(epoch=epoch + 1, **logs)
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn('Can save best model only with %s available, '
                          'skipping.' % (self.monitor), RuntimeWarning)
        else:
            if self.monitor_op(current, self.best):
                if self.verbose > 0:
                    print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                          ' saving model to %s'
                          % (epoch + 1, self.monitor, self.best,
                             current, filepath))
                self.best = current
                if self.save_weights_only:
                    self.model.save_weights(filepath, overwrite=True)
                else:
                    self.model.save(filepath, overwrite=True)
            else:
                if self.verbose > 0:
                    print('\nEpoch %05d: %s did not improve from %0.5f' %
                          (epoch + 1, self.monitor, self.best))
                    self.model.load_weights(filepath)
                    # set new learning rate
                    old_lr = K.get_value(self.model.optimizer.lr)
                    new_lr = old_lr * self.factor
                    K.set_value(self.model.optimizer.lr, new_lr)
                    print('\nReload model and decay learningrate from {} to {}\n'.format(old_lr, new_lr))

debug需要 numpy转tensor

x = np.random.random((200, 300))
sess = tf.Session()
with sess.as_default():
    tensor = tf.constant(x)
    print(tensor)
    numpy_array_2 = tensor.eval()
    print(numpy_array_2)
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
def plot_loss(self, H):
# grab the history object dictionary
H = H.history
# plot the training loss and accuracy
N = np.arange(0, len(H["loss"]))
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H["loss"], label="train_loss")
plt.plot(N, H["val_loss"], label="test_loss")
plt.plot(N, H["acc"], label="train_acc")
plt.plot(N, H["val_acc"], label="test_acc")
plt.title("xxx")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
# save the figure
os.makedirs('loss', exist_ok=True)
plt.savefig('loss.png')
plt.close()
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
#config.gpu_options.per_process_gpu_memory_fraction = 0.3
tf_config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment