-
-
Save aria42/578348 to your computer and use it in GitHub Desktop.
(ns type-level-tagger | |
{:doc "Implements State-of-the-art Unsupervised Part-of-speech Tagger | |
from \"Simple Type-Level Unsuperivsed POS Tagging\" | |
by Yoong-Keok Lee, Aria Haghighi and Regina Barzilay | |
(http://www.cs.berkeley.edu/~aria42/pubs/typetagging.pdf) | |
blog post: http://wp.me/pcW6S-x" | |
:author "Aria Haghighi ([email protected])"} | |
(:use [clojure.java.io :only [reader]] | |
[clojure.contrib.duck-streams :only [with-out-writer]] | |
[clojure.contrib.seq-utils :only [indexed]] | |
[clojure.contrib.def :only [defvar]])) | |
;; Counter: Map from object to value, cache total | |
(defrecord Counter [counts total]) | |
(defn get-count | |
"retrieve count of k from counter, should not be negative" | |
[counter k] | |
{:post [(not (neg? %))]} | |
(get (:counts counter) k 0.0)) | |
(defn inc-count | |
"increment-count of k in counter by weight amount" | |
[counter k weight] | |
(let [new-count (+ (get-count counter k) weight)] | |
(Counter. (if (zero? new-count) | |
(dissoc (:counts counter) k) | |
(assoc (:counts counter) k new-count)) | |
(+ (:total counter) weight)))) | |
;; Probability Distribution | |
;; counter: counts of objects | |
;; lambda: smoothing constants | |
;; num-keys: number of possible keys, needed to normalize | |
(defrecord DirichletMultinomial [counter lambda num-keys]) | |
(defn new-dirichlet [lambda num-keys] | |
(DirichletMultinomial. (Counter. {} 0) lambda num-keys)) | |
(defn log-prob | |
"log prob. from a DirichletMultinomial" | |
[distr key] | |
{:post [(> % Double/NEGATIVE_INFINITY)]} | |
(let [{:keys [counter lambda num-keys]} distr] | |
(Math/log (/ (+ (get-count counter key) lambda) | |
(+ (:total counter) (* lambda num-keys)))))) | |
(defn observe | |
"make an observation to a DirichletMultinomial" | |
[distr key weight] | |
(let [{:keys [counter,total,lambda,num-keys]} distr] | |
(DirichletMultinomial. (inc-count counter key weight) lambda num-keys))) | |
; Word Information | |
; word: string of word | |
; count: # of usages | |
; feats: map of feature-type to feature-value | |
; contexts: counter of [before-word after-word] usages (for HMM) | |
(defrecord WordInfo [word count feats contexts]) | |
(defn get-feats | |
"Features on a word type" | |
[w] | |
{:hasInitCap (boolean (re-matches #"[A-Z].*" w)) | |
:hasPunc (boolean (re-matches #".*\W.*" w)) | |
:suffix (let [suffix-length (min 3 (.length w))] | |
(.substring #^String w (- (.length w) suffix-length)))}) | |
(defn new-word-info [word] | |
(WordInfo. word 0 (get-feats word) (Counter. {} 0))) | |
(defn tally-usage [word-info before after] | |
(-> word-info | |
(update-in [:count] inc) | |
(update-in [:contexts] inc-count [before after] 1))) | |
(defn assoc-if-absent [m k f] | |
(if (m k) m (assoc m k (f k)))) | |
(defn tally-sent [vocab sent] | |
(reduce | |
(fn [res [before word after]] | |
(-> res | |
(assoc-if-absent word new-word-info) | |
(update-in [word] tally-usage before after))) | |
vocab | |
(partition 3 1 sent))) | |
(defn build-vocab [sents] | |
(vals (reduce tally-sent {} sents))) | |
;; Gibbs Sampling State - All distributions are DirichletMultinomial | |
;; type-assigns: map word string to tag state (integer) | |
;; tag-prior: prior distr over tag assignment | |
;; trans-distrs: map of tag => P(tag' | tag) distribution | |
;; emission-distrs: tag => P(word | tag) distribution, word=string representation | |
;; feat-distrs: tag => feat-type => P(feat-val | feat-type,tag) distribution | |
(defrecord State [type-assigns tag-prior trans-distrs emission-distrs feat-distrs]) | |
;; Globals | |
(defvar *K* nil "number of tag states") | |
(defvar *vocab* nil "seq of word infos") | |
(defvar *outfile* nil "where to write each iteration word assignments") | |
(def +rand+ (java.util.Random. 0)) | |
;; Updating Counts after word assignment | |
(defn obs-transitions | |
"if we set word to tag, we update the transition tag counts | |
from all context usages of word by weight amount" | |
[trans-distrs type-assigns word-info tag weight] | |
(reduce | |
(fn [res [[before after] count]] | |
(let [type-assigns (assoc type-assigns (:word word-info) tag) | |
before-tag (type-assigns before) | |
after-tag (type-assigns after)] | |
(-> res | |
;; Observe P(tag | tag-assign(before)) | |
(update-in [before-tag] observe tag (* count weight)) | |
;; Observe P(tag-assign(after) | tag) | |
(update-in [tag] observe after-tag (* count weight))))) | |
trans-distrs | |
(-> word-info :contexts :counts))) | |
(defn obs-features [tag-feat-distrs word-info weight] | |
(reduce | |
(fn [res [k v]] | |
(update-in res [k] observe v weight)) | |
tag-feat-distrs | |
(:feats word-info))) | |
(defn obs-emissions | |
"if a word has been assigned to a tag, we increment num-keys by 1 | |
and add weight * num-occurences of the word to counts" | |
[tag-emission-distr word-info weight] | |
(-> tag-emission-distr | |
(update-in [:num-keys] (if (> weight 0) inc dec)) | |
(observe (:word word-info) (* weight (:count word-info))))) | |
(defn update-state | |
"add word assignment and associated counts" | |
[state word-info tag add?] | |
(let [assoc-fn (if add? assoc dissoc) weight (if add? 1 -1)] | |
(State. | |
(assoc-fn (:type-assigns state) (:word word-info) tag) | |
(observe (:tag-prior state) tag weight) | |
(obs-transitions (:trans-distrs state) (:type-assigns state) word-info tag weight) | |
(update-in (:emission-distrs state) [tag] | |
obs-emissions word-info weight) | |
(update-in (:feat-distrs state) [tag] | |
obs-features word-info weight)))) | |
(defn assign [state word-info tag] | |
(update-state state word-info tag true)) | |
(defn unassign [state word-info] | |
(update-state state word-info (-> state :type-assigns (get (:word word-info))) false)) | |
(defn sum | |
([f xs] (reduce + (map f xs))) | |
([xs] | |
(reduce + 0.0 xs))) | |
(defn make-map [f xs] | |
(reduce | |
(fn [res x] (assoc res x (f x))) | |
{} xs)) | |
(defn map-vals [f m] | |
(reduce | |
(fn [res [k v]] (assoc res k (f v))) | |
{} m)) | |
(defn log-add | |
"log (sum xs) from seq of log-x" | |
[log-xs] | |
(let [max-log-x (apply max log-xs)] | |
(+ max-log-x | |
(Math/log (sum | |
(for [log-x log-xs | |
:let [diff (- log-x max-log-x)] | |
:when (> diff -30)] | |
(Math/exp diff))))))) | |
(defn log-normalize [log-xs] | |
(let [log-sum (log-add log-xs)] | |
(map (fn [log-x] (Math/exp (- log-x log-sum))) log-xs))) | |
(defn sample-from-scores [log-scores] | |
(let [trg (.nextDouble +rand+)] | |
(loop [so-far 0.0 | |
posts (indexed (log-normalize log-scores))] | |
(if-let [[i p] (first posts)] | |
(cond | |
(< trg (+ so-far p)) i | |
:default (recur (+ so-far p) (rest posts))) | |
(throw (RuntimeException. "Impossible")))))) | |
(defn score-assign | |
"Log probability of assigning word to tag" | |
[state word-info tag] | |
(+ ;; Tag Prior | |
(log-prob (:tag-prior state) tag) | |
;; Feature Prob | |
(sum | |
(fn t1 [[k v]] | |
(log-prob (get-in state [:feat-distrs tag k]) v)) | |
(:feats word-info)) | |
;; Token Transition/Emission Prob | |
;; There's a subtely here in that we need to add one to the num-keys | |
;; for the emission distribution | |
(let [type-assigns (-> state :type-assigns (assoc (:word word-info) tag)) | |
word-log-prob | |
(-> (:emission-distrs state) | |
(get tag) | |
(update-in [:num-keys] inc) | |
(log-prob (:word word-info)))] | |
(sum | |
(fn t2 [[[before after] count]] | |
(let [before-tag (type-assigns before) after-tag (type-assigns after)] | |
(* count | |
(+ word-log-prob | |
(-> state :trans-distrs (get before-tag) (log-prob tag)) | |
(-> state :trans-distrs (get tag) (log-prob after-tag)))))) | |
(-> word-info :contexts :counts))))) | |
(defn gibbs-sample [state word-info] | |
(let [state (unassign state word-info) | |
scores (map (partial score-assign state word-info) (range *K*)) | |
sample-tag (sample-from-scores scores)] | |
(assign state word-info sample-tag))) | |
(defn gibbs-sample-iter [state] | |
(time (reduce gibbs-sample state *vocab*))) | |
(defn init-state-helper [alpha beta] | |
(let [num-distinct (fn [xs] (count (reduce conj (hash-set) xs))) | |
num-feat-map ; map feature-type to num possible values | |
(->> *vocab* | |
(mapcat :feats) | |
(group-by first) ; group by feature type | |
(map-vals num-distinct))] | |
(State. | |
; random word to tag assignment - also fix assignments to start/stop | |
(let [rand-assign (make-map (fn [_] (.nextInt +rand+ *K*)) (map :word *vocab*))] | |
(assoc rand-assign "#start#" :start "#stop#" :stop)) | |
; tag prior | |
(new-dirichlet alpha *K*) | |
; transition distributions: all tags have same prior on successors | |
; K+1 possible values for tags and :stop state | |
; Also need a transition distribution for :start state | |
(make-map | |
(constantly (new-dirichlet beta (inc *K*))) | |
(conj (range *K*) :start)) | |
; emission distributions: all tags have same prior | |
(make-map | |
(constantly (new-dirichlet beta 0)) | |
(range *K*)) | |
; feat distributions | |
(let [tag-feat-distrs (map-vals (partial new-dirichlet alpha) num-feat-map)] | |
(make-map (constantly tag-feat-distrs) (range *K*)))))) | |
(defn init-state [alpha beta] | |
(reduce | |
(fn [res word-info] | |
(assign res word-info (-> res :type-assigns (get (:word word-info))))) | |
(init-state-helper alpha beta) | |
*vocab*)) | |
(defn learn [num-iters alpha beta] | |
(println "Learning on " num-iters " iterations.") | |
(println (format "Num Word Types: %d Num Tokens: %d" (count *vocab*) (sum :count *vocab*))) | |
(loop [iter 1 state (init-state alpha beta)] | |
(with-out-writer *outfile* | |
(doseq [[word tag] (:type-assigns state) :when (not ((hash-set :start :stop) tag))] | |
(println (format "%s\t%s" word tag)))) | |
(println (format "Finished Iteration %d %s Wrote current assignment to %s" | |
iter (if (= iter 1) "(Random)" "") *outfile*)) | |
(if (= iter num-iters) state | |
(recur (inc iter) (gibbs-sample-iter state))))) | |
(defn -main | |
"Main entry run with infile outfile num-iters K alpha beta | |
infile: file with one sentence per-line tokenized so split on space gives tokens (no start/stop) | |
outfile: file to write word to tag mapping after each iteration | |
num-iters: number of gibbs sampling iters to run | |
K: number of tags to use | |
alpha,beta: doubles representing smoothing (try 0.1 1)" | |
[& args] | |
(let [[infile outfile num-iters K alpha beta] args | |
sents-fn #(map | |
(fn [line] (concat ["#start#"] (seq (.split #^String line "\\s+")) ["#stop#"])) | |
(-> infile reader line-seq))] | |
(binding [*K* (Integer/parseInt K) *vocab* (build-vocab (sents-fn)) *outfile* outfile] | |
(learn (Integer/parseInt num-iters) (Double/parseDouble alpha) (Double/parseDouble beta))))) | |
(apply -main *command-line-args*) |
I have a script that just runs java with clojure.main adding clojure and contrib jar to the classpath. Make sure you use the -sever option and add some memory.
Edited to remove (neg? %) post since if you have little enough data the probability can become 1 (i.e., log prob can be zero).
If you want some sample English data to run this on, I have the zipped Brown corpus tokenized and one sentence per-line at http://dl.dropbox.com/u/1205228/brown.txt.gz. The data is for non-commercial use only and is about 1.1 millon tokens.
Hi Aria42.
firstly thank you for posting this.
I was wondering what license the code is under.
my preference is the Apache License (ASL2.0) if you don't have any.
Regards
Ian
The github project is under the Eclipse license, just like Clojure. Is that acceptable? It doesn't really make a difference to me.
Thanks aria42.
thats 100% acceptable.
Nice! Struggled a bit to find a repl script that would handle the command line properly. Finally cljr worked. What do you use?