Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save TakuTsuzuki/25541b78a596cbbd6fb0 to your computer and use it in GitHub Desktop.
Save TakuTsuzuki/25541b78a596cbbd6fb0 to your computer and use it in GitHub Desktop.
RandomForestClassifier Component of BriCA1
# coding: utf-8
import numpy as np
import matplotlib.pyplot as plt
from sklearn import ensemble, datasets
import brica1
# RandomForeestClassifier Component Definition
class RandomForestClassifierComponent(brica1.Component):
def __init__(self, n_in):
super(RandomForestClassifierComponent, self).__init__()
self.classifier = ensemble.RandomForestClassifier()
self.make_in_port("in0", n_in)
self.make_out_port("out0", 1)
def fire(self):
x = self.inputs["in0"]
z = self.classifier.predict([x])
self.results["out0"] = z
def fit(self, X, y):
self.classifier.fit(X, y)
# Load iris dataset
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
# Setup data feeder component
feeder = brica1.ConstantComponent()
feeder.make_out_port("out0", 2)
# Setup RandomForestClassfier component
RFC = RandomForestClassifierComponent(2)
RFC.fit(X, y)
# Connect the components
brica1.connect((feeder, "out0"), (RFC, "in0"))
# Add components to module
mod = brica1.Module()
mod.add_component("feeder", feeder)
mod.add_component("RFC", RFC)
# Setup scheduler and agent
s = brica1.VirtualTimeSyncScheduler()
a = brica1.Agent(s)
a.add_submodule("mod", mod)
# Test the classifier
for i in xrange(len(X)):
feeder.set_state("out0", X[i]) # Set data feeder to training data i
a.step() # Execute prediction
print "Actual: {}\tPrediction: {}\t{}".format(y[i], RFC.get_out_port("out0").buffer[0], y[i] == RFC.get_out_port("out0").buffer[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment