// decoder/faster-decoder.h // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_DECODER_FASTER_DECODER_H_ #define KALDI_DECODER_FASTER_DECODER_H_ #include "util/stl-utils.h" #include "itf/options-itf.h" #include "util/hash-list.h" #include "fst/fstlib.h" #include "itf/decodable-itf.h" #include "lat/kaldi-lattice.h" // for CompactLatticeArc #ifdef _MSC_VER #include #else #include #endif using std::tr1::unordered_map; namespace kaldi { struct FasterDecoderOptions { BaseFloat beam; int32 max_active; int32 min_active; BaseFloat beam_delta; BaseFloat hash_ratio; FasterDecoderOptions(): beam(16.0), max_active(std::numeric_limits::max()), min_active(20), // This decoder mostly used for // alignment, use small default. beam_delta(0.5), hash_ratio(2.0) { } void Register(OptionsItf *po, bool full) { /// if "full", use obscure /// options too. /// Depends on program. po->Register("beam", &beam, "Decoder beam"); po->Register("max-active", &max_active, "Decoder max active states."); po->Register("min-active", &min_active, "Decoder min active states (don't prune if #active less than this)."); if (full) { po->Register("beam-delta", &beam_delta, "Increment used in decoder [obscure setting]"); po->Register("hash-ratio", &hash_ratio, "Setting used in decoder to control hash behavior"); } } }; class FasterDecoder { public: typedef fst::StdArc Arc; typedef Arc::Label Label; typedef Arc::StateId StateId; typedef Arc::Weight Weight; FasterDecoder(const fst::Fst &fst, const FasterDecoderOptions &config); void SetOptions(const FasterDecoderOptions &config) { config_ = config; } ~FasterDecoder() { ClearToks(toks_.Clear()); } void Decode(DecodableInterface *decodable); bool ReachedFinal(); bool GetBestPath(fst::MutableFst *fst_out); protected: class Token { public: Arc arc_; // contains only the graph part of the cost; // we can work out the acoustic part from difference between // "weight_" and prev->weight_. Token *prev_; int32 ref_count_; Weight weight_; // weight up to current point. inline Token(const Arc &arc, Weight &ac_weight, Token *prev): arc_(arc), prev_(prev), ref_count_(1) { if (prev) { prev->ref_count_++; weight_ = Times(Times(prev->weight_, arc.weight), ac_weight); } else { weight_ = Times(arc.weight, ac_weight); } } inline Token(const Arc &arc, Token *prev): arc_(arc), prev_(prev), ref_count_(1) { if (prev) { prev->ref_count_++; weight_ = Times(prev->weight_, arc.weight); } else { weight_ = arc.weight; } } inline bool operator < (const Token &other) { return weight_.Value() > other.weight_.Value(); // This makes sense for log + tropical semiring. } inline static void TokenDelete(Token *tok) { while (--tok->ref_count_ == 0) { Token *prev = tok->prev_; delete tok; if (prev == NULL) return; else tok = prev; } #ifdef KALDI_PARANOID KALDI_ASSERT(tok->ref_count_ > 0); #endif } }; typedef HashList::Elem Elem; /// Gets the weight cutoff. Also counts the active tokens. BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem); void PossiblyResizeHash(size_t num_toks); // ProcessEmitting returns the likelihood cutoff used. BaseFloat ProcessEmitting(DecodableInterface *decodable, int frame); // TODO: first time we go through this, could avoid using the queue. void ProcessNonemitting(BaseFloat cutoff); // HashList defined in ../util/hash-list.h. It actually allows us to maintain // more than one list (e.g. for current and previous frames), but only one of // them at a time can be indexed by StateId. HashList toks_; const fst::Fst &fst_; FasterDecoderOptions config_; std::vector queue_; // temp variable used in ProcessNonemitting, std::vector tmp_array_; // used in GetCutoff. // make it class member to avoid internal new/delete. // It might seem unclear why we call ClearToks(toks_.Clear()). // There are two separate cleanup tasks we need to do at when we start a new file. // one is to delete the Token objects in the list; the other is to delete // the Elem objects. toks_.Clear() just clears them from the hash and gives ownership // to the caller, who then has to call toks_.Delete(e) for each one. It was designed // this way for convenience in propagating tokens from one frame to the next. void ClearToks(Elem *list); KALDI_DISALLOW_COPY_AND_ASSIGN(FasterDecoder); }; } // end namespace kaldi. #endif