Created
January 23, 2017 06:11
-
-
Save ivannp/9650d6326960cc012e63f4d26a7a5d9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cntk | |
import cntk.ops as C | |
import numpy as np | |
from sklearn.preprocessing import OneHotEncoder | |
from cntk.blocks import default_options, Input # Building blocks | |
from cntk.initializer import glorot_uniform | |
from cntk.layers import Dense # Layers | |
from cntk.learner import sgd, learning_rate_schedule, UnitType | |
from cntk.utils import get_train_eval_criterion, get_train_loss, ProgressPrinter, get_train_loss | |
class CntkClassifier: | |
def __init__(self, hidden_layers = None, batch_size ='auto', learning_rate = 0.001, num_passes = 1, display_step = 1, verbose = True): | |
self.hidden_layer_sizes = hidden_layers | |
self.batch_size = batch_size | |
self.learning_rate_init = learning_rate | |
self.num_passes = num_passes | |
#Run the trainer on and perform model training | |
self.training_progress_output_freq = display_step | |
self.verbose = verbose | |
# Defines a utility that prints the training progress | |
def print_training_progress(self, trainer, mb, frequency): | |
training_loss = "NA" | |
eval_error = "NA" | |
if mb % frequency == 0: | |
training_loss = get_train_loss(trainer) | |
eval_error = get_train_eval_criterion(trainer) | |
if self.verbose: | |
print ("Minibatch: {0}, Loss: {1:.4f}, Error: {2:.2f}%".format(mb + 1, training_loss, eval_error*100)) | |
return mb, training_loss, eval_error | |
def fit(self, x, y): | |
if len(y.shape) == 1: | |
y = np.reshape(y, (-1,1)) | |
# Map the y's to [0,nlevels) | |
self.classes_ = np.sort(np.unique(y)) | |
yz = np.searchsorted(self.classes_, y) | |
# One hot encode them | |
self.ohe = OneHotEncoder(n_values=len(self.classes_), sparse=False) | |
yy = self.ohe.fit_transform(yz) | |
# Build the classifier | |
input = Input(x.shape[1]) | |
label = Input(yy.shape[1]) | |
if self.hidden_layer_sizes is None: | |
self.hidden_layer_sizes = np.full(2, x.shape[1]*2) | |
hh = input | |
for ii in range(len(self.hidden_layer_sizes)): | |
hh = Dense(self.hidden_layer_sizes[ii], init = glorot_uniform(), activation = C.relu)(hh) | |
hh = Dense(yy.shape[1], init = glorot_uniform(), activation = None)(hh) | |
loss = C.cross_entropy_with_softmax(hh, label) | |
label_error = C.classification_error(hh, label) | |
lr_per_minibatch = learning_rate_schedule(self.learning_rate_init, UnitType.minibatch) | |
trainer = cntk.Trainer(hh, loss, label_error, [sgd(hh.parameters, lr=lr_per_minibatch)]) | |
if isinstance(self.batch_size, str): | |
if self.batch_size == 'auto': | |
self.batch_size = min(200, x.shape[0]) | |
else: | |
raise ValueError("'auto' is the only acceptable string for batch_size") | |
num_batches = x.shape[0] // self.batch_size | |
# print(x.shape) | |
# print(yy.shape) | |
# Train our neural network | |
tf = np.array_split(x, num_batches) | |
tl = np.array_split(yy, num_batches) | |
for ii in range(num_batches*self.num_passes): | |
features = np.ascontiguousarray(tf[ii % num_batches]) | |
labels = np.ascontiguousarray(tl[ii % num_batches]) | |
# Specify the mapping of input variables in the model to actual minibatch data to be trained with | |
trainer.train_minibatch({input : features, label : labels}) | |
# Some reporting | |
batchsize, loss, error = self.print_training_progress(trainer, ii, self.training_progress_output_freq) | |
self.input = input | |
self.label = label | |
self.model = hh | |
def predict(self, xx): | |
probs = self.predict_proba(xx) | |
return self.classes_[np.argmax(probs, 1)] | |
def predict_proba(self, xx): | |
out = C.softmax(self.model) | |
res = np.squeeze(out.eval({self.input: xx})) | |
# Add a dimension if we squeezed too much | |
if len(res.shape) == 1: | |
res = np.reshape(res, (1,-1)) | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment