Skip to content

Instantly share code, notes, and snippets.

@agibsonccc
Created August 28, 2014 02:02
Show Gist options
  • Save agibsonccc/c334c8a424ec54a542b7 to your computer and use it in GitHub Desktop.
Save agibsonccc/c334c8a424ec54a542b7 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.models.featuredetectors.rbm;
import static org.junit.Assert.*;
import org.deeplearning4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.linalg.factory.NDArrays;
import org.deeplearning4j.linalg.lossfunctions.LossFunctions;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Created by agibsonccc on 8/27/14.
*/
public class RBMTests {
private static Logger log = LoggerFactory.getLogger(RBMTests.class);
@Test
public void testBasic() {
float[][] data = new float[][]
{
{1,1,1,0,0,0},
{1,0,1,0,0,0},
{1,1,1,0,0,0},
{0,0,1,1,1,0},
{0,0,1,1,0,0},
{0,0,1,1,1,0},
{0,0,1,1,1,0}
};
INDArray input = NDArrays.create(data);
NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
.lossFunction(LossFunctions.LossFunction.RMSE_XENT)
.learningRate(1e-1f).nIn(6).nOut(4).build();
Model rbm = new RBM.Builder().configure(conf).withInput(input).build();
rbm.fit(input,null);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment