Created
September 12, 2019 03:48
-
-
Save tixxit/f67cabd2389823b520b35afcab0564b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/build.sbt b/build.sbt | |
index d239642..c97e604 100644 | |
--- a/build.sbt | |
+++ b/build.sbt | |
@@ -6,10 +6,12 @@ crossScalaVersions in ThisBuild := Seq("2.11.12", "2.12.7") | |
scalacOptions in ThisBuild ++= Seq("-deprecation", "-feature", "-unchecked", "-language:higherKinds") | |
-val catsEffectVersion = "1.0.0" | |
+val catsEffectVersion = "1.4.0" | |
val declineVersion = "0.5.0" | |
val scalaTestVersion = "3.0.5" | |
val scalaCheckVersion = "1.14.0" | |
+val fastparseVersion = "2.1.3" | |
+val jline3Version = "3.9.0" | |
val commonSettings = Seq( | |
libraryDependencies ++= Seq( | |
@@ -27,7 +29,9 @@ lazy val core = project | |
commonSettings, | |
libraryDependencies ++= Seq( | |
"org.typelevel" %% "cats-effect" % catsEffectVersion, | |
- "com.monovore" %% "decline" % declineVersion | |
+ "com.monovore" %% "decline" % declineVersion, | |
+ "com.lihaoyi" %% "fastparse" % fastparseVersion, | |
+ "org.jline" % "jline" % jline3Version | |
), | |
PB.targets in Compile := Seq( | |
scalapb.gen() -> (sourceManaged in Compile).value | |
diff --git a/core/src/main/scala/net/tixxit/gulon/command/Main.scala b/core/src/main/scala/net/tixxit/gulon/command/Main.scala | |
index d59644d..163ba2f 100644 | |
--- a/core/src/main/scala/net/tixxit/gulon/command/Main.scala | |
+++ b/core/src/main/scala/net/tixxit/gulon/command/Main.scala | |
@@ -8,6 +8,7 @@ object Main extends IOApp { | |
Opts.subcommand(BuildIndex.command) | |
.orElse(Opts.subcommand(Query.command)) | |
.orElse(Opts.subcommand(QueryWords.command)) | |
+ .orElse(Opts.subcommand(Repl.command)) | |
.orElse(Opts.subcommand(Test.command)) | |
val app: Command[IO[ExitCode]] = | |
diff --git a/core/src/main/scala/net/tixxit/gulon/command/Repl.scala b/core/src/main/scala/net/tixxit/gulon/command/Repl.scala | |
new file mode 100644 | |
index 0000000..f2a9874 | |
--- /dev/null | |
+++ b/core/src/main/scala/net/tixxit/gulon/command/Repl.scala | |
@@ -0,0 +1,135 @@ | |
+package net.tixxit.gulon | |
+package command | |
+ | |
+import java.nio.file.Path | |
+ | |
+import cats.data.{Validated, ValidatedNel} | |
+import cats.effect.{ContextShift, ExitCode, IO} | |
+import cats.implicits._ | |
+import com.monovore.decline._ | |
+import org.jline.reader.{EndOfFileException, LineReader, LineReaderBuilder} | |
+import org.jline.utils.{AttributedString, AttributedStringBuilder, AttributedStyle} | |
+ | |
+object Repl { | |
+ | |
+ case class Options(index: Path) | |
+ | |
+ object Options { | |
+ val index = Opts.option[Path]( | |
+ "index", short="i", metavar="file", | |
+ help="path to ANN index") | |
+ | |
+ val opts: Opts[Options] = index.map(Options(_)) | |
+ } | |
+ | |
+ sealed trait QueryResult | |
+ case class PointResult(value: Array[Float]) extends QueryResult | |
+ case class IndexResult(value: Index.Result) extends QueryResult | |
+ | |
+ sealed trait Query[+A] | |
+ case class Add(lhs: Query[PointResult], rhs: Query[PointResult]) extends Query[PointResult] | |
+ case class Sub(lhs: Query[PointResult], rhs: Query[PointResult]) extends Query[PointResult] | |
+ case class Lookup(word: String) extends Query[PointResult] | |
+ case class NN(query: Query[PointResult], k: Int) extends Query[IndexResult] | |
+ | |
+ object Parser { | |
+ import fastparse._ | |
+ import NoWhitespace._ | |
+ | |
+ // nn(word, 5) | |
+ // point(word) | |
+ // nn(p(word) + p(word) - p(word)) | |
+ def ws[_: P]: P[Unit] = " ".rep | |
+ def lookup[_: P]: P[Query[PointResult]] = P( "\"" ~/ (!"\"" ~ AnyChar.!).rep ~ "\"" ) | |
+ .map(_.mkString) | |
+ .map(Lookup(_)) | |
+ def parens[_: P]: P[Query[PointResult]] = P( "(" ~/ pointExpr ~ ")" ) | |
+ def term[_: P]: P[Query[PointResult]] = P( parens | lookup ) | |
+ def pointExpr[_: P]: P[Query[PointResult]] = P( term ~ ws ~ (CharIn("+\\-").! ~ ws ~/ term).rep ) | |
+ .map { case (init, terms) => | |
+ terms.foldLeft(init) { | |
+ case (x, ("+", y)) => Add(x, y) | |
+ case (x, ("-", y)) => Sub(x, y) | |
+ } | |
+ } | |
+ def indexExpr[_: P]: P[Query[IndexResult]] = P( "nn(" ~/ ws ~ pointExpr ~ ws ~ ")" ).map(NN(_, 1)) | |
+ def expr[_: P]: P[Query[QueryResult]] = P( (pointExpr | indexExpr) ~ End ) | |
+ } | |
+ | |
+ type Result[A] = ValidatedNel[String, A] | |
+ | |
+ def exec[A](index: Index, query: Query[A]): Result[A] = { | |
+ def recur[A0](q: Query[A0]): Result[A0] = | |
+ q match { | |
+ case Add(x, y) => | |
+ (recur(x), recur(y)).mapN { (x, y) => PointResult(MathUtils.add(x.value, x.value)) } | |
+ case Sub(x, y) => | |
+ (recur(x), recur(y)).mapN { (x, y) => PointResult(MathUtils.subtract(x.value, x.value)) } | |
+ case Lookup(word) => | |
+ index.lookup(word) | |
+ .map { x => Validated.validNel(PointResult(x)) } | |
+ .getOrElse(Validated.invalidNel(s"lookup failed: $word")) | |
+ case NN(query, k) => | |
+ recur(query).map { point => IndexResult(index.query(k, point.value)) } | |
+ } | |
+ | |
+ recur(query) | |
+ } | |
+ | |
+ def repl(index: Index): IO[Unit] = { | |
+ import fastparse.{Parsed, parse} | |
+ | |
+ def formatError(message: String): String = | |
+ new AttributedStringBuilder() | |
+ .style(AttributedStyle.DEFAULT.foreground(AttributedStyle.RED)) | |
+ .append("error> ") | |
+ .style(AttributedStyle.DEFAULT) | |
+ .append(message) | |
+ .toAnsi | |
+ | |
+ def parseExecFormat(line: String): String = | |
+ parse(line, Parser.expr(_)) match { | |
+ case Parsed.Success(query, _) => | |
+ exec(index, query) match { | |
+ case Validated.Valid(PointResult(result)) => | |
+ if (result.length <= 4) { | |
+ s"[${result.mkString(", ")}]" | |
+ } else { | |
+ s"[${result.take(2).mkString(", ")}, ..., ${result.last}]" | |
+ } | |
+ case Validated.Valid(IndexResult(result)) => | |
+ s"[${result.map(_._1).mkString(", ")}]" | |
+ case Validated.Invalid(errors) => | |
+ errors.map(formatError(_)).toList.mkString("\n") | |
+ } | |
+ case f @ Parsed.Failure(expected, failIndex, extra) => | |
+ formatError(s"invalid query: ${extra.trace().longMsg}") | |
+ } | |
+ | |
+ val prompt = new AttributedString("gulon> ", AttributedStyle.DEFAULT.foreground(AttributedStyle.CYAN)).toAnsi | |
+ | |
+ def loop(reader: LineReader): IO[Unit] = | |
+ IO.delay(reader.readLine(prompt)) | |
+ .flatMap { line => | |
+ IO.delay(println(parseExecFormat(line))) >> loop(reader) | |
+ } | |
+ .recover { | |
+ case (_: EndOfFileException) => () | |
+ } | |
+ | |
+ IO.delay(LineReaderBuilder.builder().build()) | |
+ .flatMap(loop) | |
+ } | |
+ | |
+ | |
+ def run(implicit contextShift: ContextShift[IO]): Opts[IO[ExitCode]] = | |
+ Options.opts.map { options => | |
+ for { | |
+ index <- Index.read(CommandUtils.openPath(options.index)) | |
+ _ <- repl(index) | |
+ } yield ExitCode(0) | |
+ } | |
+ | |
+ def command(implicit contextShift: ContextShift[IO]): Command[IO[ExitCode]] = | |
+ Command("repl", "open up a REPL for an index", true)(run) | |
+} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment