Last active
July 3, 2018 16:20
-
-
Save MattFanto/226772faafe7b24d38b1c39370c292b8 to your computer and use it in GitHub Desktop.
Group by sliding window in tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
length = int(X.shape[0] / step) - window_size | |
Xt = np.empty((length, window_size*len(GCLOUD_SENSOR_COLS))) | |
for i in range(length): | |
Xt[i] = X[i*step:i*step+window_size].ravel() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Welcome to the Java Apach Beam version!! | |
*/ | |
package com.dermatrack.dataflow.preprocessing; | |
import org.apache.beam.sdk.Pipeline; | |
import org.apache.beam.sdk.coders.*; | |
import org.apache.beam.sdk.extensions.gcp.storage.GcsCreateOptions; | |
import org.apache.beam.sdk.io.*; | |
import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; | |
import org.apache.beam.sdk.io.fs.MatchResult; | |
import org.apache.beam.sdk.io.fs.ResourceId; | |
import org.apache.beam.sdk.io.gcp.pubsub.PubsubOptions; | |
import org.apache.beam.sdk.options.*; | |
import org.apache.beam.sdk.transforms.Create; | |
import org.apache.beam.sdk.transforms.DoFn; | |
import org.apache.beam.sdk.transforms.ParDo; | |
import org.apache.beam.sdk.util.CoderUtils; | |
import org.apache.beam.sdk.util.MimeTypes; | |
import org.apache.beam.sdk.values.PCollection; | |
import java.io.ByteArrayOutputStream; | |
import java.io.IOException; | |
import java.nio.ByteBuffer; | |
import java.nio.channels.ReadableByteChannel; | |
import java.nio.channels.SeekableByteChannel; | |
import java.nio.channels.WritableByteChannel; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.NoSuchElementException; | |
/** | |
* Keep all meged csv files and aplly group by sliding window transformation | |
*/ | |
public class PreProcessing { | |
/** | |
*/ | |
interface Options extends PubsubOptions { | |
@Description("Output file bucket location") | |
@Default.String("gs://dermatrack-mlengine/cleaned_data/regr_filt_augm/v1/w50/") | |
@Validation.Required | |
String getOutputDir(); | |
void setOutputDir(String value); | |
@Description("Size of the sliding window") | |
@Default.Integer(50) | |
int getWindowSize(); | |
void setWindowSize(int value); | |
@Description("Num step for each sliding window") | |
@Default.Integer(5) | |
int getStepSize(); | |
void setStepSize(int value); | |
} | |
/** | |
* Class utilities to read line | |
*/ | |
private static class LineReader { | |
private ReadableByteChannel channel = null; | |
private long nextLineStart = 0; | |
private long currentLineStart = 0; | |
private final ByteBuffer buf; | |
private static final int BUF_SIZE = 1024; | |
private String currentValue = null; | |
public LineReader(final ReadableByteChannel channel) | |
throws IOException { | |
buf = ByteBuffer.allocate(BUF_SIZE); | |
buf.flip(); | |
boolean removeLine = false; | |
// If we are not at the beginning of a line, we should ignore the current line. | |
if (channel instanceof SeekableByteChannel) { | |
SeekableByteChannel seekChannel = (SeekableByteChannel) channel; | |
if (seekChannel.position() > 0) { | |
// Start from one character back and read till we find a new line. | |
seekChannel.position(seekChannel.position() - 1); | |
removeLine = true; | |
} | |
nextLineStart = seekChannel.position(); | |
} | |
this.channel = channel; | |
if (removeLine) { | |
nextLineStart += readNextLine(new ByteArrayOutputStream()); | |
} | |
} | |
private int readNextLine(final ByteArrayOutputStream out) throws IOException { | |
int byteCount = 0; | |
while (true) { | |
if (!buf.hasRemaining()) { | |
buf.clear(); | |
int read = channel.read(buf); | |
if (read < 0) { | |
break; | |
} | |
buf.flip(); | |
} | |
byte b = buf.get(); | |
byteCount++; | |
if (b == '\n') { | |
break; | |
} | |
out.write(b); | |
} | |
return byteCount; | |
} | |
public boolean readNextLine() throws IOException { | |
currentLineStart = nextLineStart; | |
ByteArrayOutputStream buf = new ByteArrayOutputStream(); | |
int offsetAdjustment = readNextLine(buf); | |
if (offsetAdjustment == 0) { | |
// EOF | |
return false; | |
} | |
nextLineStart += offsetAdjustment; | |
// When running on Windows, each line obtained from 'readNextLine()' will end with a '\r' | |
// since we use '\n' as the line boundary of the reader. So we trim it off here. | |
currentValue = CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), buf.toByteArray()).trim(); | |
return true; | |
} | |
public String getCurrent() { | |
return currentValue; | |
} | |
public long getCurrentLineStart() { | |
return currentLineStart; | |
} | |
} | |
/** | |
* Group list of records into a window, this reduce also multiple labels into a single one as: | |
* reduced_label = n_itch / total_record | |
* @param records | |
* @return | |
*/ | |
private static String recordsToWindow(String[] records) { | |
int itch = 0; | |
StringBuilder window = new StringBuilder(); | |
for (String rec : records) { | |
String[] values = rec.split(","); | |
String label = values[values.length-1]; | |
if (label == "itch") { | |
itch += 1; | |
} | |
window.append(rec.substring(0, rec.length() - label.length())); | |
} | |
window.append(itch / records.length); | |
return window.toString() + "\n"; | |
} | |
public static void main(String[] args) throws Exception { | |
Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); | |
// Enforce that this pipeline is always run in streaming mode. | |
options.setStreaming(true); | |
final int windowLen = options.getWindowSize(); | |
final int stepSize = options.getStepSize(); | |
final String ouputDir = options.getOutputDir(); | |
Pipeline pipeline = Pipeline.create(options); | |
MatchResult ms = FileSystems.match("gs://dermatrack-mlengine/cleaned_data/merged/*.csv"); | |
PCollection<MatchResult.Metadata> filesMetadata = pipeline.apply("RetrieveFileList", Create.of(ms.metadata())); | |
filesMetadata | |
.apply("GroupBySlidingWindow", | |
ParDo.of(new DoFn<MatchResult.Metadata, String>() { | |
@ProcessElement | |
public void processElement(ProcessContext c) throws IOException { | |
System.out.println(c.element()); | |
ReadableByteChannel channel = FileSystems.open(c.element().resourceId()); | |
LineReader lineReader = new LineReader(channel); | |
List<String> windows = new ArrayList<String>(); | |
int outputSize = 0; | |
// First window deserve a special treatment | |
String[] records = new String[windowLen]; | |
for (int i = 0; i < windowLen; i++) { | |
if (!lineReader.readNextLine()) { | |
return; | |
} else { | |
records[i] = lineReader.getCurrent(); | |
outputSize += records[i].getBytes().length; | |
} | |
} | |
windows.add(recordsToWindow(records)); | |
// Following window can be optimize by previously readed records | |
while (true) { | |
String[] new_records = new String[windowLen]; | |
for (int i = 0; i < windowLen - stepSize; i++) { | |
new_records[i] = records[i + stepSize]; | |
} | |
boolean endOfFile = false; | |
for (int i = windowLen - stepSize; i < windowLen; i++) { | |
if (!lineReader.readNextLine()) { | |
endOfFile = true; | |
break; | |
} else { | |
new_records[i] = lineReader.getCurrent(); | |
outputSize += new_records[i].getBytes().length; | |
} | |
} | |
if (!endOfFile) { | |
records = new_records; | |
windows.add(recordsToWindow(records)); | |
} | |
// Not enough record to group into this window just break | |
else { | |
break; | |
} | |
} | |
// Write window to blob | |
String outFileName = ouputDir + c.element().resourceId().getFilename(); | |
WritableByteChannel writeChannel = FileSystems.create( | |
FileSystems.matchNewResource(outFileName, false), GcsCreateOptions.builder().setMimeType(MimeTypes.BINARY).build()); | |
for (String rec : windows) { | |
byte [] record = rec.getBytes(); | |
ByteBuffer byteBuffer = ByteBuffer.allocate(record.length); | |
byteBuffer.put(record); | |
byteBuffer.position(0); | |
writeChannel.write(byteBuffer); | |
} | |
writeChannel.close(); | |
} | |
} | |
) | |
); | |
pipeline.run().waitUntilFinish(); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def filter_overlapping_values(x, window_size): | |
s1 = tf.slice(x[0], [window_size//2, 0], [-1, -1]) | |
s2 = tf.slice(x[1], [0, 0], [window_size//2, -1]) | |
return tf.concat((s1, s2), axis=0) | |
length = 12 | |
components = np.array([[i] for i in range(length)], dtype=np.int64) | |
# components = np.arange(6 * 4, dtype=np.int64).reshape((-1, 4)) | |
dataset = dataset_ops.Dataset.from_tensor_slices(components) | |
window_size = 4 | |
# window consecutive elements with batch | |
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(window_size)) | |
# [[0][1][2][3]] | |
# [[4][5][6][7]] | |
# [[8][9][10][11]] | |
# Skip first row and duplicate all rows, this allows the creation of overlapping window | |
dataset1 = dataset.apply(tf.contrib.data.group_by_window(lambda x: 0, lambda k, d: d.repeat(2), window_size=1)).skip(1) | |
# [[0][1][2][3]] | |
# [[4][5][6][7]] | |
# [[4][5][6][7]] | |
# [[8][9][10][11]] | |
# [[8][9][10][11]] | |
# Use batch to merge duplicate rows into a single row with both value from window(i) and window(i+1) | |
dataset1 = dataset1.apply(tf.contrib.data.batch_and_drop_remainder(2)) | |
# [ [[0][1][2][3]] [[4][5][6][7]] ] | |
# [ [[4][5][6][7]] [[8][9][10][11]] ] | |
# filter with slice only useful values for overlapping windows | |
dataset1 = dataset1.map(lambda x: filter_overlapping_values(x, window_size)) | |
# [[2][3][4][5]] | |
# [[6][7][8][9]] | |
# Now insert overlapping window into the dataset at the right position | |
dataset = tf.data.Dataset.zip((dataset, dataset1)) | |
# x0: [[0][1][2][3]] x1: [[2][3][4][5]] | |
# x0: [[4][5][6][7]] x1: [[6][7][8][9]] | |
# Flat the dataset with original window and the dataset with overlapping window into a single dataset and flat it | |
dataset = dataset.flat_map(lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(tf.data.Dataset.from_tensors(x1))) | |
# [[0][1][2][3]] | |
# [[2][3][4][5]] | |
# [[4][5][6][7]] | |
# [[6][7][8][9]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment