IT++ Logo
gmm.cpp
Go to the documentation of this file.
1
29#include <itpp/srccode/gmm.h>
32#include <itpp/base/matfunc.h>
33#include <itpp/base/specmat.h>
34#include <itpp/base/random.h>
35#include <itpp/base/timing.h>
36#include <iostream>
37#include <fstream>
38
40
41namespace itpp
42{
43
44GMM::GMM()
45{
46 d = 0;
47 M = 0;
48}
49
50GMM::GMM(std::string filename)
51{
52 load(filename);
53}
54
55GMM::GMM(int M_in, int d_in)
56{
57 M = M_in;
58 d = d_in;
59 m = zeros(M * d);
60 sigma = zeros(M * d);
61 w = 1. / M * ones(M);
62
63 for (int i = 0;i < M;i++) {
64 w(i) = 1.0 / M;
65 }
66 compute_internals();
67}
68
69void GMM::init_from_vq(const vec &codebook, int dim)
70{
71
72 mat C(dim, dim);
73 int i;
74 vec v;
75
76 d = dim;
77 M = codebook.length() / dim;
78
79 m = codebook;
80 w = ones(M) / double(M);
81
82 C.clear();
83 for (i = 0;i < M;i++) {
84 v = codebook.mid(i * d, d);
85 C = C + outer_product(v, v);
86 }
87 C = 1. / M * C;
88 sigma.set_length(M*d);
89 for (i = 0;i < M;i++) {
90 sigma.replace_mid(i*d, diag(C));
91 }
92
93 compute_internals();
94}
95
96void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in)
97{
98 int i, j;
99 d = m_in.rows();
100 M = m_in.cols();
101
102 m.set_length(M*d);
103 sigma.set_length(M*d);
104 for (i = 0;i < M;i++) {
105 for (j = 0;j < d;j++) {
106 m(i*d + j) = m_in(j, i);
107 sigma(i*d + j) = sigma_in(j, i);
108 }
109 }
110 w = w_in;
111
112 compute_internals();
113}
114
115void GMM::set_mean(const mat &m_in)
116{
117 int i, j;
118
119 d = m_in.rows();
120 M = m_in.cols();
121
122 m.set_length(M*d);
123 for (i = 0;i < M;i++) {
124 for (j = 0;j < d;j++) {
125 m(i*d + j) = m_in(j, i);
126 }
127 }
128 compute_internals();
129}
130
131void GMM::set_mean(int i, const vec &means, bool compflag)
132{
133 m.replace_mid(i*length(means), means);
134 if (compflag) compute_internals();
135}
136
137void GMM::set_covariance(const mat &sigma_in)
138{
139 int i, j;
140
141 d = sigma_in.rows();
142 M = sigma_in.cols();
143
144 sigma.set_length(M*d);
145 for (i = 0;i < M;i++) {
146 for (j = 0;j < d;j++) {
147 sigma(i*d + j) = sigma_in(j, i);
148 }
149 }
150 compute_internals();
151}
152
153void GMM::set_covariance(int i, const vec &covariances, bool compflag)
154{
155 sigma.replace_mid(i*length(covariances), covariances);
156 if (compflag) compute_internals();
157}
158
159void GMM::marginalize(int d_new)
160{
161 it_error_if(d_new > d, "GMM.marginalize: cannot change to a larger dimension");
162
163 vec mnew(d_new*M), sigmanew(d_new*M);
164 int i, j;
165
166 for (i = 0;i < M;i++) {
167 for (j = 0;j < d_new;j++) {
168 mnew(i*d_new + j) = m(i * d + j);
169 sigmanew(i*d_new + j) = sigma(i * d + j);
170 }
171 }
172 m = mnew;
173 sigma = sigmanew;
174 d = d_new;
175
176 compute_internals();
177}
178
179void GMM::join(const GMM &newgmm)
180{
181 if (d == 0) {
182 w = newgmm.w;
183 m = newgmm.m;
184 sigma = newgmm.sigma;
185 d = newgmm.d;
186 M = newgmm.M;
187 }
188 else {
189 it_error_if(d != newgmm.d, "GMM.join: cannot join GMMs of different dimension");
190
191 w = concat(double(M) / (M + newgmm.M) * w, double(newgmm.M) / (M + newgmm.M) * newgmm.w);
192 w = w / sum(w);
193 m = concat(m, newgmm.m);
194 sigma = concat(sigma, newgmm.sigma);
195
196 M = M + newgmm.M;
197 }
198 compute_internals();
199}
200
201void GMM::clear()
202{
203 w.set_length(0);
204 m.set_length(0);
205 sigma.set_length(0);
206 d = 0;
207 M = 0;
208}
209
210void GMM::save(std::string filename)
211{
212 std::ofstream f(filename.c_str());
213 int i, j;
214
215 f << M << " " << d << std::endl ;
216 for (i = 0;i < w.length();i++) {
217 f << w(i) << std::endl ;
218 }
219 for (i = 0;i < M;i++) {
220 f << m(i*d) ;
221 for (j = 1;j < d;j++) {
222 f << " " << m(i*d + j) ;
223 }
224 f << std::endl ;
225 }
226 for (i = 0;i < M;i++) {
227 f << sigma(i*d) ;
228 for (j = 1;j < d;j++) {
229 f << " " << sigma(i*d + j) ;
230 }
231 f << std::endl ;
232 }
233}
234
235void GMM::load(std::string filename)
236{
237 std::ifstream GMMFile(filename.c_str());
238 int i, j;
239
240 it_error_if(!GMMFile, std::string("GMM::load : cannot open file ") + filename);
241
242 GMMFile >> M >> d ;
243
244
245 w.set_length(M);
246 for (i = 0;i < M;i++) {
247 GMMFile >> w(i) ;
248 }
249 m.set_length(M*d);
250 for (i = 0;i < M;i++) {
251 for (j = 0;j < d;j++) {
252 GMMFile >> m(i*d + j) ;
253 }
254 }
255 sigma.set_length(M*d);
256 for (i = 0;i < M;i++) {
257 for (j = 0;j < d;j++) {
258 GMMFile >> sigma(i*d + j) ;
259 }
260 }
261 compute_internals();
262 std::cout << " mixtures:" << M << " dim:" << d << std::endl ;
263}
264
265double GMM::likelihood(const vec &x)
266{
267 double fx = 0;
268 int i;
269
270 for (i = 0;i < M;i++) {
271 fx += w(i) * likelihood_aposteriori(x, i);
272 }
273 return fx;
274}
275
276vec GMM::likelihood_aposteriori(const vec &x)
277{
278 vec v(M);
279 int i;
280
281 for (i = 0;i < M;i++) {
282 v(i) = w(i) * likelihood_aposteriori(x, i);
283 }
284 return v;
285}
286
287double GMM::likelihood_aposteriori(const vec &x, int mixture)
288{
289 int j;
290 double s;
291
292 it_error_if(d != x.length(), "GMM::likelihood_aposteriori : dimensions does not match");
293 s = 0;
294 for (j = 0;j < d;j++) {
295 s += normexp(mixture * d + j) * sqr(x(j) - m(mixture * d + j));
296 }
297 return normweight(mixture)*std::exp(s);;
298}
299
300void GMM::compute_internals()
301{
302 int i, j;
303 double s;
304 double constant = 1.0 / std::pow(2 * pi, d / 2.0);
305
306 normweight.set_length(M);
307 normexp.set_length(M*d);
308
309 for (i = 0;i < M;i++) {
310 s = 1;
311 for (j = 0;j < d;j++) {
312 normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time
313 s *= sigma(i * d + j);
314 }
315 normweight(i) = constant / std::sqrt(s);
316 }
317
318}
319
320vec GMM::draw_sample()
321{
322 static bool first = true;
323 static vec cumweight;
324 double u = randu();
325 int k;
326
327 if (first) {
328 first = false;
329 cumweight = cumsum(w);
330 it_error_if(std::abs(cumweight(length(cumweight) - 1) - 1) > 1e-6, "weight does not sum to 0");
331 cumweight(length(cumweight) - 1) = 1;
332 }
333 k = 0;
334 while (u > cumweight(k)) k++;
335
336 return elem_mult(sqrt(sigma.mid(k*d, d)), randn(d)) + m.mid(k*d, d);
337}
338
339GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE)
340{
341 mat mean;
342 int i, j, d = TrainingData(0).length();
343 vec sig;
344 GMM gmm(M, d);
345 vec m(d*M);
346 vec sigma(d*M);
347 vec w(M);
348 vec normweight(M);
349 vec normexp(d*M);
350 double LL = 0, LLold, fx;
351 double constant = 1.0 / std::pow(2 * pi, d / 2.0);
352 int T = TrainingData.length();
353 vec x1;
354 int t, n;
355 vec msum(d*M);
356 vec sigmasum(d*M);
357 vec wsum(M);
358 vec p_aposteriori(M);
359 vec x2;
360 double s;
361 vec temp1, temp2;
362 //double MINIMUM_VARIANCE=0.03;
363
364 //-----------initialization-----------------------------------
365
366 mean = vqtrain(TrainingData, M, 200000, 0.5, VERBOSE);
367 for (i = 0;i < M;i++) gmm.set_mean(i, mean.get_col(i), false);
368 // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false);
369 sig = zeros(d);
370 for (i = 0;i < TrainingData.length();i++) sig += sqr(TrainingData(i));
371 sig /= TrainingData.length();
372 for (i = 0;i < M;i++) gmm.set_covariance(i, 0.5*sig, false);
373
374 gmm.set_weight(1.0 / M*ones(M));
375
376 //-----------optimization-----------------------------------
377
378 tic();
379 for (i = 0;i < M;i++) {
380 temp1 = gmm.get_mean(i);
381 temp2 = gmm.get_covariance(i);
382 for (j = 0;j < d;j++) {
383 m(i*d + j) = temp1(j);
384 sigma(i*d + j) = temp2(j);
385 }
386 w(i) = gmm.get_weight(i);
387 }
388 for (n = 0;n < NOITER;n++) {
389 for (i = 0;i < M;i++) {
390 s = 1;
391 for (j = 0;j < d;j++) {
392 normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time
393 s *= sigma(i * d + j);
394 }
395 normweight(i) = constant * w(i) / std::sqrt(s);
396 }
397 LLold = LL;
398 wsum.clear();
399 msum.clear();
400 sigmasum.clear();
401 LL = 0;
402 for (t = 0;t < T;t++) {
403 x1 = TrainingData(t);
404 x2 = sqr(x1);
405 fx = 0;
406 for (i = 0;i < M;i++) {
407 s = 0;
408 for (j = 0;j < d;j++) {
409 s += normexp(i * d + j) * sqr(x1(j) - m(i * d + j));
410 }
411 p_aposteriori(i) = normweight(i) * std::exp(s);
412 fx += p_aposteriori(i);
413 }
414 p_aposteriori /= fx;
415 LL = LL + std::log(fx);
416
417 for (i = 0;i < M;i++) {
418 wsum(i) += p_aposteriori(i);
419 for (j = 0;j < d;j++) {
420 msum(i*d + j) += p_aposteriori(i) * x1(j);
421 sigmasum(i*d + j) += p_aposteriori(i) * x2(j);
422 }
423 }
424 }
425 for (i = 0;i < M;i++) {
426 for (j = 0;j < d;j++) {
427 m(i*d + j) = msum(i * d + j) / wsum(i);
428 sigma(i*d + j) = sigmasum(i * d + j) / wsum(i) - sqr(m(i * d + j));
429 }
430 w(i) = wsum(i) / T;
431 }
432 LL = LL / T;
433
434 if (std::abs((LL - LLold) / LL) < 1e-6) break;
435 if (VERBOSE) {
436 std::cout << n << ": " << LL << " " << std::abs((LL - LLold) / LL) << " " << toc() << std::endl ;
437 std::cout << "---------------------------------------" << std::endl ;
438 tic();
439 }
440 else {
441 std::cout << n << ": LL = " << LL << " " << std::abs((LL - LLold) / LL) << "\r" ;
442 std::cout.flush();
443 }
444 }
445 for (i = 0;i < M;i++) {
446 gmm.set_mean(i, m.mid(i*d, d), false);
447 gmm.set_covariance(i, sigma.mid(i*d, d), false);
448 }
449 gmm.set_weight(w);
450 return gmm;
451}
452
453} // namespace itpp
454
Elementary mathematical functions - header file.
Definition of a Gaussian Mixture Model Class.
Mat< T > diag(const Vec< T > &v, const int K=0)
Create a diagonal matrix using vector v as its diagonal.
Definition: matfunc.h:557
#define it_error_if(t, s)
Abort if t is true.
Definition: itassert.h:117
Vec< T > cumsum(const Vec< T > &v)
Cumulative sum of all elements in the vector.
Definition: matfunc.h:157
T sum(const Vec< T > &v)
Sum of all elements in the vector.
Definition: matfunc.h:59
int length(const Vec< T > &v)
Length of vector.
Definition: matfunc.h:51
vec sqr(const cvec &data)
Absolute square of elements.
Definition: elem_math.cpp:36
vec sqrt(const vec &x)
Square root of the elements.
Definition: elem_math.h:123
double randu(void)
Generates a random uniform (0,1) number.
Definition: random.h:804
double randn(void)
Generates a random Gaussian (0,1) variable.
Definition: random.h:831
ITPP_EXPORT mat vqtrain(Array< vec > &DB, int SIZE, int NOITER, double STARTSTEP=0.2, bool VERBOSE=true)
Function for vector quantization training.
ITPP_EXPORT vec zeros(int size)
A Double vector of zeros.
ITPP_EXPORT vec ones(int size)
A float vector of ones.
double mean(const vec &v)
The mean value.
Definition: misc_stat.cpp:36
void tic()
Reset and start timer.
Definition: timing.cpp:154
double toc()
Returns the elapsed time since last tic()
Definition: timing.cpp:159
Various functions on vectors and matrices - header file.
itpp namespace
Definition: itmex.h:37
const Array< T > concat(const Array< T > &a, const T &e)
Append element e to the end of the Array a.
Definition: array.h:486
const double pi
Constant Pi.
Definition: misc.h:103
Mat< Num_T > elem_mult(const Mat< Num_T > &m1, const Mat< Num_T > &m2)
Element wise multiplication of two matrices.
Definition: mat.h:1582
Mat< Num_T > outer_product(const Vec< Num_T > &v1, const Vec< Num_T > &v2, bool hermitian=false)
Outer product of two vectors v1 and v2.
Definition: vec.h:1021
int abs(const itpp::bin &inbin)
absolute value of bin
Definition: binary.h:186
Definition of classes for random number generators.
Definitions of special vectors and matrices.
Definitions of Timing classes.
Definitions of a vector quantizer training functions.
SourceForge Logo

Generated on Tue Jan 24 2023 00:00:00 for IT++ by Doxygen 1.9.6