This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ( | |
... | |
"database/sql" | |
"database/sql/driver" | |
"github.com/marcboeker/go-duckdb" | |
... | |
) | |
// the driver struct wrapping the duckdb connection | |
type DuckDBDriver struct { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type DAO struct { | |
driver DataDriver | |
} | |
// the data we want to load | |
QUsersTable := "CREATE TABLE users AS SELECT name, last_name, cast(age as integer)" | |
+ "FROM read_parquet('%s') " | |
// the parquet files location |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
case class PredictionResult(key: String, ts:Timestamp, | |
label: Float, prediction: Float, ratio: Float) | |
private def predictXGBBooster(app_id: String, booster: Booster, | |
predictSeq: Seq[FeaturesRecord]): | |
Try[Seq[PredictionResult]] = Try { | |
val forecastedVal = booster.predict(predictSeq.toDMatrix) | |
predictSeq.zip(forecastedVal).map { case (FeaturesRecord(_, ts, _, label), forecast) => |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
implicit class DMatrixConverter(seq: Seq[FeaturesRecord]) { | |
def toDMatrix: DMatrix = { | |
val labeledPoints = seq.map { case FeaturesRecord(_, _, features, label) => | |
LabeledPoint(label, features.size, null, features.toArray) | |
} | |
new DMatrix(labeledPoints.iterator) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private def trainXGBBooster(trainSeq: Seq[FeaturesRecord]): | |
Try[Booster] = Try { | |
XGBoost.train(trainSeq.toDMatrix, | |
Map("eta" -> 0.1f, "max_depth" -> 4, "objective" -> "reg:squarederror"), 50) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private def getForecastDatasets(app_id: String, | |
recordsIterator: Seq[FeaturesRecord]): | |
Try[(Seq[FeaturesRecord], Seq[FeaturesRecord])] = Try { | |
val dataPoints = recordsIterator.sortBy(_.ts.getTime) | |
// ignoring the last observation as it might be be partial | |
val trainSeq = dataPoints.slice(0, dataPoints.length - 3) | |
val actualValSeq = dataPoints.slice(dataPoints.length - 3, dataPoints.length - 1) | |
(trainSeq, actualValSeq) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
case class PredictionResult(key: String, ts:Timestamp, | |
label: Float, prediction: Float, ratio: Float) | |
def predict(appId:String, recordsIter:Iterator[FeaturesRecord]): | |
Seq[PredictionResult] = { | |
val predDF = for { | |
(trainSeq, actualValSeq) <- getForecastDataset(appId, recordsIter.toSeq) | |
booster <- trainXGBBooster(trainSeq) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
case class FeaturesRecord(key: String, ts:Timestamp, | |
features: Seq[Float], label: Float) | |
private def getFeaturesDataFrame(df: DataFrame): Try[Dataset[FeaturesRecord]] = Try { | |
df.map(row => { | |
val key = row.getAs[String]("app_id") | |
val label = row.getAs[Int]("installs").toFloat | |
val ts = row.getAs[Timestamp]("event_hour") | |
val dayOfWeek = row.getAs[Int]("day_of_week").toFloat | |
val hourOfDay = row.getAs[Int]("hour_of_day").toFloat |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def getStreamTopology(inputTopic:String):Topology = { | |
val builder = new StreamsBuilder() | |
val reqStream = builder.stream[String, PredictRequest](inputTopic) | |
reqStream | |
.map( (_, request) => { | |
Classifier.predict(request.recordID, request.featuresVector) | |
}) | |
.split() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private def getInputVector(rawVector:Seq[Float]): DMatrix = { | |
val nRows = 1 | |
val nCols = rawVector.length | |
val missingVal = Float.NaN | |
new DMatrix(rawVector.toArray[Float], nRows, nCols, missingVal) | |
} | |
def predict(recordID:String, features:Seq[Float]): (String, Float) = { | |
val xgbInput = getInputVector(features) |