Last active
December 20, 2015 04:09
-
-
Save MLnick/6068841 to your computer and use it in GitHub Desktop.
Spark Machine Learning API Design Notes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// An Example is an observation with optional target value and features in the form of a vector of Doubles | |
case class Example(target: Option[Double] = None, features: Vector[Double]) | |
// Base model API looks something like: | |
abstract class BaseModel(val modelSettings: Settings) | |
extends Serializable | |
with Logging { | |
def fit(data: RDD[Example]) | |
def fit[U](data: RDD[U])(implicit dataMapping: U => Example) { fit(data.map(dataMapping)) } | |
def predict(data: RDD[Example]) | |
} | |
// Models are free to implement their own additional fit methods, e.g. ALS does | |
def fit(data: RDD[(Int, Int, Double)])(implicit evidence: Manifest[RDD[(Int, Int, Double)]]) { | |
fit(data.map{ case (userId, itemId, score) => | |
Example(Option(score), DenseVector(userId - 1, itemId - 1)) } | |
) | |
} | |
// Models can have a default "DataMapping" from raw data (usually text) to model inputs | |
object DefaultALSMapping extends DataMapping[String] { | |
override def call(str: String) = { | |
DataMapping.numberStringSplit(str) match { | |
case Array(userId, itemId, score, _*) => | |
Example(Option(score.toDouble), DenseVector(userId.toDouble - 1, itemId.toDouble - 1)) | |
} | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment