# -*- coding: utf-8 -*- # Copyright 2012 Mirko Hannemann BUT, mirko.hannemann@gmail.com import sys import codecs # for UTF-8/unicode if len(sys.argv) != 2: print 'usage: reverse_arpa arpa.in' sys.exit() arpaname = sys.argv[1] #\data\ #ngram 1=4 #ngram 2=2 #ngram 3=2 # #\1-grams: #-5.234679 a -3.3 #-3.456783 b #0.0000000 -2.5 #-4.333333 # #\2-grams: #-1.45678 a b -3.23 #-1.30490 a -4.2 # #\3-grams: #-0.34958 a b #-0.23940 a b #\end\ # read language model in ARPA format try: file = codecs.open(arpaname, "r", "utf-8") except IOError: print 'file not found: ' + arpaname sys.exit() text=file.readline() while (text and text[:6] != "\\data\\"): text=file.readline() if not text: print "invalid ARPA file" sys.exit() #print text, while (text and text[:5] != "ngram"): text=file.readline() # get ngram counts cngrams=[] n=0 while (text and text[:5] == "ngram"): ind = text.split("=") counts = int(ind[1].strip()) r = ind[0].split() read_n = int(r[1].strip()) if read_n != n+1: print "invalid ARPA file:", text sys.exit() n = read_n cngrams.append(counts) #print text, text=file.readline() # read all n-grams order by order sentprob = 0.0 # sentence begin unigram ngrams=[] inf=float("inf") for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams while (text and "-grams:" not in text): text=file.readline() if n != int(text[1]): print "invalid ARPA file:", text sys.exit() #print text,cngrams[n-1] this_ngrams={} # stores all read ngrams for ng in range(cngrams[n-1]): while (text and len(text.split())<2): text=file.readline() if (not text) or ((len(text.split())==1) and (("-grams:" in text) or (text[:5] == "\\end\\"))): break if (not text) or ((len(text.split())==1) and (("-grams:" in text) or (text[:5] == "\\end\\"))): break # to deal with incorrect ARPA files entry = text.split() prob = float(entry[0]) if len(entry)>n+1: back = float(entry[-1]) words = entry[1:n+1] else: back = 0.0 words = entry[1:] ngram = " ".join(words) if (n==1) and words[0]=="": sentprob = prob prob = 0.0 this_ngrams[ngram] = (prob,back) #print prob,ngram.encode("utf-8"),back for x in range(n-1,0,-1): # add all missing backoff ngrams for reversed lm l_ngram = " ".join(words[:x]) # shortened ngram r_ngram = " ".join(words[1:1+x]) # shortened ngram with offset one if l_ngram not in ngrams[x-1]: # create missing ngram ngrams[x-1][l_ngram] = (0.0,inf) #print ngram, "create 0.0", l_ngram, "inf" if r_ngram not in ngrams[x-1]: # create missing ngram ngrams[x-1][r_ngram] = (0.0,inf) #print ngram, "create 0.0", r_ngram, "inf",x,n,h_ngram # add all missing backoff ngrams for forward lm h_ngram = " ".join(words[n-x:]) # shortened history if h_ngram not in ngrams[x-1]: # create missing ngram ngrams[x-1][h_ngram] = (0.0,inf) #print "create inf", h_ngram, "0.0" text=file.readline() if (not text) or ((len(text.split())==1) and (("-grams:" in text) or (text[:5] == "\\end\\"))): break ngrams.append(this_ngrams) while (text and text[:5] != "\\end\\"): text=file.readline() if not text: print "invalid ARPA file" sys.exit() file.close() #print text, #fourgram "maxent" model (b(ABCD)=0): #p(A)+b(A) A 0 #p(AB)+b(AB)-b(A)-p(B) AB 0 #p(ABC)+b(ABC)-b(AB)-p(BC) ABC 0 #p(ABCD)+b(ABCD)-b(ABC)-p(BCD) ABCD 0 #fourgram reverse ARPA model (b(ABCD)=0): #p(A)+b(A) A 0 #p(AB)+b(AB)-p(B)+p(A) BA 0 #p(ABC)+b(ABC)-p(BC)+p(AB)-p(B)+p(A) CBA 0 #p(ABCD)+b(ABCD)-p(BCD)+p(ABC)-p(BC)+p(AB)-p(B)+p(A) DCBA 0 # compute new reversed ARPA model print "\\data\\" for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams print "ngram "+str(n)+"="+str(len(ngrams[n-1].keys())) offset = 0.0 for n in range(1,len(cngrams)+1): # unigrams, bigrams, trigrams print "\\"+str(n)+"-grams:" keys = ngrams[n-1].keys() keys.sort() for ngram in keys: prob = ngrams[n-1][ngram] # reverse word order words = ngram.split() rstr = " ".join(reversed(words)) # swap and rev_ngram = rstr.replace("","").replace("","").replace("","") revprob = prob[0] if (prob[1] != inf): # only backoff weights from not newly created ngrams revprob = revprob + prob[1] #print prob[0],prob[1] # sum all missing terms in decreasing ngram order for x in range(n-1,0,-1): l_ngram = " ".join(words[:x]) # shortened ngram if l_ngram not in ngrams[x-1]: sys.stderr.write(rev_ngram+": not found "+l_ngram+"\n") p_l = ngrams[x-1][l_ngram][0] #print p_l,l_ngram revprob = revprob + p_l r_ngram = " ".join(words[1:1+x]) # shortened ngram with offset one if r_ngram not in ngrams[x-1]: sys.stderr.write(rev_ngram+": not found "+r_ngram+"\n") p_r = ngrams[x-1][r_ngram][0] #print -p_r,r_ngram revprob = revprob - p_r if n != len(cngrams): #not highest order back = 0.0 if rev_ngram[:3] == "": # special handling since arpa2fst ignores weight if n == 1: offset = revprob # remember weight revprob = sentprob # apply weight from forward model back = offset elif n == 2: revprob = revprob + offset # add weight to bigrams starting with if (prob[1] != inf): # only backoff weights from not newly created ngrams print revprob,rev_ngram.encode("utf-8"),back else: print revprob,rev_ngram.encode("utf-8"),"-100000.0" else: # highest order - no backoff weights if (n==2) and (rev_ngram[:3] == ""): revprob = revprob + offset print revprob,rev_ngram.encode("utf-8") print "\\end\\"