// nnet/nnet-example.cc
// Copyright 2012-2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet2/nnet-example.h"
#include "lat/lattice-functions.h"
#include "hmm/posterior.h"
namespace kaldi {
namespace nnet2 {
void NnetExample::Write(std::ostream &os, bool binary) const {
// Note: weight, label, input_frames and spk_info are members. This is a
// struct.
WriteToken(os, binary, "");
WriteToken(os, binary, "");
int32 size = labels.size();
WriteBasicType(os, binary, size);
for (int32 i = 0; i < size; i++) {
WriteBasicType(os, binary, labels[i].first);
WriteBasicType(os, binary, labels[i].second);
}
WriteToken(os, binary, "");
input_frames.Write(os, binary); // can be read as regular Matrix.
WriteToken(os, binary, "");
WriteBasicType(os, binary, left_context);
WriteToken(os, binary, "");
spk_info.Write(os, binary);
WriteToken(os, binary, "");
}
void NnetExample::Read(std::istream &is, bool binary) {
// Note: weight, label, input_frames, left_context and spk_info are members.
// This is a struct.
ExpectToken(is, binary, "");
ExpectToken(is, binary, "");
int32 size;
ReadBasicType(is, binary, &size);
labels.resize(size);
for (int32 i = 0; i < size; i++) {
ReadBasicType(is, binary, &(labels[i].first));
ReadBasicType(is, binary, &(labels[i].second));
}
ExpectToken(is, binary, "");
input_frames.Read(is, binary);
ExpectToken(is, binary, ""); // Note: this member is
// recently added, but I don't think we'll get too much back-compatibility
// problems from not handling the old format.
ReadBasicType(is, binary, &left_context);
ExpectToken(is, binary, "");
spk_info.Read(is, binary);
ExpectToken(is, binary, "");
}
void DiscriminativeNnetExample::Write(std::ostream &os,
bool binary) const {
// Note: weight, num_ali, den_lat, input_frames, left_context and spk_info are
// members. This is a struct.
WriteToken(os, binary, "");
WriteToken(os, binary, "");
WriteBasicType(os, binary, weight);
WriteToken(os, binary, "");
WriteIntegerVector(os, binary, num_ali);
if (!WriteCompactLattice(os, binary, den_lat)) {
// We can't return error status from this function so we
// throw an exception.
KALDI_ERR << "Error writing CompactLattice to stream";
}
WriteToken(os, binary, "");
{
CompressedMatrix cm(input_frames); // Note: this can be read as a regular
// matrix.
cm.Write(os, binary);
}
WriteToken(os, binary, "");
WriteBasicType(os, binary, left_context);
WriteToken(os, binary, "");
spk_info.Write(os, binary);
WriteToken(os, binary, "");
}
void DiscriminativeNnetExample::Read(std::istream &is,
bool binary) {
// Note: weight, num_ali, den_lat, input_frames, left_context and spk_info are
// members. This is a struct.
ExpectToken(is, binary, "");
ExpectToken(is, binary, "");
ReadBasicType(is, binary, &weight);
ExpectToken(is, binary, "");
ReadIntegerVector(is, binary, &num_ali);
CompactLattice *den_lat_tmp = NULL;
if (!ReadCompactLattice(is, binary, &den_lat_tmp) || den_lat_tmp == NULL) {
// We can't return error status from this function so we
// throw an exception.
KALDI_ERR << "Error reading CompactLattice from stream";
}
den_lat = *den_lat_tmp;
delete den_lat_tmp;
ExpectToken(is, binary, "");
input_frames.Read(is, binary);
ExpectToken(is, binary, "");
ReadBasicType(is, binary, &left_context);
ExpectToken(is, binary, "");
spk_info.Read(is, binary);
ExpectToken(is, binary, "");
}
void DiscriminativeNnetExample::Check() const {
KALDI_ASSERT(weight > 0.0);
KALDI_ASSERT(!num_ali.empty());
int32 num_frames = static_cast(num_ali.size());
std::vector times;
int32 num_frames_den = CompactLatticeStateTimes(den_lat, ×);
KALDI_ASSERT(num_frames == num_frames_den);
KALDI_ASSERT(input_frames.NumRows() >= left_context + num_frames);
}
} // namespace nnet2
} // namespace kaldi