// cudamatrix/cu-matrix.h // Copyright 2009-2012 Karel Vesely // 2013 Johns Hopkins University (author: Daniel Povey) // 2013 Hainan Xu // 2013 Xiaohui Zhang // 2013 Johns Hopkins University (author: Guoguo Chen) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_CUDAMATRIX_CU_MATRIX_H_ #define KALDI_CUDAMATRIX_CU_MATRIX_H_ #include #include "cudamatrix/cu-matrixdim.h" #include "cudamatrix/cu-common.h" #include "cudamatrix/cu-value.h" #include "matrix/matrix-common.h" #include "matrix/kaldi-matrix.h" #include "cudamatrix/cu-array.h" #include "cudamatrix/cu-math.h" #include "cudamatrix/cu-rand.h" namespace kaldi { template Real TraceMatMat(const CuMatrixBase &A, const CuMatrixBase &B, MatrixTransposeType trans = kNoTrans); /** * Matrix for CUDA computing. * Does the computation on the CUDA card when CUDA is compiled in and * we have a suitable GPU (CuDevice::Instantiate().Enabled() == true); * otherwise, does it on the CPU. */ /* template struct MatrixElement { int row; int column; Real weight; }; // */ template class CuMatrixBase { public: friend class CuMatrixBase; friend class CuMatrixBase; friend class CuVectorBase; friend class CuVectorBase; friend class VectorBase; friend class CuSpMatrix; friend class CuTpMatrix; friend class CuTpMatrix; friend class CuVectorBase; friend class CuSubMatrix; friend class CuRand; friend class CuSubVector; friend class CuBlockMatrix; friend void cu::RegularizeL1(CuMatrixBase *weight, CuMatrixBase *grad, Real l1, Real lr); friend void cu::Splice(const CuMatrix &src, const CuArray &frame_offsets, CuMatrix *tgt); friend void cu::Copy(const CuMatrix &src, const CuArray ©_from_indices, CuMatrix *tgt); friend void cu::Randomize(const CuMatrixBase &src, const CuArray ©_from_idx, CuMatrixBase *tgt); /// Copies column r from column indices[r] of src. /// As a special case, if indexes[i] == -1, sets column i to zero /// indices.size() must equal this->NumCols(), /// all elements of "reorder" must be in [-1, src.NumCols()-1], /// and src.NumRows() must equal this.NumRows() void CopyCols(const CuMatrixBase &src, const std::vector &indices); /// Version of CopyCols that takes CuArray argument. void CopyCols(const CuMatrixBase &src, const CuArray &indices); /// Copies row r from row indices[r] of src. /// As a special case, if indexes[i] <== -1, sets row i to zero /// "reorder".size() must equal this->NumRows(), /// all elements of "reorder" must be in [0, src.NumRows()-1], /// and src.NumCols() must equal this.NumCols() void CopyRows(const CuMatrixBase &src, const std::vector &indices); /// For each row r of this and for each column c, sets (*this)(r, c) to the /// sum \sum_j src(r, j), where j ranges from indices[c].first through /// indices[c].second - 1. void SumColumnRanges(const CuMatrixBase &src, const CuArray &indices); friend Real TraceMatMat(const CuMatrixBase &A, const CuMatrixBase &B, MatrixTransposeType trans); void AddToDiag(Real value); /// Dimensions MatrixIndexT NumRows() const { return num_rows_; } MatrixIndexT NumCols() const { return num_cols_; } MatrixIndexT Stride() const { return stride_; } // MatrixDim is a struct containing "rows", "cols" and "stride", // that is an argument of most CUDA kernels. ::MatrixDim Dim() const { ::MatrixDim d = { num_rows_, num_cols_, stride_ }; return d; } Real FrobeniusNorm() const { return sqrt(TraceMatMat(*this, *this, kTrans)); } bool IsUnit(Real tol = 0.001) const; /// True if ((*this)-other).FrobeniusNorm() <= tol * this->FrobeniusNorm() bool ApproxEqual(const CuMatrixBase &other, float tol = 0.01) const; /// Get size of matrix in bytes MatrixIndexT SizeInBytes() const { return num_rows_*stride_*sizeof(Real); } // Copy functions. These do not resize. template void CopyFromMat(const MatrixBase &src, MatrixTransposeType trans = kNoTrans); void CopyFromMat(const MatrixBase &src, MatrixTransposeType trans = kNoTrans); void CopyFromSp(const CuSpMatrix &M); template void CopyFromTp(const CuTpMatrix &M, MatrixTransposeType trans = kNoTrans); template void CopyFromMat(const CuMatrixBase &M, MatrixTransposeType trans = kNoTrans); template void CopyToMat(MatrixBase *dst, MatrixTransposeType trans = kNoTrans) const; void CopyRowsFromVec(const CuVectorBase &v); void CopyRowsFromVec(const VectorBase &v); /// Copy vector into specific column of matrix. void CopyColFromVec(const CuVectorBase &v, const MatrixIndexT col); /// Set each element to the sigmoid of the corresponding element of "src": /// element by element, x = 1 / (1 + exp(-x)) void Sigmoid(const CuMatrixBase &src); /// Apply the function y = log(1 + exp(x)), to each element. /// Note: the derivative of this function is the sigmoid function. /// This is like a soft ReLU. void SoftHinge(const CuMatrixBase &src); /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j ^ (power)) ^ (1 / p) /// where G = x.NumCols() / y.NumCols() must be an integer. void GroupPnorm(const CuMatrixBase &src, Real pow); /// Calculate derivatives for the GroupPnorm function above... /// if "input" is the input to the GroupPnorm function above (i.e. the "src" variable), /// and "output" is the result of the computation (i.e. the "this" of that function /// call), and *this has the same dimension as "input", then it sets each element /// of *this to the derivative d(output-elem)/d(input-elem) for each element of "input", where /// "output-elem" is whichever element of output depends on that input element. void GroupPnormDeriv(const CuMatrixBase &input, const CuMatrixBase &output, Real power); /// Compute the hyperbolic tangent (tanh) function; element by element, /// *this = tanh(src). void Tanh(const CuMatrixBase &src); /// Differentiate backward through the sigmoid function. Here, "value" is the /// sigmoid output. Does, element-by-element, *this = diff * value * (1 - value). void DiffSigmoid(const CuMatrixBase &value, const CuMatrixBase &diff); /// Differentiate backward through the tanh function. Here, "value" is the /// tanh output. Does, element-by-element, *this = diff * (1 - value^2). void DiffTanh(const CuMatrixBase &value, const CuMatrixBase &diff); /// Differentiate the block [softmax+cross-entropy] : /// dE/da = posterior_mat - target_mat, /// 'E' is error function, 'a' is activation on softmax input /// /// Interface: /// tgt ... index vector, encodes the matrix of targets /// net_out_or_diff ... before invocation net output, after diff dE/da /// log_post_tgt ... per-frame statistics for cross-entropy computations : /// log(sum_row(posterior_mat .* target_mat)) void DiffXent(const CuArray &tgt, CuVector *log_post_tgt); /// This method may be only called for symmetric matrices (it accesses the /// upper as well as lower triangle). The result is put in the lower /// triangle, and the upper triangle zeroed. void Cholesky(); void SymInvertPosDef(); ///< Inversion for positive definite symmetric matrices. ///< Requires that the input is symmetric (we do not check this). ///< The output is symmetric. void ApplyPow(Real power); void ApplyHeaviside(); ///< For each element, sets x = (x > 0 ? 1.0 : 0.0) void ApplyFloor(Real floor_val); void ApplyCeiling(Real ceiling_val); void ApplyExp(); /// Softmax nonlinearity /// Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row /// for each row, the max value is first subtracted for good numerical stability void ApplySoftMaxPerRow(const CuMatrixBase &src); /// Find the id of the maximal element for each row void FindRowMaxId(CuArray *id) const; /* // Copy row interval from matrix // @param r [in] number of rows to copy. // @param src [in] source matrix. // @param src_ro [in] source matrix row offset. // @param dst_ro [in] destination matrix row offset. // void CopyRowsFromMat(int32 r, const CuMatrixBase &src, int32 src_ro, int32 dst_ro); */ /// Math operations, some calling kernels void SetZero(); void Set(Real value); void Add(Real value); void SetZeroUpperDiag(); void Scale(Real value); void ApplyLog(); /// Multiply two matrices elementwise: C = A .* C void MulElements(const CuMatrixBase &A); /// Do, elementwise, *this = max(*this, A). void Max(const CuMatrixBase &A); /// scale i'th column by scale[i] void MulColsVec(const CuVectorBase &scale); /// scale i'th row by scale[i] void MulRowsVec(const CuVectorBase &scale); /// divide each row into src.NumCols() groups, and then scale i'th row's jth group of elements by src[i, j]. void MulRowsGroupMat(const CuMatrixBase &src); /// divide i'th row by scale[i] void DivRowsVec(const CuVectorBase &div); /// invert the matrix by elements. void InvertElements(); /// B = aplha * A + beta * B void AddMat(Real alpha, const CuMatrixBase &A, Real beta=1.0); /// B = aplha * row + beta * B void AddVecToCols(Real alpha, const CuVectorBase &col, Real beta = 1.0); /// B = aplha * row + beta * B void AddVecToRows(Real alpha, const CuVectorBase &row, Real beta = 1.0); /// C = alpha * A(^T)*B(^T) + beta * C void AddMatMat(Real alpha, const CuMatrixBase &A, MatrixTransposeType transA, const CuMatrixBase &B, MatrixTransposeType transB, Real beta); /// *this = a * b / c (by element; when c = 0, *this = a) void AddMatMatDivMat(const CuMatrixBase &A, const CuMatrixBase &B, const CuMatrixBase &C); /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only /// updates the lower triangle of *this. It will leave the matrix asymmetric; /// if you need it symmetric as a regular matrix, do CopyLowerToUpper(). void SymAddMat2(const Real alpha, const CuMatrixBase &M, MatrixTransposeType transA, Real beta); /// This function is like AddMatMat but for where the second argument is of /// type CuBlockMatrix (a block-diagonal matrix of blocks). void AddMatBlock(Real alpha, const CuMatrixBase &A, MatrixTransposeType transA, const CuBlockMatrix &B, MatrixTransposeType transB, Real beta); /// *this = beta * *this + alpha * diag(v) * M [or M^T]. /// The same as adding M but scaling each row M_i by v(i). void AddDiagVecMat(const Real alpha, CuVectorBase &v, const CuMatrixBase &M, MatrixTransposeType transM, Real beta = 1.0); /// this <-- beta*this + alpha*A*B void AddMatSp(const Real alpha, const CuMatrixBase &A, MatrixTransposeType transA, const CuSpMatrix &B, const Real beta) { CuMatrix M(B); return AddMatMat(alpha, A, transA, M, kNoTrans, beta); } /// this <-- beta*this + alpha*SpA*B void AddSpMat(const Real alpha, const CuSpMatrix &A, const CuMatrixBase &B, MatrixTransposeType transB, const Real beta) { CuMatrix M(A); return AddMatMat(alpha, M, kNoTrans, B, transB, beta); } /// this <-- beta*this + alpha*A*B. void AddTpMat(const Real alpha, const CuTpMatrix &A, MatrixTransposeType transA, const CuMatrixBase &B, MatrixTransposeType transB, const Real beta) { CuMatrix M(A); return AddMatMat(alpha, M, transA, B, transB, beta); } /// this <-- beta*this + alpha*A*B. void AddMatTp(const Real alpha, const CuMatrixBase &A, MatrixTransposeType transA, const CuTpMatrix &B, MatrixTransposeType transB, const Real beta) { CuMatrix M(B); return AddMatMat(alpha, A, transA, M, transB, beta); } void CopyFromBlock(const CuBlockMatrix &B, MatrixTransposeType trans = kNoTrans); void CopyLowerToUpper(); void CopyUpperToLower(); inline CuSubMatrix Range(const MatrixIndexT row_offset, const MatrixIndexT num_rows, const MatrixIndexT col_offset, const MatrixIndexT num_cols) const { return CuSubMatrix(*this, row_offset, num_rows, col_offset, num_cols); } inline CuSubMatrix RowRange(const MatrixIndexT row_offset, const MatrixIndexT num_rows) const { return CuSubMatrix(*this, row_offset, num_rows, 0, num_cols_); } inline CuSubMatrix ColRange(const MatrixIndexT col_offset, const MatrixIndexT num_cols) const { return CuSubMatrix(*this, 0, num_rows_, col_offset, num_cols); } inline const CuSubVector Row(MatrixIndexT i) const { KALDI_ASSERT(static_cast(i) < static_cast(num_rows_)); return CuSubVector(data_ + (i * stride_), NumCols()); } inline CuSubVector Row(MatrixIndexT i) { KALDI_ASSERT(static_cast(i) < static_cast(num_rows_)); return CuSubVector(data_ + (i * stride_), NumCols()); } inline CuValue operator() (MatrixIndexT r, MatrixIndexT c) { KALDI_PARANOID_ASSERT(static_cast(r) < static_cast(num_rows_) && static_cast(c) < static_cast(num_cols_)); return CuValue(data_ + r * stride_ + c); } inline Real operator() (MatrixIndexT r, MatrixIndexT c) const { KALDI_PARANOID_ASSERT(static_cast(r) < static_cast(num_rows_) && static_cast(c) < static_cast(num_cols_)); return CuValue(data_ + r * stride_ + c); // will be casted to Real. } Real Sum() const; /// Return the trace. If check_square = true, will crash if matrix is not square. Real Trace(bool check_square = true) const; void SetRandn(); void SetRandUniform(); void Write(std::ostream &os, bool binary) const; // This function, adds a list of MatrixElements (scaled by alpha) to corresponding locations to // (*this). void AddElements(Real alpha, const std::vector >& input); // This function resizes the output to indices.size(), and for each element of // "indices" it interprets it as a (row, column) index into *this, and puts // (*this)(row, column) into the corresponding element of "output". void Lookup(const std::vector &indices, std::vector *output) const; // Creates binary mask with per-element equality predicates of *this, mat. // Output stored to 'mask', values : 1.0 = equal, 0.0 = not-equal. void EqualElementMask(const CuMatrixBase &mat, CuMatrix *mask) const; protected: // The following two functions should only be called if we did not compile with CUDA // or could not get a CUDA card; in that case the contents are interpreted the // same as a regular matrix. inline const MatrixBase &Mat() const { return *(reinterpret_cast* >(this)); } inline MatrixBase &Mat() { return *(reinterpret_cast* >(this)); } /// Get raw row pointer inline const Real* RowData(MatrixIndexT r) const { return data_ + r * stride_; } inline Real* RowData(MatrixIndexT r) { return data_ + r * stride_; } inline const Real *Data() const { return data_; } inline Real *Data() { return data_; } // The constructors are protected to prevent the user creating an instance of // this class. /// Default constructor CuMatrixBase(): data_(NULL), num_cols_(0), num_rows_(0), stride_(0) { } /// This constructor takes the #rows, #cols and stride; it's called from /// the constructor of CuSubMatrix. CuMatrixBase(Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, MatrixIndexT stride): data_(data), num_cols_(num_cols), num_rows_(num_rows), stride_(stride) { } Real *data_; ///< GPU data pointer (or regular matrix data pointer, ///< if either CUDA was not compiled in or we could not ///< acquire the device). // Note: it might seem a bit backwards that we have the number of columns // first here; it's necessary because we need the data to be laid out the same // as for MatrixBase so the Mat() function call will work. We don't want to // change the layout of MatrixBase at this point, or there will be crashes if // people don't thoroughly recompile. MatrixIndexT num_cols_; MatrixIndexT num_rows_; MatrixIndexT stride_; private: KALDI_DISALLOW_COPY_AND_ASSIGN(CuMatrixBase); }; // class CuMatrixBase /// This class represents a matrix that's stored on the GPU if we have one, /// and in memory if not. template class CuMatrix: public CuMatrixBase { public: CuMatrix() { } /// Constructor with memory initialisation CuMatrix(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type = kSetZero) { Resize(rows, cols, resize_type); } // Note: we had to remove the "explicit" keyword due // to problems with STL vectors of CuMatrixBase. CuMatrix(const CuMatrix &other, MatrixTransposeType trans = kNoTrans); explicit CuMatrix(const CuBlockMatrix &other, MatrixTransposeType trans = kNoTrans); explicit CuMatrix(const CuMatrixBase &other, MatrixTransposeType trans = kNoTrans); template explicit CuMatrix(const MatrixBase &other, MatrixTransposeType trans = kNoTrans); /// Copy constructor taking SpMatrix... explicit CuMatrix(const CuSpMatrix &M) : CuMatrixBase() { Resize(M.NumRows(), M.NumRows(), kUndefined); this->CopyFromSp(M); } /// Copy constructor taking TpMatrix... template explicit CuMatrix(const CuTpMatrix & M, MatrixTransposeType trans = kNoTrans) : CuMatrixBase() { Resize(M.NumCols(), M.NumRows(), kUndefined); this->CopyFromTp(M, trans); } /// Copy constructor: as above, but from another type. template explicit CuMatrix(const CuMatrixBase &M, MatrixTransposeType trans = kNoTrans); CuMatrix &operator = (const CuMatrixBase &other) { this->Resize(other.NumRows(), other.NumCols(), kUndefined); this->CopyFromMat(other); return *this; } CuMatrix &operator = (const CuMatrix &other) { this->Resize(other.NumRows(), other.NumCols(), kUndefined); this->CopyFromMat(other); return *this; } CuMatrix &operator = (const MatrixBase &other) { this->Resize(other.NumRows(), other.NumCols(), kUndefined); this->CopyFromMat(other); return *this; } void Transpose(); /// Allocate the memory void Resize(MatrixIndexT rows, MatrixIndexT cols, MatrixResizeType resize_type = kSetZero); void Swap(Matrix *mat); void Swap(CuMatrix *mat); template void Swap(CuMatrix *mat); /// I/O functions void Read(std::istream &is, bool binary); /// Destructor ~CuMatrix() { Destroy(); } inline const Matrix &Mat() const { return *(reinterpret_cast* >(this)); } inline Matrix &Mat() { return *(reinterpret_cast* >(this)); } /// This function does: for each element { row, column, weight } indexed i in /// the vector "elements", let x(i) = A(row(i), column(i)); then it does /// (*this)(row(i), column(i)) += weight(i) / x(i), and /// *tot_objf = \sum_i weight(i) * log(x(i)), and /// *tot_weight = \sum_i weight(i) /// Preconditions: A must be strictly positive, and no (row, column) pair /// may be repeated within "elements" void CompObjfAndDeriv(const std::vector > &elements, const CuMatrix &A, Real *tot_objf, Real* tot_weight); private: void Destroy(); }; /// This class is used for a piece of a CuMatrix. template class CuSubMatrix: public CuMatrixBase { public: inline CuSubMatrix(const CuMatrixBase &mat, const MatrixIndexT row_offset, const MatrixIndexT num_rows, const MatrixIndexT col_offset, const MatrixIndexT num_cols); /// This type of constructor is needed for Range() to work [in CuMatrix base /// class]. Cannot make it explicit or that breaks. inline CuSubMatrix (const CuSubMatrix &other): CuMatrixBase (other.data_, other.num_cols_, other.num_rows_, other.stride_) {} private: /// Disallow assignment. CuSubMatrix &operator = (const CuSubMatrix &other); }; template bool ApproxEqual(const CuMatrixBase &A, const CuMatrixBase &B, Real tol = 0.01) { return A.ApproxEqual(B, tol); } template inline void AssertEqual(CuMatrixBase &A, CuMatrixBase &B, float tol = 0.01) { KALDI_ASSERT(A.ApproxEqual(B, tol)); } template bool SameDim(const CuMatrixBase &M, const CuMatrixBase &N) { return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); } template bool SameDimAndStride(const CuMatrixBase &M, const CuMatrixBase &N) { return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols() && M.Stride() == N.Stride()); } /// I/O template std::ostream &operator << (std::ostream &out, const CuMatrixBase &mat); template template Matrix::Matrix(const CuMatrixBase &M, MatrixTransposeType trans) { if (trans == kNoTrans) Init(M.NumRows(), M.NumCols()); else Init(M.NumCols(), M.NumRows()); M.CopyToMat(this, trans); } template template void MatrixBase::CopyFromMat(const CuMatrixBase &cu, MatrixTransposeType trans) { cu.CopyToMat(this, trans); } } // namespace #include "cudamatrix/cu-matrix-inl.h" #endif