Skip to content

Commit 9b7dc1d

Browse files
committed
Add gen_ and visualize
1 parent cb81f22 commit 9b7dc1d

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,12 @@ def gen_params(self):
301301
302302
To confirm the generated vaules, use `self.get_params()`.
303303
"""
304-
pass
304+
self.pi_vec[:] = self.rng.dirichlet(self.h_eta_vec)
305+
for k in range(self.c_num_classes):
306+
self.a_mat[k] = self.rng.dirichlet(self.h_zeta_vecs[k])
307+
for k in range(self.c_num_classes):
308+
self.lambda_mats[k] = ss_wishart.rvs(df=self.h_nus[k],scale=self.h_w_mats[k],random_state=self.rng)
309+
self.mu_vecs[k] = self.rng.multivariate_normal(mean=self.h_m_vecs[k],cov=np.linalg.inv(self.h_kappas[k]*self.lambda_mats[k]))
305310

306311
def gen_sample(self,sample_length):
307312
"""Generate a sample from the stochastic data generative model.
@@ -322,6 +327,21 @@ def gen_sample(self,sample_length):
322327
``(sample_length,c_num_classes)``
323328
whose rows are one-hot vectors.
324329
"""
330+
_check.pos_int(sample_length,'sample_length',DataFormatError)
331+
z = np.zeros([sample_length,self.c_num_classes],dtype=int)
332+
x = np.empty([sample_length,self.c_degree])
333+
_lambda_mats_inv = np.linalg.inv(self.lambda_mats)
334+
335+
# i=0
336+
k = self.rng.choice(self.c_num_classes,p=self.pi_vec)
337+
z[0,k] = 1
338+
x[0] = self.rng.multivariate_normal(mean=self.mu_vecs[k],cov=_lambda_mats_inv[k])
339+
# i>0
340+
for i in range(1,sample_length):
341+
k = self.rng.choice(self.c_num_classes,p=self.a_mat[np.argmax(z[i-1])])
342+
z[i,k] = 1
343+
x[i] = self.rng.multivariate_normal(mean=self.mu_vecs[k],cov=_lambda_mats_inv[k])
344+
return x,z
325345

326346
def save_sample(self,filename,sample_length):
327347
"""Save the generated sample as NumPy ``.npz`` format.
@@ -340,8 +360,10 @@ def save_sample(self,filename,sample_length):
340360
--------
341361
numpy.savez_compressed
342362
"""
363+
x,z=self.gen_sample(sample_length)
364+
np.savez_compressed(filename,x=x,z=z)
343365

344-
def visualize_model(self,sample_length=100):
366+
def visualize_model(self,sample_length=200):
345367
"""Visualize the stochastic data generative model and generated samples.
346368
347369
Parameters
@@ -352,9 +374,61 @@ def visualize_model(self,sample_length=100):
352374
Examples
353375
--------
354376
>>> from bayesml import hiddenmarkovnormal
355-
>>> model = hiddenmarkovnormal.GenModel(c_num_classes=2,c_degree=1)
377+
>>> import numpy as np
378+
>>> model = hiddenmarkovnormal.GenModel(
379+
c_num_classes=2,
380+
c_degree=1,
381+
mu_vecs=np.array([[5],[-5]]),
382+
a_mat=np.array([[0.95,0.05],[0.1,0.9]]))
356383
>>> model.visualize_model()
384+
pi_vec:
385+
[0.5 0.5]
386+
a_mat:
387+
[[0.95 0.05]
388+
[0.1 0.9 ]]
389+
mu_vecs:
390+
[[ 5.]
391+
[-5.]]
392+
lambda_mats:
393+
[[[1.]]
394+
395+
[[1.]]]
396+
397+
.. image:: ./images/hiddenmarkovnormal_example.png
357398
"""
399+
if self.c_degree == 1:
400+
print(f"pi_vec:\n {self.pi_vec}")
401+
print(f"a_mat:\n {self.a_mat}")
402+
print(f"mu_vecs:\n {self.mu_vecs}")
403+
print(f"lambda_mats:\n {self.lambda_mats}")
404+
_lambda_mats_inv = np.linalg.inv(self.lambda_mats)
405+
fig, axes = plt.subplots()
406+
sample, latent_vars = self.gen_sample(sample_length)
407+
408+
change_points = [0]
409+
for i in range(1,sample_length):
410+
if np.any(latent_vars[i-1] != latent_vars[i]):
411+
change_points.append(i)
412+
change_points.append(sample_length)
413+
414+
cm = plt.get_cmap('jet')
415+
for i in range(1,len(change_points)):
416+
axes.axvspan(
417+
change_points[i-1],
418+
change_points[i],
419+
color=cm(
420+
int((np.argmax(latent_vars[change_points[i-1]])
421+
/ (self.c_num_classes-1)) * 255)
422+
),
423+
alpha=0.3,
424+
ls='',
425+
)
426+
axes.plot(np.arange(sample.shape[0]),sample)
427+
axes.set_xlabel("time")
428+
axes.set_ylabel("x")
429+
plt.show()
430+
else:
431+
raise(ParameterFormatError("if c_degree > 1, it is impossible to visualize the model by this function."))
358432

359433
class LearnModel(base.Posterior,base.PredictiveMixin):
360434
def __init__(
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from bayesml import hiddenmarkovnormal
2+
import numpy as np
3+
4+
model = hiddenmarkovnormal.GenModel(3,1)
5+
6+
print(model.get_params())
7+
8+
model.set_params(mu_vecs=np.ones([3,1]))
9+
10+
print(model.get_params())
41 KB
Loading

0 commit comments

Comments
 (0)