Skip to content

Instantly share code, notes, and snippets.

@ahmadmo
Last active August 21, 2020 18:26
Show Gist options
  • Save ahmadmo/bc16475a8b64a4e3a0a7137b1de6ca4f to your computer and use it in GitHub Desktop.
Save ahmadmo/bc16475a8b64a4e3a0a7137b1de6ca4f to your computer and use it in GitHub Desktop.
Builds vocabulary from a set of english movie subtitles (.srt files)
import java.nio.charset.Charset
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.nio.file.StandardOpenOption
import java.util.regex.Pattern
import kotlin.math.ln
import kotlin.streams.asSequence
const val MAX_SHINGLE_SIZE = 4
val tfIdfThresholds = arrayOf(
0.0004 to 0.0036,
0.0004 to 0.0056,
0.0008 to 0.0072,
0.0016 to 0.0104
)
fun main() {
val inputDirs = arrayOf(
Paths.get("/path/to/movies"),
Paths.get("/path/to/other/movies")
)
val outputDir = Paths.get("/path/to/vocab")
val srtFiles = inputDirs.flatMap(::listSrtFiles)
srtFiles.parallelStream()
.map(::buildDocument)
.sequential()
.asSequence().flatMap { it.asSequence() }
.groupBy({ it.key }, { it.value })
.entries.parallelStream()
.forEach { (shingleSize, corpus) ->
buildVocab(corpus, shingleSize).save(outputDir, shingleSize)
}
}
// ------------------------------------------------
typealias Shingles = HashMap<String, Int>
typealias Document = HashMap<Int, Shingles>
typealias ScoredShingle = Pair<String, Double>
typealias CumulativeScoredShingles = HashMap<String, ArrayList<Double>>
typealias AverageScoredShingles = Map<String, Double>
operator fun Shingles.plusAssign(shingle: String) {
this[shingle] = (this[shingle] ?: 0) + 1
}
operator fun CumulativeScoredShingles.plusAssign(shingle: ScoredShingle) {
computeIfAbsent(shingle.first) { ArrayList() } += shingle.second
}
// ------------------------------------------------
fun listSrtFiles(dir: Path): List<Path> =
Runtime.getRuntime()
.exec(arrayOf("/bin/sh", "-c", "find $dir -type f -name \"*.srt\""))
.let { proc ->
val files = proc.inputStream.use { stream ->
stream.bufferedReader().lineSequence().filter { it[0] == '/' }.map(Paths::get).toList()
}
proc.destroy()
files
}
// ------------------------------------------------
fun String.normalize(): String =
replaceMlTags()
.replaceAssTags()
.replaceNonWords()
.replaceTrailingApostrophes()
.replaceLargeSpaces()
.toLowerCase().trim()
/**
* Markup Language Tags
*/
val mlTagPattern: Pattern = Pattern.compile("<[^>]+>")
fun String.replaceMlTags(): String = mlTagPattern.matcher(this).replaceAll(" ")
/**
* @see <a href="http://docs.aegisub.org/3.2/ASS_Tags/">ASS Tags</a>
*/
val assTagPattern: Pattern = Pattern.compile("\\{\\\\[^}]+}")
fun String.replaceAssTags(): String = assTagPattern.matcher(this).replaceAll(" ")
val nonWordPattern: Pattern = Pattern.compile("[\\W&&[^']]+")
fun String.replaceNonWords(): String = nonWordPattern.matcher(this).replaceAll(" ")
val trailingApostrophePattern: Pattern = Pattern.compile("'(?:\\s|$)")
fun String.replaceTrailingApostrophes(): String = trailingApostrophePattern.matcher(this).replaceAll(" ")
val largeSpacePattern: Pattern = Pattern.compile("\\s{2,}")
fun String.replaceLargeSpaces(): String = largeSpacePattern.matcher(this).replaceAll(" ")
// ------------------------------------------------
val srtCharsets = arrayOf(StandardCharsets.UTF_8, StandardCharsets.ISO_8859_1)
val srtTimeRangePattern: Pattern =
Pattern.compile("^\\d{2,}:\\d{2}:\\d{2}(?:,\\d{3})? --> \\d{2,}:\\d{2}:\\d{2}(?:,\\d{3})?$")
fun buildDocument(srtFile: Path): Document {
println("building document = $srtFile")
lateinit var exception: Exception
for (charset in srtCharsets) try {
return buildDocument(srtFile, charset)
} catch (ex: Exception) {
exception = ex
}
throw exception
}
fun buildDocument(srtFile: Path, charset: Charset): Document {
val document = Document()
Files.newBufferedReader(srtFile, charset).useLines { lines ->
val text = StringBuilder()
for (line in lines) when {
line.isBlank() -> continue
srtTimeRangePattern.matcher(line).matches() -> continue
line.toIntOrNull() == null -> text.append(line).append(' ')
else -> {
addShingles(document, text.toString())
text.setLength(0)
}
}
addShingles(document, text.toString())
}
return document
}
fun addShingles(document: Document, text: String) {
if (text.isBlank()) return
val normalizedText = text.normalize()
if (normalizedText.isBlank()) return
val words = normalizedText.split(' ').filter { word -> word.none { it.isDigit() } }
(1..MAX_SHINGLE_SIZE).forEach { size ->
val shingles = document.computeIfAbsent(size) { Shingles() }
words.asSequence()
.chunked(size)
.filter { it.size == size }
.map { it.joinToString(separator = " ") }
.filter { it.length > size * 2 }
.forEach { shingles += it }
}
}
// ------------------------------------------------
/**
* @see <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">tf–idf</a>
*/
fun tfIdf(shingle: String, source: Shingles, corpus: List<Shingles>): Double {
val tf = source.getValue(shingle).toDouble() / source.values.sum()
val n = corpus.count { shingle in it }
val idf = ln(corpus.size / (1.0 + n)) + 1.0
return tf * idf
}
fun Shingles.filteredSortByTfIdf(corpus: List<Shingles>, shingleSize: Int): List<ScoredShingle> =
keys.asSequence()
.map { ScoredShingle(it, tfIdf(it, this, corpus)) }
.filter { filterTfIdfScore(it.second, shingleSize) }
.sortedBy { it.second }
.toList()
fun CumulativeScoredShingles.toFilteredAverage(shingleSize: Int): AverageScoredShingles =
asSequence().map { ScoredShingle(it.key, it.value.average()) }
.filter { filterTfIdfScore(it.second, shingleSize) }
.toMap()
fun filterTfIdfScore(score: Double, shingleSize: Int): Boolean {
val thresholds = tfIdfThresholds[shingleSize - 1]
return score >= thresholds.first && score < thresholds.second
}
// ------------------------------------------------
fun buildVocab(corpus: List<Shingles>, shingleSize: Int): AverageScoredShingles {
println("building vocab = $shingleSize")
val vocab = CumulativeScoredShingles()
for (shingles in corpus) {
shingles.filteredSortByTfIdf(corpus, shingleSize).forEach { vocab += it }
}
return vocab.toFilteredAverage(shingleSize)
}
fun AverageScoredShingles.save(dir: Path, shingleSize: Int) {
val file = dir.resolve("$shingleSize.txt")
val openOptions = arrayOf(StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE)
Files.newBufferedWriter(file, *openOptions).use { writer ->
keys.asSequence().sorted().forEach { shingle ->
writer.write("$shingle\n")
}
}
}
@ahmadmo
Copy link
Author

ahmadmo commented Aug 21, 2020

Sample results:

1.txt
2.txt
3.txt
4.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment