This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| case class FeaturesRecord(key: String, ts:Timestamp, | |
| features: Seq[Float], label: Float) | |
| private def getFeaturesDataFrame(df: DataFrame): Try[Dataset[FeaturesRecord]] = Try { | |
| df.map(row => { | |
| val key = row.getAs[String]("app_id") | |
| val label = row.getAs[Int]("installs").toFloat | |
| val ts = row.getAs[Timestamp]("event_hour") | |
| val dayOfWeek = row.getAs[Int]("day_of_week").toFloat | |
| val hourOfDay = row.getAs[Int]("hour_of_day").toFloat |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| case class PredictionResult(key: String, ts:Timestamp, | |
| label: Float, prediction: Float, ratio: Float) | |
| def predict(appId:String, recordsIter:Iterator[FeaturesRecord]): | |
| Seq[PredictionResult] = { | |
| val predDF = for { | |
| (trainSeq, actualValSeq) <- getForecastDataset(appId, recordsIter.toSeq) | |
| booster <- trainXGBBooster(trainSeq) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| private def getForecastDatasets(app_id: String, | |
| recordsIterator: Seq[FeaturesRecord]): | |
| Try[(Seq[FeaturesRecord], Seq[FeaturesRecord])] = Try { | |
| val dataPoints = recordsIterator.sortBy(_.ts.getTime) | |
| // ignoring the last observation as it might be be partial | |
| val trainSeq = dataPoints.slice(0, dataPoints.length - 3) | |
| val actualValSeq = dataPoints.slice(dataPoints.length - 3, dataPoints.length - 1) | |
| (trainSeq, actualValSeq) | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| private def trainXGBBooster(trainSeq: Seq[FeaturesRecord]): | |
| Try[Booster] = Try { | |
| XGBoost.train(trainSeq.toDMatrix, | |
| Map("eta" -> 0.1f, "max_depth" -> 4, "objective" -> "reg:squarederror"), 50) | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| implicit class DMatrixConverter(seq: Seq[FeaturesRecord]) { | |
| def toDMatrix: DMatrix = { | |
| val labeledPoints = seq.map { case FeaturesRecord(_, _, features, label) => | |
| LabeledPoint(label, features.size, null, features.toArray) | |
| } | |
| new DMatrix(labeledPoints.iterator) | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| case class PredictionResult(key: String, ts:Timestamp, | |
| label: Float, prediction: Float, ratio: Float) | |
| private def predictXGBBooster(app_id: String, booster: Booster, | |
| predictSeq: Seq[FeaturesRecord]): | |
| Try[Seq[PredictionResult]] = Try { | |
| val forecastedVal = booster.predict(predictSeq.toDMatrix) | |
| predictSeq.zip(forecastedVal).map { case (FeaturesRecord(_, ts, _, label), forecast) => |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| type DAO struct { | |
| driver DataDriver | |
| } | |
| // the data we want to load | |
| QUsersTable := "CREATE TABLE users AS SELECT name, last_name, cast(age as integer)" | |
| + "FROM read_parquet('%s') " | |
| // the parquet files location |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ( | |
| ... | |
| "database/sql" | |
| "database/sql/driver" | |
| "github.com/marcboeker/go-duckdb" | |
| ... | |
| ) | |
| // the driver struct wrapping the duckdb connection | |
| type DuckDBDriver struct { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| func (d *DuckDBDriver) Execute(statement string) error { | |
| _, err := d.db.Exec(statement) | |
| if err != nil { | |
| return err | |
| } | |
| return nil | |
| } | |
| func (d *DuckDBDriver) Query(statement string) (*sql.Rows, error) { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| let repo = Repo::with_revision(model_name.parse()?, RepoType::Model, revision.parse()?); | |
| let api = Api::new()?; | |
| let api = api.repo(repo); | |
| let config_filename = api.get("config.json")?; | |
| let tokenizer_filename = api.get("tokenizer.json")?; | |
| let weights_filename = api.get("model.safetensors")?; | |
| // load the model config | |
| let config = std::fs::read_to_string(config_filename)?; | |
| let config: Config = serde_json::from_str(&config)?; |