Skip to content

Instantly share code, notes, and snippets.

View jumpingfella's full-sized avatar

Stepan Samarin jumpingfella

View GitHub Profile
@jumpingfella
jumpingfella / RecurrentNets.java
Created May 28, 2018 16:23
LSTM + RNN network for time series
public class RecurrentNets {
private static final double learningRate = 0.1;
private static final int seed = 12345;
private static final int nHidden = 200;
private static final int truncatedBPTTLength = 22;
public static MultiLayerNetwork buildLstmNetworks(int nIn, int nOut) {
//Set up network configuration:
@jumpingfella
jumpingfella / RegressionRNN.java
Created May 25, 2018 14:17
1..100 deeplearning4j RNN example
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;