Created
March 17, 2012 19:24
-
-
Save axiak/2064499 to your computer and use it in GitHub Desktop.
pybrain milk
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Michael Axiak <[email protected]> | |
import numpy as np | |
from milk.supervised.base import supervised_model | |
def _label_to_map(label, labelMap): | |
val = np.ndarray(len(labelMap)) | |
val[labelMap[label]] = 1 | |
return val | |
def _label_from_map(i, labelMap): | |
return iter((key for key, val in labelMap.items() if val == i)).next() | |
class _nn_learner(object): | |
def __init__(self, hidden_layers=3, max_epochs=15): | |
self.hidden_layers = hidden_layers | |
self.max_epochs = max_epochs | |
def train(self, features, labels, normalisedLabels=False, **kwargs): | |
import pybrain | |
from pybrain.supervised.trainers import BackpropTrainer | |
from pybrain.tools.shortcuts import buildNetwork | |
from pybrain.datasets import SupervisedDataSet | |
input_layers = features.shape[1] | |
if normalisedLabels: | |
output_layers = 1 | |
else: | |
labelMap = {v: i for i, v in enumerate(set(labels))} | |
output_layers = len(labelMap) | |
ds = SupervisedDataSet(input_layers, output_layers) | |
nn = buildNetwork(input_layers, self.hidden_layers, output_layers) | |
for input_val, output_val in zip(features, labels): | |
if normalisedLabels: | |
ds.addSample(input_val, output_val) | |
else: | |
ds.addSample(input_val, _label_to_map(output_val, labelMap)) | |
trainer = BackpropTrainer(nn, ds) | |
trainer.trainUntilConvergence(maxEpochs=self.max_epochs) | |
return _nn_model(nn, labelMap) | |
class _nn_model(supervised_model): | |
def __init__(self, neural_network, labelMap): | |
self.neural_network = neural_network | |
self.label_map = labelMap | |
def apply(self, q): | |
result = np.argmax(self.neural_network.activate(q)) | |
return _label_from_map(result, self.label_map) | |
def nn_learner(hidden_layers, max_epochs=15): | |
from milk.supervised.defaultlearner import feature_selection_simple | |
from milk.supervised.classifier import ctransforms | |
return ctransforms(feature_selection_simple(), | |
_nn_learner(hidden_layers, max_epochs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment