Created
August 13, 2015 20:49
-
-
Save tmbdev/928422f4b491f4aef4f5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# <nbformat>3.0</nbformat> | |
# <codecell> | |
import openfst | |
from scipy.ndimage import measurements,filters | |
from collections import defaultdict | |
from pylab import * | |
import unicodedata | |
class AsciiCodec: | |
def __init__(self): | |
d = {} | |
for i,c in enumerate(range(32,128)): | |
c = unichr(c) | |
d[i+1] = c | |
d[c] = i+1 | |
self.d = d | |
def nclasses(self): | |
return 97 | |
def encode(self,s): | |
s = unicode(s) | |
s = unicodedata.normalize('NFKD',s) | |
s = s.encode("ascii","replace") | |
reject = self.d["~"] | |
return array([self.d.get(c,reject) for c in s],'i') | |
def decode(self,a): | |
return "".join([self.d.get(i,"~") for i in a]) | |
default_codec = AsciiCodec() | |
def fstprint(fst,states=None): | |
for state in fst: | |
for x in fst.iterarcs(state): | |
if states is None or state in states: | |
print state,"->",x.nextstate,(x.ilabel,x.olabel),x.weight.Value() | |
# <markdowncell> | |
# We start by making an FST for the transcript. This is of the form | |
# "#+A+#+B+#+C+#+", where "#" is the "other" label. This transduces | |
# class labels to positions in the ground truth string. We need to | |
# make sure class labels start at 1, since 0 is reserved for epsilon. | |
# <codecell> | |
def make_transcript_fst(targets): | |
# Make an fst for the list of classes "targets". | |
# We're assuming that class 0 is the "skip" class and actual | |
# classes are numbered starting at 1. Since 0 is epsilon in OpenFST, | |
# we add an offset. | |
gt = [1] | |
for i,c in enumerate(targets): | |
gt += [c+1,1] | |
transcript_fst = openfst.LogVectorFst() | |
states = [transcript_fst.AddState() for i in range(len(gt)+1)] | |
for i,c in enumerate(gt): | |
transcript_fst.AddArc(states[i],int(c),i+1,0.0,states[i]) | |
transcript_fst.AddArc(states[i],int(c),i+1,0.0,states[i+1]) | |
transcript_fst.SetStart(states[0]) | |
transcript_fst.SetFinal(states[-1],0.0) | |
return transcript_fst,gt | |
# <markdowncell> | |
# Next, for demonstration purposes, we just generate random outputs | |
# from the classifier and transform those into a transducer. The | |
# input labels are times and the output labels are classes. We add 1 | |
# again because 0 means epsilon. | |
# <codecell> | |
def make_output_fst(outputs,threshold=100.0): | |
n,nc = outputs.shape | |
signal_fst = openfst.LogVectorFst() | |
states = [signal_fst.AddState() for i in range(n+1)] | |
for i in range(n): | |
for c in range(0,nc): | |
if outputs[i,c]>=threshold: continue | |
signal_fst.AddArc(states[i],i+1,int(c)+1,outputs[i,c],states[i+1]) | |
signal_fst.SetStart(states[0]) | |
signal_fst.SetFinal(states[-1],0.0) | |
return signal_fst | |
# <markdowncell> | |
# Now we compose the two. | |
# <codecell> | |
def shortest_distance(comp,reverse=False): | |
# A wrapper for the ShortestDistance function that returns | |
# a NumPy vector. | |
dist = openfst.vector_logweight() | |
openfst.ShortestDistance(comp,dist,reverse) | |
return array([x.Value() for x in dist]) | |
# <codecell> | |
def compute_time(comp): | |
# Compute a map from states to times. | |
from collections import defaultdict | |
time = defaultdict(set) | |
for state in comp: | |
for x in comp.iterarcs(state): | |
time[state].add(x.ilabel) | |
for state in time.keys(): | |
l = list(time[state]) | |
time[state] = l[0] if len(l)==1 else -1 | |
time[1+max(time.keys())] = -1 | |
return time | |
# <codecell> | |
def compute_transitions(comp,time,gt,dist,rdist): | |
# Compute a table indexed by state pairs and containing | |
# a list of arcs between those state pairs. | |
transitions = defaultdict(list) | |
for state in comp: | |
for x in comp.iterarcs(state): | |
t0 = time[state] | |
t1 = time[x.nextstate] | |
# print (state,x.nextstate),(t0,t1) | |
label = gt[x.olabel-1] | |
lcost = dist[state] | |
tcost = x.weight.Value() | |
rcost = rdist[x.nextstate] | |
cost = lcost+tcost+rcost | |
transitions[(t0,t1)].append((label,cost)) | |
return transitions | |
# <codecell> | |
def arc_posteriors(ts,nc=None): | |
# Given a list of arcs with negative log costs, | |
# compute a posterior distribution. | |
c = array([x[0] for x in ts],'i') | |
if nc is None: nc = amax(c)+1 | |
l = array([x[1] for x in ts]) | |
l -= amin(l) | |
l = -l-log(sum(exp(-l))) | |
return measurements.sum(exp(l),c,range(nc)) | |
# <codecell> | |
def ctc_align(outputs,transcript,threshold=100.0,verbose=0): | |
# Perform CTC-style alignment between a 2D array | |
# representing classifier outputs and a corresponding | |
# vector of transcriptions. This replaces each | |
# element x in the transcript with a pattern _+x+_+ | |
# and then performs forward-backward computations. | |
# It outputs an array in the same shape as classifier | |
# outputs, but updated with the result of the forward-backward | |
# algorithm. | |
n,nc = outputs.shape | |
signal_fst = make_output_fst(outputs,threshold=threshold) | |
assert openfst.Verify(signal_fst) | |
transcript = array(transcript,'i') | |
transcript_fst,gt = make_transcript_fst(transcript) | |
assert openfst.Verify(transcript_fst) | |
comp = openfst.LogVectorFst() | |
openfst.ArcSortOutput(signal_fst) | |
openfst.ArcSortInput(transcript_fst) | |
if verbose: print "compose" | |
openfst.Compose(signal_fst,transcript_fst,comp) | |
openfst.Connect(comp) | |
assert openfst.Verify(comp) | |
if verbose: print "sd1" | |
dist = shortest_distance(comp) | |
if verbose: print "sd2" | |
rdist = shortest_distance(comp,True) | |
if verbose: print "compute time" | |
time = compute_time(comp) | |
if verbose: print "transitions" | |
transitions = compute_transitions(comp,time,gt,dist,rdist) | |
result = [] | |
if verbose: print "posteriors" | |
for i in range(1,n): | |
ps = arc_posteriors(transitions[(i,i+1)],nc+1) | |
result.append(ps) | |
if verbose: print "done" | |
result = array(result) | |
return result[:,1:] | |
# <codecell> | |
if __name__=="__main__": | |
transcript = arange(50,dtype='i')+1 | |
outputs = filters.gaussian_filter(rand(500,51),1.0) | |
ctc = ctc_align(outputs,transcript) | |
figsize(15,4) | |
print outputs.shape,ctc.shape | |
subplot(211); imshow(outputs.T) | |
subplot(212); imshow(ctc.T) | |
figsize(8,8) | |
for i in range(ctc.shape[1]): | |
plot(ctc[:,i]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment