OpenMEEG
Loading...
Searching...
No Matches
symmatrix.h
Go to the documentation of this file.
1// Project Name: OpenMEEG (http://openmeeg.github.io)
2// © INRIA and ENPC under the French open source license CeCILL-B.
3// See full copyright notice in the file LICENSE.txt
4// If you make a copy of this file, you must either:
5// - provide also LICENSE.txt and modify this header to refer to it.
6// - replace this header by the LICENSE.txt content.
7
8#pragma once
9
10#include <iostream>
11#include <cstdlib>
12#include <string>
13
14#include <vector.h>
15#include <linop.h>
16
17namespace OpenMEEG {
18
19 class Matrix;
20
21 class OPENMEEGMATHS_EXPORT SymMatrix : public LinOp {
22
23 friend class Vector;
24
25 LinOpValue value;
26
27 public:
28
29 SymMatrix(): LinOp(0,0,SYMMETRIC,2),value() {}
30
31 SymMatrix(const char* fname): LinOp(0,0,SYMMETRIC,2),value() { this->load(fname); }
32 SymMatrix(Dimension N): LinOp(N,N,SYMMETRIC,2),value(size()) { }
33 SymMatrix(Dimension M,Dimension N): LinOp(N,N,SYMMETRIC,2),value(size()) { om_assert(N==M); }
34 SymMatrix(const SymMatrix& S,const DeepCopy): LinOp(S.nlin(),S.nlin(),SYMMETRIC,2),value(S.size(),S.data()) { }
35
36 explicit SymMatrix(const Vector& v);
37 explicit SymMatrix(const Matrix& A);
38
39 size_t size() const { return nlin()*(nlin()+1)/2; };
40 void info() const ;
41
42 Dimension ncol() const { return nlin(); } // SymMatrix only need num_lines
43 Dimension& ncol() { return nlin(); }
44
45 void alloc_data() { value = LinOpValue(size()); }
46 void reference_data(const double* array) { value = LinOpValue(size(),array); }
47
48 bool empty() const { return value.empty(); }
49 void set(double x) ;
50 double* data() const { return value.get(); }
51
52 double operator()(const Index i,const Index j) const {
53 om_assert(i<nlin());
54 om_assert(j<nlin());
55 return data()[(i<=j) ? i+j*(j+1)/2 : j+i*(i+1)/2];
56 }
57
58 double& operator()(const Index i,const Index j) {
59 om_assert(i<nlin());
60 om_assert(j<nlin());
61 return data()[(i<=j) ? i+j*(j+1)/2 : j+i*(i+1)/2];
62 }
63
64 Matrix operator()(const Index i_start,const Index i_end,const Index j_start,const Index j_end) const;
65 Matrix submat(const Index istart,const Index isize,const Index jstart,const Index jsize) const;
66 SymMatrix submat(const Index istart,const Index iend) const;
67 Vector getlin(const Index i) const;
68 void setlin(const Index i,const Vector& v);
69 Vector solveLin(const Vector& B) const;
70 void solveLin(Vector* B,const int nbvect);
72
73 const SymMatrix& operator=(const double d);
74
75 SymMatrix operator+(const SymMatrix& B) const;
76 SymMatrix operator-(const SymMatrix& B) const;
77 Matrix operator*(const SymMatrix& B) const;
78 Matrix operator*(const Matrix& B) const;
79 Vector operator*(const Vector& v) const;
80 SymMatrix operator*(const double x) const;
81 SymMatrix operator/(const double x) const { return (*this)*(1/x); }
82
83 void operator +=(const SymMatrix& B);
84 void operator -=(const SymMatrix& B);
85 void operator *=(const double x);
86 void operator /=(const double x) { (*this)*=(1/x); }
87
88 SymMatrix inverse() const;
89 void invert();
90 SymMatrix posdefinverse() const;
91 double det();
92 // void eigen(Matrix& Z,Vector& D);
93
94 void save(const char* filename) const;
95 void load(const char* filename);
96
97 void save(const std::string& s) const { save(s.c_str()); }
98 void load(const std::string& s) { load(s.c_str()); }
99
100 friend class Matrix;
101 };
102
103 // Returns the solution of (this)*X = B
104
105 inline Vector SymMatrix::solveLin(const Vector& B) const {
106 SymMatrix invA(*this,DEEP_COPY);
107 Vector X(B,DEEP_COPY);
108
109 #ifdef HAVE_LAPACK
110 // Bunch Kaufman factorization
111 BLAS_INT* pivots=new BLAS_INT[nlin()];
112 int Info = 0;
113 DSPTRF('U',sizet_to_int(invA.nlin()),invA.data(),pivots,Info);
114 // Inverse
115 DSPTRS('U',sizet_to_int(invA.nlin()),1,invA.data(),pivots,X.data(),sizet_to_int(invA.nlin()),Info);
116
117 om_assert(Info==0);
118 delete[] pivots;
119 #else
120 std::cout << "solveLin not defined" << std::endl;
121 #endif
122 return X;
123 }
124
125 // stores in B the solution of (this)*X = B, where B is a set of nbvect vector
126
127 inline void SymMatrix::solveLin(Vector* B,const int nbvect) {
128 SymMatrix invA(*this,DEEP_COPY);
129
130 #ifdef HAVE_LAPACK
131 // Bunch Kaufman Factorization
132 BLAS_INT *pivots=new BLAS_INT[nlin()];
133 int Info = 0;
134 //char *uplo="U";
135 DSPTRF('U',sizet_to_int(invA.nlin()),invA.data(),pivots,Info);
136 // Inverse
137 for(int i = 0; i < nbvect; i++)
138 DSPTRS('U',sizet_to_int(invA.nlin()),1,invA.data(),pivots,B[i].data(),sizet_to_int(invA.nlin()),Info);
139
140 om_assert(Info==0);
141 delete[] pivots;
142 #else
143 std::cout << "solveLin not defined" << std::endl;
144 #endif
145 }
146
147 inline void SymMatrix::operator+=(const SymMatrix& B) {
148 om_assert(nlin()==B.nlin());
149 #ifdef HAVE_BLAS
150 BLAS(daxpy,DAXPY)(sizet_to_int(nlin()*(nlin()+1)/2), 1.0, B.data(), 1, data() , 1);
151 #else
152 const size_t sz = size();
153 for (size_t i=0; i<sz; ++i)
154 data()[i] += B.data()[i];
155 #endif
156 }
157
158 inline void SymMatrix::operator-=(const SymMatrix& B) {
159 om_assert(nlin()==B.nlin());
160 #ifdef HAVE_BLAS
161 BLAS(daxpy,DAXPY)(sizet_to_int(nlin()*(nlin()+1)/2), -1.0, B.data(), 1, data() , 1);
162 #else
163 const size_t sz = size();
164 for (size_t i=0; i<sz; ++i)
165 data()[i] -= B.data()[i];
166 #endif
167 }
168
170 // supposes (*this) is definite positive
171 SymMatrix invA(*this,DEEP_COPY);
172 #ifdef HAVE_LAPACK
173 // U'U factorization then inverse
174 int Info = 0;
175 DPPTRF('U', sizet_to_int(nlin()),invA.data(),Info);
176 DPPTRI('U', sizet_to_int(nlin()),invA.data(),Info);
177 om_assert(Info==0);
178 #else
179 std::cerr << "Positive definite inverse not defined" << std::endl;
180 #endif
181 return invA;
182 }
183
184 inline double SymMatrix::det() {
185 SymMatrix invA(*this,DEEP_COPY);
186 double d = 1.0;
187 #ifdef HAVE_LAPACK
188 // Bunch Kaufmqn
189 BLAS_INT *pivots=new BLAS_INT[nlin()];
190 int Info = 0;
191 // TUDUtTt
192 DSPTRF('U', sizet_to_int(invA.nlin()), invA.data(), pivots,Info);
193 if (Info<0)
194 std::cout << "Big problem in det (DSPTRF)" << std::endl;
195 for (size_t i = 0; i< nlin(); i++){
196 if (pivots[i] >= 0) {
197 d *= invA(i,i);
198 } else { // pivots[i] < 0
199 if (i < nlin()-1 && pivots[i] == pivots[i+1]) {
200 d *= (invA(i,i)*invA(i+1,i+1)-invA(i,i+1)*invA(i+1,i));
201 i++;
202 } else {
203 std::cout << "Big problem in det" << std::endl;
204 }
205 }
206 }
207 delete[] pivots;
208 #else
209 throw OpenMEEG::maths::LinearAlgebraError("Determinant not defined without LAPACK");
210 #endif
211 return(d);
212 }
213
214 // inline void SymMatrix::eigen(Matrix& Z,Vector& D ){
215 // // performs the complete eigen-decomposition.
216 // // (*this) = Z.D.Z'
217 // // -> eigenvector are columns of the Matrix Z.
218 // // (*this).Z[:,i] = D[i].Z[:,i]
219 // #ifdef HAVE_LAPACK
220 // SymMatrix symtemp(*this,DEEP_COPY);
221 // D = Vector(nlin());
222 // Z = Matrix(nlin(),nlin());
223 //
224 // int info;
225 // double lworkd;
226 // int lwork;
227 // int liwork;
228 //
229 // DSPEVD('V','U',sizet_to_int(nlin()),symtemp.data(),D.data(),Z.data(),sizet_to_int(nlin()),&lworkd,-1,&liwork,-1,info);
230 // lwork = (int) lworkd;
231 // double * work = new double[lwork];
232 // BLAS_INT *iwork = new BLAS_INT[liwork];
233 // DSPEVD('V','U',sizet_to_int(nlin()),symtemp.data(),D.data(),Z.data(),sizet_to_int(nlin()),work,lwork,iwork,liwork,info);
234 //
235 // delete[] work;
236 // delete[] iwork;
237 // #endif
238 // }
239
240 inline SymMatrix SymMatrix::operator+(const SymMatrix& B) const {
241 om_assert(nlin()==B.nlin());
242 SymMatrix C(*this,DEEP_COPY);
243 C += B;
244 return C;
245 }
246
247 inline SymMatrix SymMatrix::operator-(const SymMatrix& B) const {
248 om_assert(nlin()==B.nlin());
249 SymMatrix C(*this,DEEP_COPY);
250 C -= B;
251 return C;
252 }
253
255 #ifdef HAVE_LAPACK
256 SymMatrix invA(*this, DEEP_COPY);
257 // LU
258 BLAS_INT* pivots = new BLAS_INT[nlin()];
259 int Info = 0;
260 const BLAS_INT M = sizet_to_int(nlin());
261 DSPTRF('U',M,invA.data(),pivots,Info);
262 // Inverse
263 double* work = new double[nlin()*64];
264 DSPTRI('U',M,invA.data(),pivots,work,Info);
265
266 om_assert(Info==0);
267 delete[] pivots;
268 delete[] work;
269 return invA;
270 #else
271 throw OpenMEEG::maths::LinearAlgebraError("Inverse not implemented, requires LAPACK");
272 #endif
273 }
274
275 inline void SymMatrix::invert() {
276 #ifdef HAVE_LAPACK
277 // LU
278 BLAS_INT* pivots = new BLAS_INT[nlin()];
279 int Info = 0;
280 const BLAS_INT M = sizet_to_int(nlin());
281 DSPTRF('U',M,data(),pivots,Info);
282 // Inverse
283 double* work = new double[nlin()*64];
284 DSPTRI('U',M,data(),pivots,work,Info);
285
286 om_assert(Info==0);
287 delete[] pivots;
288 delete[] work;
289 return;
290 #else
291 throw OpenMEEG::maths::LinearAlgebraError("Inverse not implemented, requires LAPACK");
292 #endif
293 }
294
295 inline Vector SymMatrix::operator*(const Vector& v) const {
296 om_assert(nlin()==v.size());
297 Vector y(nlin());
298 #ifdef HAVE_BLAS
299 const BLAS_INT M = sizet_to_int(nlin());
300 DSPMV(CblasUpper,M,1.0,data(),v.data(),1,0.0,y.data(),1);
301 #else
302 for (Index i=0; i<nlin(); ++i) {
303 y(i)=0;
304 for (Index j=0; j<nlin(); ++j)
305 y(i)+=(*this)(i,j)*v(j);
306 }
307 #endif
308 return y;
309 }
310
311 inline Vector SymMatrix::getlin(const Index i) const {
312 om_assert(i<nlin());
313 Vector v(ncol());
314 for (Index j=0; j<ncol(); ++j)
315 v(j) = (*this)(i,j);
316 return v;
317 }
318
319 inline void SymMatrix::setlin(const Index i,const Vector& v) {
320 om_assert(v.size()==nlin());
321 om_assert(i<nlin());
322 for (Index j=0; j<ncol(); ++j)
323 (*this)(i,j) = v(j);
324 }
325}
Dimension nlin() const
Definition: linop.h:48
Matrix class Matrix class.
Definition: matrix.h:28
void save(const std::string &s) const
Definition: symmatrix.h:97
Matrix submat(const Index istart, const Index isize, const Index jstart, const Index jsize) const
SymMatrix posdefinverse() const
Definition: symmatrix.h:169
Matrix operator*(const Matrix &B) const
SymMatrix operator/(const double x) const
Definition: symmatrix.h:81
SymMatrix(const char *fname)
Definition: symmatrix.h:31
Vector solveLin(const Vector &B) const
Definition: symmatrix.h:105
Matrix solveLin(Matrix &B) const
const SymMatrix & operator=(const double d)
Matrix operator*(const SymMatrix &B) const
SymMatrix operator*(const double x) const
Dimension ncol() const
Definition: symmatrix.h:42
void operator+=(const SymMatrix &B)
Definition: symmatrix.h:147
size_t size() const
Definition: symmatrix.h:39
SymMatrix submat(const Index istart, const Index iend) const
double * data() const
Definition: symmatrix.h:50
SymMatrix operator+(const SymMatrix &B) const
Definition: symmatrix.h:240
Matrix operator()(const Index i_start, const Index i_end, const Index j_start, const Index j_end) const
void info() const
void setlin(const Index i, const Vector &v)
Definition: symmatrix.h:319
double & operator()(const Index i, const Index j)
Definition: symmatrix.h:58
Dimension & ncol()
Definition: symmatrix.h:43
Vector getlin(const Index i) const
Definition: symmatrix.h:311
void save(const char *filename) const
SymMatrix inverse() const
Definition: symmatrix.h:254
void reference_data(const double *array)
Definition: symmatrix.h:46
bool empty() const
Definition: symmatrix.h:48
SymMatrix(const SymMatrix &S, const DeepCopy)
Definition: symmatrix.h:34
double operator()(const Index i, const Index j) const
Definition: symmatrix.h:52
void load(const char *filename)
SymMatrix(const Matrix &A)
SymMatrix(const Vector &v)
SymMatrix(Dimension N)
Definition: symmatrix.h:32
void load(const std::string &s)
Definition: symmatrix.h:98
void set(double x)
void operator-=(const SymMatrix &B)
Definition: symmatrix.h:158
SymMatrix(Dimension M, Dimension N)
Definition: symmatrix.h:33
SymMatrix operator-(const SymMatrix &B) const
Definition: symmatrix.h:247
size_t size() const
Definition: vector.h:40
double * data() const
Definition: vector.h:44
Vect3 operator*(const double d, const Vect3 &V)
Definition: vect3.h:105
DeepCopy
Definition: linop.h:84
@ DEEP_COPY
Definition: linop.h:84
unsigned Dimension
Definition: linop.h:32
double det(const Vect3 &V1, const Vect3 &V2, const Vect3 &V3)
Definition: vect3.h:108
unsigned Index
Definition: linop.h:33
BLAS_INT sizet_to_int(const unsigned &num)
Definition: linop.h:26
bool empty() const
Definition: linop.h:96