Skip to content

Commit e458e3c

Browse files
committed
Added get_params and get_h_params
1 parent d066598 commit e458e3c

File tree

2 files changed

+30
-11
lines changed

2 files changed

+30
-11
lines changed

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ def __init__(
2525
a_mat=None,
2626
mu_vecs=None,
2727
lambda_mats=None,
28+
h_eta_vec=None,
29+
h_zeta_vecs=None,
2830
h_m_vecs=None,
2931
h_kappas=None,
3032
h_nus=None,
3133
h_w_mats=None,
32-
h_eta_vec=None,
33-
h_zeta_vecs=None,
3434
seed=None
3535
):
3636
# constants
@@ -52,8 +52,19 @@ def __init__(
5252
self.h_nus = np.ones(self.c_num_classes) * self.c_degree
5353
self.h_w_mats = np.tile(np.identity(self.c_degree),[self.c_num_classes,1,1])
5454

55-
self.set_params(pi_vec,a_mat,mu_vecs,lambda_mats)
56-
self.set_h_params(h_eta_vec,h_zeta_vecs,h_m_vecs,h_kappas,h_nus,h_w_mats)
55+
self.set_params(
56+
pi_vec,
57+
a_mat,
58+
mu_vecs,
59+
lambda_mats)
60+
61+
self.set_h_params(
62+
h_eta_vec,
63+
h_zeta_vecs,
64+
h_m_vecs,
65+
h_kappas,
66+
h_nus,
67+
h_w_mats)
5768

5869
def set_params(
5970
self,
@@ -160,14 +171,18 @@ def set_h_params(
160171
self.h_w_mats[:] = h_w_mats
161172

162173
def get_params(self):
163-
# paramsを辞書として返す関数.
164-
# 要素の順番はset_paramsの引数の順にそろえる.
165-
pass
174+
return {'pi_vec':self.pi_vec,
175+
'a_mat':self.a_mat,
176+
'mu_vecs':self.mu_vecs,
177+
'lambda_mats': self.lambda_mats}
166178

167179
def get_h_params(self):
168-
# h_paramsを辞書として返す関数.
169-
# 要素の順番はset_h_paramsの引数の順にそろえる.
170-
pass
180+
return {'h_eta_vec':self.h_eta_vec,
181+
'h_zeta_vecs':self.h_zeta_vecs,
182+
'h_m_vecs':self.h_m_vecs,
183+
'h_kappas':self.h_kappas,
184+
'h_nus':self.h_nus,
185+
'h_w_mats':self.h_w_mats}
171186

172187
# まだ実装しなくてよい
173188
def gen_params(self):

bayesml/hiddenmarkovnormal/test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from bayesml import hiddenmarkovnormal
2+
import numpy as np
23

3-
model = hiddenmarkovnormal.LearnModel(2,2)
4+
model = hiddenmarkovnormal.GenModel(
5+
3,2,h_w_mats=np.tile(np.identity(2)*2,[4,1,1]))
6+
7+
print(model.get_h_params())

0 commit comments

Comments
 (0)