// fstext/fstext-utils-test.cc // Copyright 2009-2012 Microsoft Corporation Daniel Povey // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "base/kaldi-common.h" // for exceptions #include "fstext/fstext-utils.h" #include "fstext/fst-test-utils.h" #include "util/stl-utils.h" namespace fst { template void TestMakeLinearAcceptor() { typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; int len = rand() % 10; vector vec; vector vec_nozeros; for (int i = 0; i < len; i++) { int j = rand() % len; vec.push_back(j); if (j != 0) vec_nozeros.push_back(j); } VectorFst vfst; MakeLinearAcceptor(vec, &vfst); vector vec2; vector vec3; Weight w; GetLinearSymbolSequence(vfst, &vec2, &vec3, &w); assert(w == Weight::One()); assert(vec_nozeros == vec2); assert(vec_nozeros == vec3); if (vec2.size() != 0 || vec3.size() != 0) { // This test might not work // for empty sequences... { vector > vecs2; vector > vecs3; vector ws; GetLinearSymbolSequences(vfst, &vecs2, &vecs3, &ws); assert(vecs2.size() == 1); assert(vecs2[0] == vec2); assert(vecs3[0] == vec3); assert(ApproxEqual(ws[0], w)); } { vector > fstvec; NbestAsFsts(vfst, 1, &fstvec); KALDI_ASSERT(fstvec.size() == 1); assert(RandEquivalent(vfst, fstvec[0], 2/*paths*/, 0.01/*delta*/, rand()/*seed*/, 100/*path length-- max?*/)); } } bool include_eps = (rand() % 2 == 0); if (!include_eps) vec = vec_nozeros; kaldi::SortAndUniq(&vec); vector vec4; GetInputSymbols(vfst, include_eps, &vec4); assert(vec4 == vec); vector vec5; GetInputSymbols(vfst, include_eps, &vec5); } template void TestDeterminizeStarInLog() { VectorFst *fst = RandFst(); VectorFst fst_copy(fst); typename Arc::Label next_sym = 1 + HighestNumberedInputSymbol(*fst); vector syms; PreDeterminize(fst, NULL, "#", next_sym, &syms); } // Don't instantiate with log semiring, as RandEquivalent may fail. template void TestSafeDeterminizeWrapper() { // also tests SafeDeterminizeMinimizeWrapper(). typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; VectorFst *fst = new VectorFst(); int n_syms = 2 + rand() % 5, n_states = 3 + rand() % 10, n_arcs = 5 + rand() % 30, n_final = 1 + rand()%3; // Up to 2 unique symbols. cout << "Testing pre-determinize with "< all_syms; // including epsilon. // Put symbols in the symbol table from 1..n_syms-1. for (size_t i = 0;i < (size_t)n_syms;i++) { std::stringstream ss; if (i == 0) ss << ""; else ss<AddSymbol(ss.str()); assert(cur_lab == (Label)i); all_syms.push_back(cur_lab); } assert(all_syms[0] == 0); // Create states. vector all_states; for (size_t i = 0;i < (size_t)n_states;i++) { StateId this_state = fst->AddState(); if (i == 0) fst->SetStart(i); all_states.push_back(this_state); } // Set final states. for (size_t j = 0;j < (size_t)n_final;j++) { StateId id = all_states[rand() % n_states]; Weight weight = (Weight)(0.33*(rand() % 5) ); printf("calling SetFinal with %d and %f\n", id, weight.Value()); fst->SetFinal(id, weight); } // Create arcs. for (size_t i = 0;i < (size_t)n_arcs;i++) { Arc a; a.nextstate = all_states[rand() % n_states]; a.ilabel = all_syms[rand() % n_syms]; a.olabel = all_syms[rand() % n_syms]; // same input+output vocab. a.weight = (Weight) (0.33*(rand() % 2)); StateId start_state = all_states[rand() % n_states]; fst->AddArc(start_state, a); } std::cout <<" printing before trimming\n"; { FstPrinter fstprinter(*fst, sptr, sptr, NULL, false, true); fstprinter.Print(&std::cout, "standard output"); } // Trim resulting FST. Connect(fst); std::cout <<" printing after trimming\n"; { FstPrinter fstprinter(*fst, sptr, sptr, NULL, false, true); fstprinter.Print(&std::cout, "standard output"); } VectorFst *fst_copy_orig = new VectorFst(*fst); VectorFst *fst_det = new VectorFst; vector