Last active
August 21, 2020 18:26
-
-
Save ahmadmo/bc16475a8b64a4e3a0a7137b1de6ca4f to your computer and use it in GitHub Desktop.
Builds vocabulary from a set of english movie subtitles (.srt files)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.nio.charset.Charset | |
import java.nio.charset.StandardCharsets | |
import java.nio.file.Files | |
import java.nio.file.Path | |
import java.nio.file.Paths | |
import java.nio.file.StandardOpenOption | |
import java.util.regex.Pattern | |
import kotlin.math.ln | |
import kotlin.streams.asSequence | |
const val MAX_SHINGLE_SIZE = 4 | |
val tfIdfThresholds = arrayOf( | |
0.0004 to 0.0036, | |
0.0004 to 0.0056, | |
0.0008 to 0.0072, | |
0.0016 to 0.0104 | |
) | |
fun main() { | |
val inputDirs = arrayOf( | |
Paths.get("/path/to/movies"), | |
Paths.get("/path/to/other/movies") | |
) | |
val outputDir = Paths.get("/path/to/vocab") | |
val srtFiles = inputDirs.flatMap(::listSrtFiles) | |
srtFiles.parallelStream() | |
.map(::buildDocument) | |
.sequential() | |
.asSequence().flatMap { it.asSequence() } | |
.groupBy({ it.key }, { it.value }) | |
.entries.parallelStream() | |
.forEach { (shingleSize, corpus) -> | |
buildVocab(corpus, shingleSize).save(outputDir, shingleSize) | |
} | |
} | |
// ------------------------------------------------ | |
typealias Shingles = HashMap<String, Int> | |
typealias Document = HashMap<Int, Shingles> | |
typealias ScoredShingle = Pair<String, Double> | |
typealias CumulativeScoredShingles = HashMap<String, ArrayList<Double>> | |
typealias AverageScoredShingles = Map<String, Double> | |
operator fun Shingles.plusAssign(shingle: String) { | |
this[shingle] = (this[shingle] ?: 0) + 1 | |
} | |
operator fun CumulativeScoredShingles.plusAssign(shingle: ScoredShingle) { | |
computeIfAbsent(shingle.first) { ArrayList() } += shingle.second | |
} | |
// ------------------------------------------------ | |
fun listSrtFiles(dir: Path): List<Path> = | |
Runtime.getRuntime() | |
.exec(arrayOf("/bin/sh", "-c", "find $dir -type f -name \"*.srt\"")) | |
.let { proc -> | |
val files = proc.inputStream.use { stream -> | |
stream.bufferedReader().lineSequence().filter { it[0] == '/' }.map(Paths::get).toList() | |
} | |
proc.destroy() | |
files | |
} | |
// ------------------------------------------------ | |
fun String.normalize(): String = | |
replaceMlTags() | |
.replaceAssTags() | |
.replaceNonWords() | |
.replaceTrailingApostrophes() | |
.replaceLargeSpaces() | |
.toLowerCase().trim() | |
/** | |
* Markup Language Tags | |
*/ | |
val mlTagPattern: Pattern = Pattern.compile("<[^>]+>") | |
fun String.replaceMlTags(): String = mlTagPattern.matcher(this).replaceAll(" ") | |
/** | |
* @see <a href="http://docs.aegisub.org/3.2/ASS_Tags/">ASS Tags</a> | |
*/ | |
val assTagPattern: Pattern = Pattern.compile("\\{\\\\[^}]+}") | |
fun String.replaceAssTags(): String = assTagPattern.matcher(this).replaceAll(" ") | |
val nonWordPattern: Pattern = Pattern.compile("[\\W&&[^']]+") | |
fun String.replaceNonWords(): String = nonWordPattern.matcher(this).replaceAll(" ") | |
val trailingApostrophePattern: Pattern = Pattern.compile("'(?:\\s|$)") | |
fun String.replaceTrailingApostrophes(): String = trailingApostrophePattern.matcher(this).replaceAll(" ") | |
val largeSpacePattern: Pattern = Pattern.compile("\\s{2,}") | |
fun String.replaceLargeSpaces(): String = largeSpacePattern.matcher(this).replaceAll(" ") | |
// ------------------------------------------------ | |
val srtCharsets = arrayOf(StandardCharsets.UTF_8, StandardCharsets.ISO_8859_1) | |
val srtTimeRangePattern: Pattern = | |
Pattern.compile("^\\d{2,}:\\d{2}:\\d{2}(?:,\\d{3})? --> \\d{2,}:\\d{2}:\\d{2}(?:,\\d{3})?$") | |
fun buildDocument(srtFile: Path): Document { | |
println("building document = $srtFile") | |
lateinit var exception: Exception | |
for (charset in srtCharsets) try { | |
return buildDocument(srtFile, charset) | |
} catch (ex: Exception) { | |
exception = ex | |
} | |
throw exception | |
} | |
fun buildDocument(srtFile: Path, charset: Charset): Document { | |
val document = Document() | |
Files.newBufferedReader(srtFile, charset).useLines { lines -> | |
val text = StringBuilder() | |
for (line in lines) when { | |
line.isBlank() -> continue | |
srtTimeRangePattern.matcher(line).matches() -> continue | |
line.toIntOrNull() == null -> text.append(line).append(' ') | |
else -> { | |
addShingles(document, text.toString()) | |
text.setLength(0) | |
} | |
} | |
addShingles(document, text.toString()) | |
} | |
return document | |
} | |
fun addShingles(document: Document, text: String) { | |
if (text.isBlank()) return | |
val normalizedText = text.normalize() | |
if (normalizedText.isBlank()) return | |
val words = normalizedText.split(' ').filter { word -> word.none { it.isDigit() } } | |
(1..MAX_SHINGLE_SIZE).forEach { size -> | |
val shingles = document.computeIfAbsent(size) { Shingles() } | |
words.asSequence() | |
.chunked(size) | |
.filter { it.size == size } | |
.map { it.joinToString(separator = " ") } | |
.filter { it.length > size * 2 } | |
.forEach { shingles += it } | |
} | |
} | |
// ------------------------------------------------ | |
/** | |
* @see <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">tf–idf</a> | |
*/ | |
fun tfIdf(shingle: String, source: Shingles, corpus: List<Shingles>): Double { | |
val tf = source.getValue(shingle).toDouble() / source.values.sum() | |
val n = corpus.count { shingle in it } | |
val idf = ln(corpus.size / (1.0 + n)) + 1.0 | |
return tf * idf | |
} | |
fun Shingles.filteredSortByTfIdf(corpus: List<Shingles>, shingleSize: Int): List<ScoredShingle> = | |
keys.asSequence() | |
.map { ScoredShingle(it, tfIdf(it, this, corpus)) } | |
.filter { filterTfIdfScore(it.second, shingleSize) } | |
.sortedBy { it.second } | |
.toList() | |
fun CumulativeScoredShingles.toFilteredAverage(shingleSize: Int): AverageScoredShingles = | |
asSequence().map { ScoredShingle(it.key, it.value.average()) } | |
.filter { filterTfIdfScore(it.second, shingleSize) } | |
.toMap() | |
fun filterTfIdfScore(score: Double, shingleSize: Int): Boolean { | |
val thresholds = tfIdfThresholds[shingleSize - 1] | |
return score >= thresholds.first && score < thresholds.second | |
} | |
// ------------------------------------------------ | |
fun buildVocab(corpus: List<Shingles>, shingleSize: Int): AverageScoredShingles { | |
println("building vocab = $shingleSize") | |
val vocab = CumulativeScoredShingles() | |
for (shingles in corpus) { | |
shingles.filteredSortByTfIdf(corpus, shingleSize).forEach { vocab += it } | |
} | |
return vocab.toFilteredAverage(shingleSize) | |
} | |
fun AverageScoredShingles.save(dir: Path, shingleSize: Int) { | |
val file = dir.resolve("$shingleSize.txt") | |
val openOptions = arrayOf(StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE) | |
Files.newBufferedWriter(file, *openOptions).use { writer -> | |
keys.asSequence().sorted().forEach { shingle -> | |
writer.write("$shingle\n") | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Sample results:
1.txt
2.txt
3.txt
4.txt