Created
November 21, 2016 10:55
-
-
Save thvasilo/67bcb9370b03971f380ae43c4ae6e2d0 to your computer and use it in GitHub Desktop.
A basic online SGD using the Flink stream API.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Licensed to the Apache Software Foundation (ASF) under one or more | |
* contributor license agreements. See the NOTICE file distributed with | |
* this work for additional information regarding copyright ownership. | |
* The ASF licenses this file to You under the Apache License, Version 2.0 | |
* (the "License"); you may not use this file except in compliance with | |
* the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package se.sics.quickstart; | |
import org.apache.flink.api.common.functions.MapFunction; | |
import org.apache.flink.api.common.state.ValueState; | |
import org.apache.flink.api.common.state.ValueStateDescriptor; | |
import org.apache.flink.api.common.typeinfo.TypeHint; | |
import org.apache.flink.api.common.typeinfo.TypeInformation; | |
import org.apache.flink.api.java.utils.ParameterTool; | |
import org.apache.flink.configuration.Configuration; | |
import org.apache.flink.streaming.api.TimeCharacteristic; | |
import org.apache.flink.streaming.api.datastream.DataStream; | |
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; | |
import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction; | |
import org.apache.flink.streaming.api.functions.co.CoMapFunction; | |
import org.apache.flink.streaming.api.functions.source.SourceFunction; | |
import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction; | |
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; | |
import org.apache.flink.util.Collector; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
/** | |
* Skeleton for incremental machine learning algorithm consisting of a | |
* pre-computed model, which gets updated for the new inputs and new input data | |
* for which the job provides predictions. | |
* | |
* <p> | |
* This may serve as a base of a number of algorithms, e.g. updating an | |
* incremental Alternating Least Squares model while also providing the | |
* predictions. | |
* | |
* <p> | |
* This example shows how to use: | |
* <ul> | |
* <li>Connected streams | |
* <li>CoFunctions | |
* <li>Tuple data types | |
* </ul> | |
*/ | |
public class IncrementalLearning { | |
// ************************************************************************* | |
// PROGRAM | |
// ************************************************************************* | |
public static void main(String[] args) throws Exception { | |
// Checking input parameters | |
final ParameterTool params = ParameterTool.fromArgs(args); | |
Double learningRate = params.has("learningRate") ? new Double(params.get("learningRate")) : 0.001; | |
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); | |
env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime); | |
// env.setParallelism(1); | |
// To simplify we make the assumption that the last element in each line is the dependent variable | |
DataStream<ArrayList<Double>> trainingData = env.readTextFile(params.get("training")) | |
.map(new VectorExtractor()); | |
DataStream<ArrayList<Double>> newData = env.readTextFile(params.get("test")) | |
.map(new VectorExtractor()); | |
// build new model on every second of new data | |
DataStream<ArrayList<Double>> model = trainingData | |
.countWindowAll(Integer.parseInt(params.get("batchsize"))) | |
.apply(new PartialModelBuilder(learningRate, Integer.parseInt(params.get("dimensions")))); | |
// model.print(); | |
// use partial model for newData | |
DataStream<Double> errors = newData.connect(model).flatMap(new Evaluator()); | |
errors.print(); | |
// emit result | |
// if (params.has("output")) { | |
// prediction.writeAsText(params.get("output")); | |
// } else { | |
// System.out.println("Printing result to stdout. Use --output to specify output path."); | |
// prediction.print(); | |
// } | |
// execute program | |
env.execute("Streaming Incremental Learning"); | |
} | |
// ************************************************************************* | |
// USER FUNCTIONS | |
// ************************************************************************* | |
/** | |
* Feeds new data for newData. By default it is implemented as constantly | |
* emitting the Integer 1 in a loop. | |
*/ | |
public static class FiniteNewDataSource implements SourceFunction<Integer> { | |
private static final long serialVersionUID = 1L; | |
private int counter; | |
private String filepath; | |
public FiniteNewDataSource(int counter, String filepath) { | |
this.counter = counter; | |
this.filepath = filepath; | |
} | |
@Override | |
public void run(SourceContext<Integer> ctx) throws Exception { | |
Thread.sleep(15); | |
while (counter < 50) { | |
ctx.collect(getNewData()); | |
} | |
} | |
@Override | |
public void cancel() { | |
// No cleanup needed | |
} | |
private Integer getNewData() throws InterruptedException { | |
Thread.sleep(5); | |
counter++; | |
return 1; | |
} | |
} | |
private static Double predict(ArrayList<Double> model, ArrayList<Double> example) { | |
Double prediction = 0.0; | |
for (int i = 0; i < model.size(); i++) { | |
prediction += model.get(i) * example.get(i); | |
} | |
return prediction; | |
} | |
/** | |
* Builds up-to-date partial models on new training data. | |
*/ | |
public static class PartialModelBuilder extends RichAllWindowFunction<ArrayList<Double>, ArrayList<Double>, GlobalWindow> { | |
public PartialModelBuilder(Double learningRate, int dimensions) { | |
this.learningRate = learningRate; | |
this.dimensions = dimensions; | |
} | |
private Double learningRate; | |
private int dimensions; | |
private int applyCount = 0; | |
private static final long serialVersionUID = 1L; | |
private transient ValueState<ArrayList<Double>> modelState; | |
@Override | |
public void open(Configuration config) { | |
ArrayList<Double> allZeroes = new ArrayList<>(Collections.nCopies(dimensions, 0.0)); | |
// obtain key-value state for prediction model | |
// TODO: Do random assignment of weights instead of all zeros? | |
ValueStateDescriptor<ArrayList<Double>> descriptor = | |
new ValueStateDescriptor<>( | |
// state name | |
"modelState", | |
// type information of state | |
TypeInformation.of(new TypeHint<ArrayList<Double>>() {}), | |
// default value of state | |
allZeroes); | |
modelState = getRuntimeContext().getState(descriptor); | |
} | |
private Double squaredError(Double truth, Double prediction) { | |
return 0.5 * (truth - prediction) * (truth - prediction); | |
} | |
private ArrayList<Double> buildPartialModel(Iterable<ArrayList<Double>> trainingBatch) throws Exception{ | |
int batchSize = 0; | |
ArrayList<Double> regressionModel = modelState.value(); | |
ArrayList<Double> gradientSum = new ArrayList<>(Collections.nCopies(dimensions, 0.0)); | |
for (ArrayList<Double> sample : trainingBatch) { | |
batchSize++; | |
Double truth = sample.get(sample.size() - 1); | |
Double prediction = predict(regressionModel, sample); | |
Double error = squaredError(truth, prediction); | |
Double derivative = prediction - truth; | |
for (int i = 0; i < regressionModel.size(); i++) { | |
Double weightGradient = derivative * sample.get(i); | |
Double currentSum = gradientSum.get(i); | |
gradientSum.set(i, currentSum + weightGradient); | |
} | |
} | |
for (int i = 0; i < regressionModel.size(); i++) { | |
Double oldWeight = regressionModel.get(i); | |
Double currentLR = learningRate / Math.sqrt(applyCount); | |
Double change = currentLR * (gradientSum.get(i) / batchSize); | |
regressionModel.set(i, oldWeight - change); | |
} | |
return regressionModel; | |
} | |
@Override | |
public void apply(GlobalWindow window, Iterable<ArrayList<Double>> values, Collector<ArrayList<Double>> out) throws Exception { | |
this.applyCount++; | |
ArrayList<Double> updatedModel = buildPartialModel(values); | |
modelState.update(updatedModel); | |
out.collect(updatedModel); | |
} | |
} | |
public static class Evaluator implements CoFlatMapFunction<ArrayList<Double>, ArrayList<Double>, Double> { | |
ArrayList<Double> partialModel = null; | |
@Override | |
public void flatMap1(ArrayList<Double> example, Collector<Double> out) throws Exception { | |
// System.out.format("Example: %f\n", example.get(0)); | |
if (partialModel != null) { | |
System.out.println("Model was not null!"); | |
Double prediction = predict(partialModel, example); | |
Double error = example.get(example.size() - 1) - prediction; | |
out.collect(error); | |
} else { | |
System.out.println("Model was null!"); | |
} | |
} | |
@Override | |
public void flatMap2(ArrayList<Double> curModel, Collector<Double> out) throws Exception { | |
partialModel = curModel; | |
out.collect(partialModel.get(0)); | |
} | |
} | |
public static class Predictor implements CoMapFunction<ArrayList<Double>, ArrayList<Double>, Double> { | |
ArrayList<Double> partialModel = null; | |
@Override | |
public Double map1(ArrayList<Double> example) throws Exception { | |
if (partialModel != null) { | |
System.out.println("Partial model ready!"); | |
return predict(partialModel, example); | |
} else { | |
return Double.NaN; | |
} | |
} | |
@Override | |
public Double map2(ArrayList<Double> curModel) throws Exception { | |
partialModel = curModel; | |
return -1.0; | |
} | |
} | |
public static class VectorExtractor implements MapFunction<String, ArrayList<Double>> { | |
@Override | |
public ArrayList<Double> map(String s) throws Exception { | |
String[] elements = s.split(","); | |
ArrayList<Double> doubleElements = new ArrayList<>(elements.length); | |
for (int i = 0; i < elements.length; i++) { | |
doubleElements.add(new Double(elements[i])); | |
} | |
return doubleElements; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment