@@ -301,7 +301,12 @@ def gen_params(self):
301
301
302
302
To confirm the generated vaules, use `self.get_params()`.
303
303
"""
304
- pass
304
+ self .pi_vec [:] = self .rng .dirichlet (self .h_eta_vec )
305
+ for k in range (self .c_num_classes ):
306
+ self .a_mat [k ] = self .rng .dirichlet (self .h_zeta_vecs [k ])
307
+ for k in range (self .c_num_classes ):
308
+ self .lambda_mats [k ] = ss_wishart .rvs (df = self .h_nus [k ],scale = self .h_w_mats [k ],random_state = self .rng )
309
+ self .mu_vecs [k ] = self .rng .multivariate_normal (mean = self .h_m_vecs [k ],cov = np .linalg .inv (self .h_kappas [k ]* self .lambda_mats [k ]))
305
310
306
311
def gen_sample (self ,sample_length ):
307
312
"""Generate a sample from the stochastic data generative model.
@@ -322,6 +327,21 @@ def gen_sample(self,sample_length):
322
327
``(sample_length,c_num_classes)``
323
328
whose rows are one-hot vectors.
324
329
"""
330
+ _check .pos_int (sample_length ,'sample_length' ,DataFormatError )
331
+ z = np .zeros ([sample_length ,self .c_num_classes ],dtype = int )
332
+ x = np .empty ([sample_length ,self .c_degree ])
333
+ _lambda_mats_inv = np .linalg .inv (self .lambda_mats )
334
+
335
+ # i=0
336
+ k = self .rng .choice (self .c_num_classes ,p = self .pi_vec )
337
+ z [0 ,k ] = 1
338
+ x [0 ] = self .rng .multivariate_normal (mean = self .mu_vecs [k ],cov = _lambda_mats_inv [k ])
339
+ # i>0
340
+ for i in range (1 ,sample_length ):
341
+ k = self .rng .choice (self .c_num_classes ,p = self .a_mat [np .argmax (z [i - 1 ])])
342
+ z [i ,k ] = 1
343
+ x [i ] = self .rng .multivariate_normal (mean = self .mu_vecs [k ],cov = _lambda_mats_inv [k ])
344
+ return x ,z
325
345
326
346
def save_sample (self ,filename ,sample_length ):
327
347
"""Save the generated sample as NumPy ``.npz`` format.
@@ -340,8 +360,10 @@ def save_sample(self,filename,sample_length):
340
360
--------
341
361
numpy.savez_compressed
342
362
"""
363
+ x ,z = self .gen_sample (sample_length )
364
+ np .savez_compressed (filename ,x = x ,z = z )
343
365
344
- def visualize_model (self ,sample_length = 100 ):
366
+ def visualize_model (self ,sample_length = 200 ):
345
367
"""Visualize the stochastic data generative model and generated samples.
346
368
347
369
Parameters
@@ -352,9 +374,61 @@ def visualize_model(self,sample_length=100):
352
374
Examples
353
375
--------
354
376
>>> from bayesml import hiddenmarkovnormal
355
- >>> model = hiddenmarkovnormal.GenModel(c_num_classes=2,c_degree=1)
377
+ >>> import numpy as np
378
+ >>> model = hiddenmarkovnormal.GenModel(
379
+ c_num_classes=2,
380
+ c_degree=1,
381
+ mu_vecs=np.array([[5],[-5]]),
382
+ a_mat=np.array([[0.95,0.05],[0.1,0.9]]))
356
383
>>> model.visualize_model()
384
+ pi_vec:
385
+ [0.5 0.5]
386
+ a_mat:
387
+ [[0.95 0.05]
388
+ [0.1 0.9 ]]
389
+ mu_vecs:
390
+ [[ 5.]
391
+ [-5.]]
392
+ lambda_mats:
393
+ [[[1.]]
394
+
395
+ [[1.]]]
396
+
397
+ .. image:: ./images/hiddenmarkovnormal_example.png
357
398
"""
399
+ if self .c_degree == 1 :
400
+ print (f"pi_vec:\n { self .pi_vec } " )
401
+ print (f"a_mat:\n { self .a_mat } " )
402
+ print (f"mu_vecs:\n { self .mu_vecs } " )
403
+ print (f"lambda_mats:\n { self .lambda_mats } " )
404
+ _lambda_mats_inv = np .linalg .inv (self .lambda_mats )
405
+ fig , axes = plt .subplots ()
406
+ sample , latent_vars = self .gen_sample (sample_length )
407
+
408
+ change_points = [0 ]
409
+ for i in range (1 ,sample_length ):
410
+ if np .any (latent_vars [i - 1 ] != latent_vars [i ]):
411
+ change_points .append (i )
412
+ change_points .append (sample_length )
413
+
414
+ cm = plt .get_cmap ('jet' )
415
+ for i in range (1 ,len (change_points )):
416
+ axes .axvspan (
417
+ change_points [i - 1 ],
418
+ change_points [i ],
419
+ color = cm (
420
+ int ((np .argmax (latent_vars [change_points [i - 1 ]])
421
+ / (self .c_num_classes - 1 )) * 255 )
422
+ ),
423
+ alpha = 0.3 ,
424
+ ls = '' ,
425
+ )
426
+ axes .plot (np .arange (sample .shape [0 ]),sample )
427
+ axes .set_xlabel ("time" )
428
+ axes .set_ylabel ("x" )
429
+ plt .show ()
430
+ else :
431
+ raise (ParameterFormatError ("if c_degree > 1, it is impossible to visualize the model by this function." ))
358
432
359
433
class LearnModel (base .Posterior ,base .PredictiveMixin ):
360
434
def __init__ (
0 commit comments