📄 nagmatrix.cc
字号:
// // NagMatrix.cc// // A class for manipulating (2D) matrices (uses either NAG or BLAS/LAPACK library)//#include <cstdio>#include <cstdlib>#include <cassert>#include <cstring> // for memcpy#include "NagMatrix.h"#include "tracker_defines_types_and_helpers.h"#include "text_output.h"namespace ReadingPeopleTracker{#ifdef USE_FLOAT// single precision math#define dsyev_ ssyev_#define dgetrf_ sgetrf_#define dgetri_ sgetri_#define dgemv_ sgemv_#define dgemm_ sgemm_#define dgeev_ sgeev_#define dgeqrf_ sgeqrf_#define dorgqr_ sorgqr_#define dgesvd_ sgesvd_#define dgels_ sgels_// please note: the following conversions might not be needed anymore// as this was NAG's old naming scheme: ...f for double, ...e for single// precision. Please check with your NAG implementation. #define f01aaf_ f01aae_#define f02abf_ f02abe_#define f02agf_ f02age_#define f01qcf_ f01qce_#define f01qef_ f01qee_#define f02wef_ f02wee_#define f04jgf_ f04jge_#endif#ifdef USE_LAPACK#ifdef WIN32#include <mkl.h> // Intel Math Kernel Library, a BLAS/LAPACK implementation#ifdef USE_FLOAT// single precision math#define dsyev_ SSYEV#define dgetrf_ SGETRF#define dgetri_ SGETRI#define dgemm_ SGEMM#define dgeev_ SGEEV#define dgemv_ SGEMV#define dgeqrf_ SGEQRF#define dorgqr_ SORGQR#define dgesvd_ SGESVD#define dgels_ SGELS#else// double precision math#define dsyev_ DSYEV#define dgetrf_ DGETRF#define dgetri_ DGETRI#define dgemv_ DGEMV#define dgemm_ DGEMM#define dgeev_ DGEEV#define dgeqrf_ DGEQRF#define dorgqr_ DORGQR#define dgesvd_ DGESVD#define dgels_ DGELS#endif // ifdef USE_FLOAT#else // ifdef WIN32extern "C"{ int ilaenv_(int *ispec, char *name, char *opts, int *n1, int *n2, int *n3, int *n4); void dsyev_(char *jobz, char *uplo, int *n, realno *a, int *lda, realno *w, realno *work, int *lwork, int *info); void dgetrf_(int *m, int *n, realno *a, int * lda, int *ipiv, int *info); void dgetri_ (int *n, realno *a, int *lda, int *ipiv, realno *work, int *lwork, int *info); void dgeev_ (char *jobvl, char *jobvr, int *n, realno *a, int *lda, realno *wr, realno *wi, realno *vl, int *ldvl, realno *vr, int *ldvr, realno *work, int *lwork, int *info); void dgeqrf_(int *m, int *n, realno *a, int *lda, realno *tau, realno *work, int *lwork, int *info); void dorgqr_(int *m, int *n, int *k, realno *a, int *lda, realno *tau, realno *work, int *lwork, int *info); void dgesvd_(char *cha, char *chb, int *m, int *n, realno *a,int *lda, realno *s, realno *u, int *ldu, realno *v, int *ldvt, realno *work, int *lwork, int *info ); void dgels_(char *ch, int *m, int *n, int *nrhs, realno *a, int *lda, realno *b, int *ldb, realno *work, int *lwork, int *info); }#endif // ifdef WIN32 else#else // ifdef USE_LAPACK// use NAG// default value for IFAIL variable, determines what NAG should do on failureconst int NagMatrix::DEF_IFAIL = -1; // show error message but continueextern "C"{ void f01aaf_ (realno*, int*, int*, realno*, int*, realno*, int* ); void f02abf_ (realno*, int*, int*, realno*, realno*, int*, realno*, int*); void f02agf_(realno*, int*, int*, realno*, realno*, realno*, int*, realno*, int*, int*, int*); void f01qcf_(int*, int*, realno*, int*, realno*, int*); void f01qef_(char*, int*, int*, int*, realno* , int*, realno*, realno *, int*); void f02wef_(int*, int*, realno*, int*, int*, realno*, int*, int*, realno*, int*, realno*, int*, realno*, int*, realno* , int*); void f04jgf_(int *M, int *N, realno *A, int *NRA, realno *B, realno *TOL, bool *SVD, realno *SIGMA, int *IRANK, realno *WORK, int *LWORK, int *IFAIL);}#endif // ifdef USE_LAPACK else#ifndef WIN32extern "C"{ void dgemv_ (char*, int*, int*, realno*, realno*, int*, const realno*, int*, realno*, realno*, int* ); void dgemm_ (char*, char*, int*, int*, int*, realno*, realno*, int*, realno*, int*, realno*, realno*, int*);}#endifvoid NagMatrix::matrix_error(const char *message) const{ cerror << "Error in NagMatrix library routine: " << message << endl; abort();}NagMatrix::NagMatrix(unsigned int n, unsigned int m){ rows = n; columns = m; data = new realno[n*m]; own_memory = true;}void NagMatrix::reconstruct(unsigned int n, unsigned int m){ if ((data != NULL) && own_memory) delete [] data; rows = n; columns = m; if ((n*m) > 0) { data = new realno[n*m]; own_memory = true; } else { data = NULL; own_memory = false; }}NagMatrix::NagMatrix(unsigned int n, unsigned int m, realno init){ rows = n; columns = m; data = new realno[n*m]; own_memory = true; realno *element = data; register unsigned int num_elements = n * m; for (register unsigned int i = 0; i < num_elements; i++) *element++ = init;}NagMatrix::NagMatrix(unsigned int n, unsigned int m, const NagVector &v){ rows = n; columns = m; assert(((n * m) == v.get_size()) || ((n == m) && (n == v.get_size()))); // enforce valid initialisation if ((n * m) == v.get_size()) { data = new realno[n*m]; own_memory = true; memcpy((void*)data, (void*)v.get_data_const(), n * m * sizeof(realno)); } else if ((n == m) && (n == v.get_size())) { // create matrix with values from v on the diagonal, 0 elsewhere data = new realno[n*m]; own_memory = true; clear(0); for (unsigned int i = 0; i < rows; i++) set(i,i, v[i]); } else { // no valid data given for initialisation data = NULL; own_memory = false; } }void NagMatrix::reconstruct(unsigned int n, unsigned int m, const NagVector &v){ if ((data != NULL) && own_memory) delete [] data; rows = n; columns = m; assert (n*m == v.get_size()); // force same size data = new realno[n*m]; own_memory = true; memcpy((void*)data, (void*)v.get_data_const(), n * m * sizeof(realno));}realno NagMatrix::identity_fn(unsigned int i, unsigned int j){ if (i == j) return 1; return 0;}// initialising matrix with values generated by a function of row and columnNagMatrix::NagMatrix(unsigned int n, unsigned int m, realno (*func) (unsigned int, unsigned int)){ rows = n; columns = m; data = new realno[n*m]; own_memory = true; for (unsigned int i = 0; i < rows; i++) for (unsigned int j = 0; j < columns; j++) set(i,j, (*func) (i,j));}NagMatrix::~NagMatrix(){ if ((data != NULL) && own_memory) delete [] data;}void NagMatrix::transpose (NagMatrix &result) const{ if (result.data == NULL) result.reconstruct(columns, rows); for (unsigned int i = 0; i < rows; i++) for (unsigned int j = 0; j < columns; j++) result.set(j,i,read(i,j));}void NagMatrix::invert(NagMatrix &result) const{ if (rows != columns) matrix_error(" cannot invert non-square matrix");#ifdef USE_LAPACK copy(result); int N = (int) columns; int LDA = (int) rows; int *IPIV = new int[N]; int INFO;// int ISPEC = 1; //int NB = ilaenv_(&ISPEC, "dgetri", "", &N, &LDA, &minus1, &minus1); int NB = 64; int LWORK = (N * NB); NagVector WORK(LWORK); dgetrf_(&LDA, &N, result.data, &LDA, IPIV, &INFO); if (INFO != 0) matrix_error(" cannot factorise"); dgetri_(&N, result.data, &LDA, IPIV, WORK.get_data(), &LWORK, &INFO); if (INFO != 0) matrix_error(" cannot invert matrix");#else if (result.data == NULL) result.reconstruct(rows, columns); int IA = (int) rows; int N = (int) rows; int IX = result.no_rows(); NagMatrix A (*this); // make a copy because the NAG routine modifies A. NagVector P(rows); int IFAIL = DEF_IFAIL; // return value: > 0 on failure, value on entry determines error message f01aaf_(A.data, &IA, &N, result.data, &IX, P.data, &IFAIL); if (IFAIL != 0) matrix_error(" failed to invert");#endif}void NagMatrix::scale(const realno s, NagMatrix &result) const{ if (result.data == NULL) result.reconstruct(rows, columns); assert (result.no_rows() == rows); assert (result.no_columns() == columns); realno *data1 = data; realno *data2 = result.data; for (unsigned int i = rows * columns; i > 0 ; i--) *data2++ = s * *data1++;}void NagMatrix::clear(const realno s){ realno *data1 = data; for (unsigned int i = rows * columns; i > 0 ; i--) *data1++ = s;}void NagMatrix::add (const NagMatrix &m2, NagMatrix &result) const{ if (result.data == NULL) result.reconstruct(rows, columns); if ((rows != m2.rows) || (columns != m2.columns) || (rows != result.rows) || (columns != result.columns)) matrix_error(" illegal Matrix to add"); realno *data1 = data; realno *data2 = m2.data; realno *data_res = result.data; for (unsigned int i = rows * columns; i > 0; i--) *data_res++ = (*data1++) + (*data2++);}void NagMatrix::subtract (const NagMatrix &m2, NagMatrix &result) const{ if (result.data == NULL) result.reconstruct(rows, columns); if ((rows != m2.rows) || (columns != m2.columns) || (rows != result.rows) || (columns != result.columns)) matrix_error(" illegal Matrix to subtract"); realno *data1 = data; realno *data2 = m2.data; realno *data_res = result.data; for (unsigned int i = rows * columns; i > 0; i--) *data_res++ = (*data1++) - (*data2++);}void NagMatrix::multiply (const NagMatrix &m2, NagMatrix &result) const{ if (result.data == NULL) result.reconstruct(rows, m2.columns); if ((rows != result.rows) || (m2.columns != result.columns) || (columns != m2.rows)) matrix_error(" error in matrix multiply"); char transa = 'N'; char transb = 'N'; int M = (int) rows; int N = (int) m2.columns; int K = (int) columns; realno alpha = 1.0; realno beta = 0.0; int LDA = (int) rows; int LDB = (int) m2.rows; int LDC = (int) result.rows; realno *B = m2.data; realno *C = result.data; realno *A = data; dgemm_(&transa, &transb, &M, &N, &K, &alpha, A, &LDA, B, &LDB, &beta, C, &LDC);}void NagMatrix::multiply (const NagMatrix &m2, NagMatrix &result, char transa, char transb, realno alpha) const{ if (transa == 't') transa = 'T'; if (transb == 't') transb = 'T'; int M, N, K; if (transa == 'T') M = (int) columns; else M = (int) rows; if (transb == 'T') N = (int) m2.rows; else N = (int) m2.columns; if (transa == 'T') K = (int) rows; else K = (int) columns; if (((transb == 'T') && (K != m2.columns)) || ((transb != 'T') && (K != m2.rows))) matrix_error("bad call to matrix multiply"); realno beta = 0.0; if (result.data == NULL) result.reconstruct(M, N);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -