Last active
July 3, 2017 11:04
-
-
Save fmarten/f750e2808b865d642c5faddb9d44b508 to your computer and use it in GitHub Desktop.
A minimal implementation of JoBimText in scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env scalas | |
/*** | |
scalaVersion := "2.11.8" | |
resolvers ++= Seq( | |
"apache-snapshots" at "http://repository.apache.org/snapshots/" | |
) | |
val sparkVersion = "2.1.1" | |
libraryDependencies ++= Seq( | |
"org.apache.spark" %% "spark-core" % sparkVersion, | |
"org.apache.spark" %% "spark-sql" % sparkVersion | |
) | |
*/ | |
// Create observations.csv file | |
new java.io.PrintWriter("/tmp/observations.csv") { | |
write("term\tcontext\npython\tfeature1\npython\tfeature2\nsnake\tfeature1") | |
close() | |
} | |
// Run Spark pipeline | |
import org.apache.spark.sql.SparkSession | |
implicit val spark = SparkSession | |
.builder() | |
.config("spark.master", "local[*]") | |
.getOrCreate() | |
val df = spark | |
.read | |
.format("csv") | |
.option("delimiter", "\t") | |
.option("quote", "") | |
.option("header", "true") | |
.load("/tmp/observations.csv") | |
.createOrReplaceTempView("observations") | |
/* | |
* Calculating significance scores with Local (or Lexical) Mutual Information (i.e. LMI) | |
* (I did not find a Wikipedia article for LMI or another canonical online resource.) | |
* | |
* LMI is defined by Biemann et al, 2D Text Now, 2013 | |
* with the help of PMI (https://en.wikipedia.org/wiki/Pointwise_mutual_information) | |
* | |
* PMI(term,feature) = log2 { f(term,feature) / [f(term) * f(feature)] } | |
* LMI(term,feature) = f(term,feature) * PMI | |
* | |
* Here f(thing) signifies "count of thing". | |
* | |
* | |
* To read belows source code, you will however find the following notation easier: | |
* PMI(term,context) = log2 { f_tc / [f_t * f_c] } | |
* LMI(term,context) = f_tc * PMI | |
* | |
* Notice how we have changed "feature" to "context" | |
* And shifted the notation from f(*) to n_*, which corresponds to belows column names | |
* | |
*/ | |
spark | |
.sql("SELECT term, COUNT(*) as f_t FROM observations GROUP BY term") | |
.createOrReplaceTempView("term_counts") | |
spark | |
.sql("SELECT context, COUNT(*) f_c FROM observations GROUP BY context") | |
.createOrReplaceTempView("context_counts") | |
spark | |
.sql("SELECT term, context, COUNT(*) as f_tc FROM observations GROUP BY term, context") | |
.createOrReplaceTempView("term_context_counts") | |
spark | |
.sql( | |
""" | |
| SELECT o.term, o.context, f_tc * log2( f_tc / (f_t * f_c) ) as score | |
| FROM observations o | |
| JOIN term_context_counts tcc ON o.term = tcc.term AND o.context = tcc.context | |
| JOIN term_counts tc ON o.term = tc.term | |
| JOIN context_counts cc ON o.context = cc.context | |
""".stripMargin) | |
.createOrReplaceTempView("lmi_scores") | |
spark | |
.sql("SELECT term, context, score FROM lmi_scores") | |
.write | |
.format("csv") | |
.option("delimiter", "\t") | |
.option("quote", "") | |
.option("header", "true") | |
.save("/tmp/lmi") | |
/** | |
* Calculating Distributional Thesaurus (DT) | |
* by taking top 10 contexts per term. | |
* | |
*/ | |
spark | |
.sql( | |
""" | |
| SELECT term, CONCAT(context, ":", score) AS scored_context, rank | |
| FROM ( | |
| SELECT term, context, score, | |
| DENSE_RANK() OVER (PARTITION BY term ORDER BY score DESC) AS rank | |
| FROM lmi_scores | |
| ) ranked_contexts WHERE rank <= 10 | |
""".stripMargin) | |
.createOrReplaceTempView("top_contexts") | |
spark | |
.sql( | |
""" | |
| SELECT term, CONCAT_WS(",", COLLECT_LIST(scored_context)) AS contexts | |
| FROM top_contexts | |
| GROUP BY term | |
""".stripMargin) | |
.createOrReplaceTempView("distributional_thesaurus") | |
spark | |
.sql("SELECT term, contexts FROM distributional_thesaurus") | |
.write | |
.format("csv") | |
.option("delimiter", "\t") | |
.option("quote", "") | |
.option("header", "true") | |
.save("/tmp/dt") | |
spark.stop() | |
// The results can be found in /tmp/dt and /tmp/lmi |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To run this script, please follow instructions from: http://www.scala-sbt.org/0.13/docs/Scripts.html