// cudamatrix/cu-tp-matrix.h // Copyright 2013 Ehsan Variani // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. // #ifndef KALDI_CUDAMATRIX_CU_TP_MATRIX_H_ #define KALDI_CUDAMATRIX_CU_TP_MATRIX_H_ #include #include "cudamatrix/cu-common.h" #include "matrix/matrix-common.h" #include "matrix/tp-matrix.h" #include "cudamatrix/cu-array.h" #include "cudamatrix/cu-math.h" #include "cudamatrix/cu-packed-matrix.h" #include "cudamatrix/cu-matrix.h" namespace kaldi { template class CuTpMatrix; template class CuTpMatrix : public CuPackedMatrix { friend class CuMatrixBase; friend class CuMatrixBase; friend class CuVectorBase; friend class CuSubMatrix; friend class CuRand; friend class CuTpMatrix; friend class CuTpMatrix; public: CuTpMatrix() : CuPackedMatrix() {} explicit CuTpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) : CuPackedMatrix(r, resize_type) {} explicit CuTpMatrix(const TpMatrix &orig) : CuPackedMatrix(orig) {} explicit CuTpMatrix(const CuTpMatrix &orig) : CuPackedMatrix(orig) {} explicit CuTpMatrix(const CuMatrixBase &orig, MatrixTransposeType trans = kNoTrans); ~CuTpMatrix() {} void CopyFromMat(const CuMatrixBase &M, MatrixTransposeType Trans = kNoTrans); void CopyFromTp(const CuTpMatrix &other) { CuPackedMatrix::CopyFromPacked(other); } void CopyFromTp(const TpMatrix &other) { CuPackedMatrix::CopyFromPacked(other); } void Cholesky(const CuSpMatrix& Orig); void Invert(); protected: inline const TpMatrix &Mat() const { return *(reinterpret_cast* >(this)); } inline TpMatrix &Mat() { return *(reinterpret_cast* >(this)); } }; } // namespace #endif