// nnet/nnet-affine-transform.h // Copyright 2011-2014 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET_NNET_AFFINE_TRANSFORM_H_ #define KALDI_NNET_NNET_AFFINE_TRANSFORM_H_ #include "nnet/nnet-component.h" #include "nnet/nnet-various.h" #include "cudamatrix/cu-math.h" namespace kaldi { namespace nnet1 { class AffineTransform : public UpdatableComponent { public: AffineTransform(int32 dim_in, int32 dim_out) : UpdatableComponent(dim_in, dim_out), linearity_(dim_out, dim_in), bias_(dim_out), linearity_corr_(dim_out, dim_in), bias_corr_(dim_out) { } ~AffineTransform() { } Component* Copy() const { return new AffineTransform(*this); } ComponentType GetType() const { return kAffineTransform; } void InitData(std::istream &is) { // define options float bias_mean = -2.0, bias_range = 2.0, param_stddev = 0.1; // parse config std::string token; while (!is.eof()) { ReadToken(is, false, &token); /**/ if (token == "") ReadBasicType(is, false, ¶m_stddev); else if (token == "") ReadBasicType(is, false, &bias_mean); else if (token == "") ReadBasicType(is, false, &bias_range); else KALDI_ERR << "Unknown token " << token << ", a typo in config?" << " (ParamStddev|BiasMean|BiasRange)"; is >> std::ws; // eat-up whitespace } // // initialize // Matrix mat(output_dim_, input_dim_); for (int32 r=0; r vec(output_dim_); for (int32 i=0; i* wei_copy) const { wei_copy->Resize(NumParams()); int32 linearity_num_elem = linearity_.NumRows() * linearity_.NumCols(); wei_copy->Range(0,linearity_num_elem).CopyRowsFromMat(Matrix(linearity_)); wei_copy->Range(linearity_num_elem, bias_.Dim()).CopyFromVec(Vector(bias_)); } std::string Info() const { return std::string("\n linearity") + MomentStatistics(linearity_) + "\n bias" + MomentStatistics(bias_); } std::string InfoGradient() const { return std::string("\n linearity_grad") + MomentStatistics(linearity_corr_) + "\n bias_grad" + MomentStatistics(bias_corr_); } void PropagateFnc(const CuMatrix &in, CuMatrix *out) { // precopy bias out->AddVecToRows(1.0, bias_, 0.0); // multiply by weights^t out->AddMatMat(1.0, in, kNoTrans, linearity_, kTrans, 1.0); } void BackpropagateFnc(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, CuMatrix *in_diff) { // multiply error derivative by weights in_diff->AddMatMat(1.0, out_diff, kNoTrans, linearity_, kNoTrans, 0.0); } void Update(const CuMatrix &input, const CuMatrix &diff) { // we use following hyperparameters from the option class const BaseFloat lr = opts_.learn_rate; const BaseFloat mmt = opts_.momentum; const BaseFloat l2 = opts_.l2_penalty; const BaseFloat l1 = opts_.l1_penalty; // we will also need the number of frames in the mini-batch const int32 num_frames = input.NumRows(); // compute gradient (incl. momentum) linearity_corr_.AddMatMat(1.0, diff, kTrans, input, kNoTrans, mmt); bias_corr_.AddRowSumMat(1.0, diff, mmt); // l2 regularization if (l2 != 0.0) { linearity_.AddMat(-lr*l2*num_frames, linearity_); } // l1 regularization if (l1 != 0.0) { cu::RegularizeL1(&linearity_, &linearity_corr_, lr*l1*num_frames, lr); } // update linearity_.AddMat(-lr, linearity_corr_); bias_.AddVec(-lr, bias_corr_); } /// Accessors to the component parameters const CuVector& GetBias() { return bias_; } void SetBias(const CuVector& bias) { KALDI_ASSERT(bias.Dim() == bias_.Dim()); bias_.CopyFromVec(bias); } const CuMatrix& GetLinearity() { return linearity_; } void SetLinearity(const CuMatrix& linearity) { KALDI_ASSERT(linearity.NumRows() == linearity_.NumRows()); KALDI_ASSERT(linearity.NumCols() == linearity_.NumCols()); linearity_.CopyFromMat(linearity); } const CuVector& GetBiasCorr() { return bias_corr_; } const CuMatrix& GetLinearityCorr() { return linearity_corr_; } private: CuMatrix linearity_; CuVector bias_; CuMatrix linearity_corr_; CuVector bias_corr_; }; } // namespace nnet1 } // namespace kaldi #endif