// nnet/nnet-rbm.h // Copyright 2012-2013 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET_NNET_RBM_H_ #define KALDI_NNET_NNET_RBM_H_ #include "nnet/nnet-component.h" #include "nnet/nnet-nnet.h" #include "nnet/nnet-various.h" #include "cudamatrix/cu-math.h" namespace kaldi { namespace nnet1 { class RbmBase : public Component { public: typedef enum { Bernoulli, Gaussian } RbmNodeType; RbmBase(int32 dim_in, int32 dim_out) : Component(dim_in, dim_out) { } /*Is included in Component:: itf virtual void Propagate( const CuMatrix &vis_probs, CuMatrix *hid_probs ) = 0; */ virtual void Reconstruct( const CuMatrix &hid_state, CuMatrix *vis_probs ) = 0; virtual void RbmUpdate( const CuMatrix &pos_vis, const CuMatrix &pos_hid, const CuMatrix &neg_vis, const CuMatrix &neg_hid ) = 0; virtual RbmNodeType VisType() const = 0; virtual RbmNodeType HidType() const = 0; virtual void WriteAsNnet(std::ostream& os, bool binary) const = 0; /// Set training hyper-parameters to the network and its UpdatableComponent(s) void SetRbmTrainOptions(const RbmTrainOptions& opts) { rbm_opts_ = opts; } /// Get training hyper-parameters from the network const RbmTrainOptions& GetRbmTrainOptions() const { return rbm_opts_; } protected: RbmTrainOptions rbm_opts_; //// Make these methods inaccessible for descendants. // private: // For RBMs we use Reconstruct(.) void Backpropagate(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, CuMatrix *in_diff) { } void BackpropagateFnc(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, CuMatrix *in_diff) { } // //// }; class Rbm : public RbmBase { public: Rbm(int32 dim_in, int32 dim_out) : RbmBase(dim_in, dim_out) { } ~Rbm() { } Component* Copy() const { return new Rbm(*this); } ComponentType GetType() const { return kRbm; } void InitData(std::istream &is) { // define options std::string vis_type; std::string hid_type; float vis_bias_mean = 0.0, vis_bias_range = 0.0, hid_bias_mean = 0.0, hid_bias_range = 0.0, param_stddev = 0.1; std::string vis_bias_cmvn_file; // initialize biases to logit(p_active) // parse config std::string token; while (!is.eof()) { ReadToken(is, false, &token); /**/ if (token == "") ReadToken(is, false, &vis_type); else if (token == "") ReadToken(is, false, &hid_type); else if (token == "") ReadBasicType(is, false, &vis_bias_mean); else if (token == "") ReadBasicType(is, false, &vis_bias_range); else if (token == "") ReadBasicType(is, false, &hid_bias_mean); else if (token == "") ReadBasicType(is, false, &hid_bias_range); else if (token == "") ReadBasicType(is, false, ¶m_stddev); else if (token == "") ReadToken(is, false, &vis_bias_cmvn_file); else KALDI_ERR << "Unknown token " << token << " Typo in config?"; is >> std::ws; // eat-up whitespace } // // initialize // if (vis_type == "bern" || vis_type == "Bernoulli") vis_type_ = RbmBase::Bernoulli; else if (vis_type == "gauss" || vis_type == "Gaussian") vis_type_ = RbmBase::Gaussian; else KALDI_ERR << "Wrong " << vis_type; // if (hid_type == "bern" || hid_type == "Bernoulli") hid_type_ = RbmBase::Bernoulli; else if (hid_type == "gauss" || hid_type == "Gaussian") hid_type_ = RbmBase::Gaussian; else KALDI_ERR << "Wrong " << hid_type; // visible-hidden connections Matrix mat(output_dim_, input_dim_); for (int32 r=0; r vec(output_dim_); for (int32 i=0; i vec2(input_dim_); for (int32 i=0; i " << vis_bias_cmvn_file; Nnet cmvn; cmvn.Read(vis_bias_cmvn_file); // getting probablity that neuron fires: Vector p(dynamic_cast(cmvn.GetComponent(0)).GetShiftVec()); p.Scale(-1.0); // compute logit: Vector logit_p(p.Dim()); for(int32 d = 0; d < p.Dim(); d++) { if(p(d) < 0.0001) p(d) = 0.0001; if(p(d) > 0.9999) p(d) = 0.9999; logit_p(d) = log(p(d)) - log(1.0 - p(d)); } vis_bias_ = logit_p; KALDI_ASSERT(vis_bias_.Dim() == InputDim()); } // } void ReadData(std::istream &is, bool binary) { std::string vis_node_type, hid_node_type; ReadToken(is, binary, &vis_node_type); ReadToken(is, binary, &hid_node_type); if(vis_node_type == "bern") { vis_type_ = RbmBase::Bernoulli; } else if(vis_node_type == "gauss") { vis_type_ = RbmBase::Gaussian; } if(hid_node_type == "bern") { hid_type_ = RbmBase::Bernoulli; } else if(hid_node_type == "gauss") { hid_type_ = RbmBase::Gaussian; } vis_hid_.Read(is, binary); vis_bias_.Read(is, binary); hid_bias_.Read(is, binary); KALDI_ASSERT(vis_hid_.NumRows() == output_dim_); KALDI_ASSERT(vis_hid_.NumCols() == input_dim_); KALDI_ASSERT(vis_bias_.Dim() == input_dim_); KALDI_ASSERT(hid_bias_.Dim() == output_dim_); } void WriteData(std::ostream &os, bool binary) const { switch (vis_type_) { case Bernoulli : WriteToken(os,binary,"bern"); break; case Gaussian : WriteToken(os,binary,"gauss"); break; default : KALDI_ERR << "Unknown type " << vis_type_; } switch (hid_type_) { case Bernoulli : WriteToken(os,binary,"bern"); break; case Gaussian : WriteToken(os,binary,"gauss"); break; default : KALDI_ERR << "Unknown type " << hid_type_; } vis_hid_.Write(os, binary); vis_bias_.Write(os, binary); hid_bias_.Write(os, binary); } // Component API void PropagateFnc(const CuMatrix &in, CuMatrix *out) { // precopy bias out->AddVecToRows(1.0, hid_bias_, 0.0); // multiply by weights^t out->AddMatMat(1.0, in, kNoTrans, vis_hid_, kTrans, 1.0); // optionally apply sigmoid if (hid_type_ == RbmBase::Bernoulli) { out->Sigmoid(*out); } } // RBM training API void Reconstruct(const CuMatrix &hid_state, CuMatrix *vis_probs) { // check the dim if (output_dim_ != hid_state.NumCols()) { KALDI_ERR << "Nonmatching dims, component:" << output_dim_ << " data:" << hid_state.NumCols(); } // optionally allocate buffer if (input_dim_ != vis_probs->NumCols() || hid_state.NumRows() != vis_probs->NumRows()) { vis_probs->Resize(hid_state.NumRows(), input_dim_); } // precopy bias vis_probs->AddVecToRows(1.0, vis_bias_, 0.0); // multiply by weights vis_probs->AddMatMat(1.0, hid_state, kNoTrans, vis_hid_, kNoTrans, 1.0); // optionally apply sigmoid if (vis_type_ == RbmBase::Bernoulli) { vis_probs->Sigmoid(*vis_probs); } } void RbmUpdate(const CuMatrix &pos_vis, const CuMatrix &pos_hid, const CuMatrix &neg_vis, const CuMatrix &neg_hid) { KALDI_ASSERT(pos_vis.NumRows() == pos_hid.NumRows() && pos_vis.NumRows() == neg_vis.NumRows() && pos_vis.NumRows() == neg_hid.NumRows() && pos_vis.NumCols() == neg_vis.NumCols() && pos_hid.NumCols() == neg_hid.NumCols() && pos_vis.NumCols() == input_dim_ && pos_hid.NumCols() == output_dim_); //lazy initialization of buffers if ( vis_hid_corr_.NumRows() != vis_hid_.NumRows() || vis_hid_corr_.NumCols() != vis_hid_.NumCols() || vis_bias_corr_.Dim() != vis_bias_.Dim() || hid_bias_corr_.Dim() != hid_bias_.Dim() ){ vis_hid_corr_.Resize(vis_hid_.NumRows(),vis_hid_.NumCols(),kSetZero); //vis_bias_corr_.Resize(vis_bias_.Dim(),kSetZero); //hid_bias_corr_.Resize(hid_bias_.Dim(),kSetZero); vis_bias_corr_.Resize(vis_bias_.Dim()); hid_bias_corr_.Resize(hid_bias_.Dim()); } // // ANTI-WEIGHT-EXPLOSION PROTECTION // in the following section we detect that the weights in Gaussian-Bernoulli RBM // are about to explode. The weight explosion is caused by large variance of the // reconstructed data, which causes increase of weight variance towards the explosion. // // To avoid explosion, the variance of the visible-data and reconstructed-data // should be about the same. The model is particularly sensitive at the very // beginning of the CD-1 training. // // We compute variance of a)input mini-batch b)reconstruction. // When the ratio b)/a) is larger than 2, we: // 1. scale down the weights and biases by b)/a) (for next mini-batch b)/a) gets 1.0) // 2. shrink learning rate by 0.9x // 3. reset the momentum buffer // // Wa also display a warning. Note that in later stage // the training returns back to higher learning rate. // if (vis_type_ == RbmBase::Gaussian) { //get the standard deviations of pos_vis and neg_vis data //pos_vis CuMatrix pos_vis_pow2(pos_vis); pos_vis_pow2.MulElements(pos_vis); CuVector pos_vis_second(pos_vis.NumCols()); pos_vis_second.AddRowSumMat(1.0,pos_vis_pow2,0.0); CuVector pos_vis_mean(pos_vis.NumCols()); pos_vis_mean.AddRowSumMat(1.0/pos_vis.NumRows(),pos_vis,0.0); Vector pos_vis_second_h(pos_vis_second.Dim()); pos_vis_second.CopyToVec(&pos_vis_second_h); Vector pos_vis_mean_h(pos_vis_mean.Dim()); pos_vis_mean.CopyToVec(&pos_vis_mean_h); Vector pos_vis_stddev(pos_vis_mean_h); pos_vis_stddev.MulElements(pos_vis_mean_h); pos_vis_stddev.Scale(-1.0); pos_vis_stddev.AddVec(1.0/pos_vis.NumRows(),pos_vis_second_h); /* set negative values to zero before the square root */ for (int32 i=0; i neg_vis_pow2(neg_vis); neg_vis_pow2.MulElements(neg_vis); CuVector neg_vis_second(neg_vis.NumCols()); neg_vis_second.AddRowSumMat(1.0,neg_vis_pow2,0.0); CuVector neg_vis_mean(neg_vis.NumCols()); neg_vis_mean.AddRowSumMat(1.0/neg_vis.NumRows(),neg_vis,0.0); Vector neg_vis_second_h(neg_vis_second.Dim()); neg_vis_second.CopyToVec(&neg_vis_second_h); Vector neg_vis_mean_h(neg_vis_mean.Dim()); neg_vis_mean.CopyToVec(&neg_vis_mean_h); Vector neg_vis_stddev(neg_vis_mean_h); neg_vis_stddev.MulElements(neg_vis_mean_h); neg_vis_stddev.Scale(-1.0); neg_vis_stddev.AddVec(1.0/neg_vis.NumRows(),neg_vis_second_h); /* set negative values to zero before the square root */ for (int32 i=0; i(pos_vis.NumRows()); vis_hid_corr_.AddMatMat(-lr/N, neg_hid, kTrans, neg_vis, kNoTrans, mmt); vis_hid_corr_.AddMatMat(+lr/N, pos_hid, kTrans, pos_vis, kNoTrans, 1.0); vis_hid_corr_.AddMat(-lr*l2, vis_hid_, 1.0); vis_hid_.AddMat(1.0, vis_hid_corr_, 1.0); // UPDATE visbias vector // // visbiasinc = momentum*visbiasinc + (epsilonvb/numcases)*(posvisact-negvisact); // vis_bias_corr_.AddRowSumMat(-lr/N, neg_vis, mmt); vis_bias_corr_.AddRowSumMat(+lr/N, pos_vis, 1.0); vis_bias_.AddVec(1.0, vis_bias_corr_, 1.0); // UPDATE hidbias vector // // hidbiasinc = momentum*hidbiasinc + (epsilonhb/numcases)*(poshidact-neghidact); // hid_bias_corr_.AddRowSumMat(-lr/N, neg_hid, mmt); hid_bias_corr_.AddRowSumMat(+lr/N, pos_hid, 1.0); hid_bias_.AddVec(1.0, hid_bias_corr_, 1.0); } RbmNodeType VisType() const { return vis_type_; } RbmNodeType HidType() const { return hid_type_; } void WriteAsNnet(std::ostream& os, bool binary) const { //header WriteToken(os,binary,Component::TypeToMarker(Component::kAffineTransform)); WriteBasicType(os,binary,OutputDim()); WriteBasicType(os,binary,InputDim()); if(!binary) os << "\n"; //data vis_hid_.Write(os,binary); hid_bias_.Write(os,binary); //optionally sigmoid activation if(HidType() == Bernoulli) { WriteToken(os,binary,Component::TypeToMarker(Component::kSigmoid)); WriteBasicType(os,binary,OutputDim()); WriteBasicType(os,binary,OutputDim()); } if(!binary) os << "\n"; } protected: CuMatrix vis_hid_; ///< Matrix with neuron weights CuVector vis_bias_; ///< Vector with biases CuVector hid_bias_; ///< Vector with biases CuMatrix vis_hid_corr_; ///< Matrix for linearity updates CuVector vis_bias_corr_; ///< Vector for bias updates CuVector hid_bias_corr_; ///< Vector for bias updates RbmNodeType vis_type_; RbmNodeType hid_type_; }; } // namespace nnet1 } // namespace kaldi #endif