Skip to content

Commit 7b20449

Browse files
committed
gen, visualize
1 parent 5a9ccff commit 7b20449

File tree

3 files changed

+97
-70
lines changed

3 files changed

+97
-70
lines changed

bayesml/gaussianmixture/_gaussianmixture.py

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def get_h_params(self):
204204
* ``"h_m_vec"`` : The value of ``self.h_mu_vec``
205205
* ``"h_kappa"`` : The value of ``self.h_kappa``
206206
* ``"h_nu"`` : The value of ``self.h_nu``
207-
* ``"h_w_mat"`` : The value of ``self.h_lambda_mat``
207+
* ``"h_w_mat"`` : The value of ``self.h_w_mat``
208208
"""
209209
return {"h_alpha_vec":self.h_alpha_vec,
210210
"h_m_vec":self.h_m_vec,
@@ -215,11 +215,12 @@ def get_h_params(self):
215215
def gen_params(self):
216216
"""Generate the parameter from the prior distribution.
217217
218-
The generated vaule is set at ``self.mu_vec`` and ``self.lambda_mat``.
218+
The generated vaule is set at ``self.pi_vec``, ``self.mu_vecs`` and ``self.lambda_mats``.
219219
"""
220-
pass
221-
# self.lambda_mat[:] = ss_wishart.rvs(df=self.h_nu,scale=self.h_w_mat,random_state=self.rng)
222-
# self.mu_vec[:] = self.rng.multivariate_normal(mean=self.h_m_vec,cov=np.linalg.inv(self.h_kappa*self.lambda_mat))
220+
self.pi_vec[:] = self.rng.dirichlet(self.h_alpha_vec)
221+
for k in range(self.num_classes):
222+
self.lambda_mats[k] = ss_wishart.rvs(df=self.h_nu,scale=self.h_w_mat,random_state=self.rng)
223+
self.mu_vecs[k] = self.rng.multivariate_normal(mean=self.h_m_vec,cov=np.linalg.inv(self.h_kappa*self.lambda_mats[k]))
223224

224225
def set_params(self,pi_vec,mu_vecs,lambda_mats):
225226
"""Set the parameter of the sthocastic data generative model.
@@ -285,16 +286,24 @@ def gen_sample(self,sample_size):
285286
Returns
286287
-------
287288
x : numpy ndarray
288-
2-dimensional array whose shape is ``(sammple_size,degree)`` and its elements are real number.
289+
2-dimensional array whose shape is ``(sammple_size,degree)`` and its elements are real numbers.
290+
z : numpy ndarray
291+
2-dimensional array whose shape is ``(sample_size,num_classes)`` whose rows are one-hot vectors.
289292
"""
290-
pass
291-
# _check.pos_int(sample_size,'sample_size',DataFormatError)
292-
# return self.rng.multivariate_normal(mean=self.mu_vec,cov=np.linalg.inv(self.lambda_mat),size=sample_size)
293+
_check.pos_int(sample_size,'sample_size',DataFormatError)
294+
z = np.zeros([sample_size,self.num_classes],dtype=int)
295+
x = np.empty([sample_size,self.degree])
296+
_lambda_mats_inv = np.linalg.inv(self.lambda_mats)
297+
for i in range(sample_size):
298+
k = self.rng.choice(self.num_classes,p=self.pi_vec)
299+
z[i,k] = 1
300+
x[i] = self.rng.multivariate_normal(mean=self.mu_vecs[k],cov=_lambda_mats_inv[k])
301+
return x,z
293302

294303
def save_sample(self,filename,sample_size):
295304
"""Save the generated sample as NumPy ``.npz`` format.
296305
297-
It is saved as a NpzFile with keyword: \"x\".
306+
It is saved as a NpzFile with keyword: \"x\", \"z\".
298307
299308
Parameters
300309
----------
@@ -308,71 +317,91 @@ def save_sample(self,filename,sample_size):
308317
--------
309318
numpy.savez_compressed
310319
"""
311-
pass
312-
# np.savez_compressed(filename,x=self.gen_sample(sample_size))
320+
x,z=self.gen_sample(sample_size)
321+
np.savez_compressed(filename,x=x,z=z)
313322

314323
def visualize_model(self,sample_size=100):
315324
"""Visualize the stochastic data generative model and generated samples.
316325
317326
Parameters
318327
----------
319328
sample_size : int, optional
320-
A positive integer, by default 1
329+
A positive integer, by default 100
321330
322331
Examples
323332
--------
324-
>>> from bayesml import multivariate_normal
325-
>>> model = multivariate_normal.GenModel()
333+
>>> from bayesml import gaussianmixture
334+
>>> import numpy as np
335+
>>> model = gaussianmixture.GenModel(
336+
>>> pi_vec=np.array([0.444,0.444,0.112]),
337+
>>> mu_vecs=np.array([[-2.8],[-0.8],[2]]),
338+
>>> lambda_mats=np.array([[[6.25]],[[6.25]],[[100]]])
339+
>>> )
326340
>>> model.visualize_model()
327-
mu:
328-
[0. 0.]
329-
lambda_mat:
330-
[[1. 0.]
331-
[0. 1.]]
341+
pi_vec:
342+
[0.444 0.444 0.112]
343+
mu_vecs:
344+
[[-2.8]
345+
[-0.8]
346+
[ 2. ]]
347+
lambda_mats:
348+
[[[ 6.25]]
349+
350+
[[ 6.25]]
351+
352+
[[100. ]]]
332353
333-
.. image:: ./images/multivariate_normal_example.png
354+
.. image:: ./images/gaussianmixture_example.png
334355
"""
335-
pass
336-
# if self.degree == 1:
337-
# print(f"mu: {self.mu_vec}")
338-
# print(f"lambda_mat: {self.lambda_mat}")
339-
# lambda_mat_inv = np.linalg.inv(self.lambda_mat)
340-
# fig, axes = plt.subplots()
341-
# sample = self.gen_sample(sample_size)
342-
# x = np.linspace(sample.min()-(sample.max()-sample.min())*0.25,
343-
# sample.max()+(sample.max()-sample.min())*0.25,
344-
# 100)
345-
# axes.plot(x,ss_multivariate_normal.pdf(x,self.mu_vec,lambda_mat_inv))
346-
# axes.hist(sample,density=True)
347-
# axes.set_xlabel("x")
348-
# axes.set_ylabel("Density or frequency")
349-
# plt.show()
350-
351-
# elif self.degree == 2:
352-
# print(f"mu:\n{self.mu_vec}")
353-
# print(f"lambda_mat:\n{self.lambda_mat}")
354-
# lambda_mat_inv = np.linalg.inv(self.lambda_mat)
355-
# fig, axes = plt.subplots()
356-
# sample = self.gen_sample(sample_size)
357-
# x = np.linspace(sample[:,0].min()-(sample[:,0].max()-sample[:,0].min())*0.25,
358-
# sample[:,0].max()+(sample[:,0].max()-sample[:,0].min())*0.25,
359-
# 100)
360-
# y = np.linspace(sample[:,1].min()-(sample[:,1].max()-sample[:,1].min())*0.25,
361-
# sample[:,1].max()+(sample[:,1].max()-sample[:,1].min())*0.25,
362-
# 100)
363-
# xx, yy = np.meshgrid(x,y)
364-
# grid = np.empty((100,100,2))
365-
# grid[:,:,0] = xx
366-
# grid[:,:,1] = yy
367-
# axes.contourf(xx,yy,ss_multivariate_normal.pdf(grid,self.mu_vec,lambda_mat_inv),cmap='Blues')
368-
# axes.plot(self.mu_vec[0],self.mu_vec[1],marker="x",color='red')
369-
# axes.set_xlabel("x[0]")
370-
# axes.set_ylabel("x[1]")
371-
# axes.scatter(sample[:,0],sample[:,1],color='tab:orange')
372-
# plt.show()
373-
374-
# else:
375-
# raise(ParameterFormatError("if degree > 2, it is impossible to visualize the model by this function."))
356+
if self.degree == 1:
357+
print(f"pi_vec:\n {self.pi_vec}")
358+
print(f"mu_vecs:\n {self.mu_vecs}")
359+
print(f"lambda_mats:\n {self.lambda_mats}")
360+
_lambda_mats_inv = np.linalg.inv(self.lambda_mats)
361+
fig, axes = plt.subplots()
362+
sample, _ = self.gen_sample(sample_size)
363+
x = np.linspace(sample.min()-(sample.max()-sample.min())*0.25,
364+
sample.max()+(sample.max()-sample.min())*0.25,
365+
1000)
366+
y = np.zeros(1000)
367+
for k in range(self.num_classes):
368+
y += self.pi_vec[k] * ss_multivariate_normal.pdf(x,self.mu_vecs[k],_lambda_mats_inv[k])
369+
axes.plot(x,y)
370+
axes.hist(sample,density=True)
371+
axes.set_xlabel("x")
372+
axes.set_ylabel("Density or frequency")
373+
plt.show()
374+
375+
elif self.degree == 2:
376+
print(f"pi_vec:\n {self.pi_vec}")
377+
print(f"mu_vecs:\n {self.mu_vecs}")
378+
print(f"lambda_mats:\n {self.lambda_mats}")
379+
_lambda_mats_inv = np.linalg.inv(self.lambda_mats)
380+
fig, axes = plt.subplots()
381+
sample, _ = self.gen_sample(sample_size)
382+
x = np.linspace(sample[:,0].min()-(sample[:,0].max()-sample[:,0].min())*0.25,
383+
sample[:,0].max()+(sample[:,0].max()-sample[:,0].min())*0.25,
384+
1000)
385+
y = np.linspace(sample[:,1].min()-(sample[:,1].max()-sample[:,1].min())*0.25,
386+
sample[:,1].max()+(sample[:,1].max()-sample[:,1].min())*0.25,
387+
1000)
388+
xx, yy = np.meshgrid(x,y)
389+
grid = np.empty((1000,1000,2))
390+
grid[:,:,0] = xx
391+
grid[:,:,1] = yy
392+
z = np.zeros([1000,1000])
393+
for k in range(self.num_classes):
394+
z += self.pi_vec[k] * ss_multivariate_normal.pdf(grid,self.mu_vecs[k],_lambda_mats_inv[k])
395+
axes.contourf(xx,yy,z,cmap='Blues')
396+
for k in range(self.num_classes):
397+
axes.plot(self.mu_vecs[k,0],self.mu_vecs[k,1],marker="x",color='red')
398+
axes.set_xlabel("x[0]")
399+
axes.set_ylabel("x[1]")
400+
axes.scatter(sample[:,0],sample[:,1],color='tab:orange')
401+
plt.show()
402+
403+
else:
404+
raise(ParameterFormatError("if degree > 2, it is impossible to visualize the model by this function."))
376405

377406
# class LearnModel(base.Posterior,base.PredictiveMixin):
378407
# """The posterior distribution and the predictive distribution.

bayesml/gaussianmixture/test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from bayesml import gaussianmixture
22
import numpy as np
33

4-
model = gaussianmixture.GenModel(num_classes=2,degree=3)
5-
print(model.get_params())
6-
params = model.get_params()
7-
params['pi_vec'] = np.ones(6)/6
8-
params['mu_vecs'] = np.ones([2,3,3])
9-
params['lambda_mats'] = np.tile(np.identity(4),[2,3,1,1])
10-
model.set_params(*params.values())
11-
print(model.get_params())
4+
# model = gaussianmixture.GenModel(pi_vec=np.array([0.444,0.444,0.112]),
5+
# mu_vecs=np.array([[-2.8],[-0.8],[2]]),
6+
# lambda_mats=np.array([[[6.25]],[[6.25]],[[100]]]))
7+
model = gaussianmixture.GenModel(mu_vecs=np.array([[2,2],[-2,-2]]))
8+
# model.gen_params()
9+
model.visualize_model()
27 KB
Loading

0 commit comments

Comments
 (0)