Skip to content

Commit dc67df4

Browse files
committed
Add set, get, etc to LearnModel
1 parent e458e3c commit dc67df4

File tree

2 files changed

+115
-79
lines changed

2 files changed

+115
-79
lines changed

bayesml/hiddenmarkovnormal/_hiddenmarkovnormal.py

Lines changed: 112 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def set_params(
9898
if a_mat is not None:
9999
_check.float_vecs_sum_1(a_mat, "a_mat", ParameterFormatError)
100100
_check.shape_consistency(
101-
a_mat.shape[-1], "a_mat.shape[-1]",
102-
self.c_num_classes, "self.c_num_classes",
101+
a_mat.shape[-1],"a_mat.shape[-1]",
102+
self.c_num_classes,"self.c_num_classes",
103103
ParameterFormatError
104104
)
105105
self.a_mat[:] = a_mat
@@ -184,19 +184,15 @@ def get_h_params(self):
184184
'h_nus':self.h_nus,
185185
'h_w_mats':self.h_w_mats}
186186

187-
# まだ実装しなくてよい
188187
def gen_params(self):
189188
pass
190189

191-
# まだ実装しなくてよい
192190
def gen_sample(self):
193191
pass
194192

195-
# まだ実装しなくてよい
196193
def save_sample(self):
197194
pass
198195

199-
# まだ実装しなくてよい
200196
def visualize_model(self):
201197
pass
202198

@@ -207,12 +203,12 @@ def __init__(
207203
c_num_classes,
208204
c_degree,
209205
*,
206+
h0_eta_vec=None,
207+
h0_zeta_vecs=None,
210208
h0_m_vecs=None,
211209
h0_kappas=None,
212210
h0_nus=None,
213211
h0_w_mats=None,
214-
h0_eta_vec=None,
215-
h0_zeta_vecs=None,
216212
seed = None
217213
):
218214
# constants
@@ -221,22 +217,22 @@ def __init__(
221217
self.rng = np.random.default_rng(seed)
222218

223219
# h0_params
220+
self.h0_eta_vec = np.ones(self.c_num_classes) / 2.0
221+
self.h0_zeta_vecs = np.ones([self.c_num_classes,self.c_num_classes]) / 2.0
224222
self.h0_m_vecs = np.zeros([self.c_num_classes,self.c_degree])
225223
self.h0_kappas = np.ones([self.c_num_classes])
226224
self.h0_nus = np.ones(self.c_num_classes) * self.c_degree
227225
self.h0_w_mats = np.tile(np.identity(self.c_degree),[self.c_num_classes,1,1])
228226
self.h0_w_mats_inv = np.linalg.inv(self.h0_w_mats)
229-
self.h0_eta_vec = np.ones(self.c_num_classes) / 2.0
230-
self.h0_zeta_vecs = np.ones([self.c_num_classes,self.c_num_classes]) / 2.0
231227

232228
# hn_params
229+
self.hn_eta_vec = np.empty(self.c_num_classes)
230+
self.hn_zeta_vecs = np.empty([self.c_num_classes,self.c_num_classes])
233231
self.hn_m_vecs = np.empty([self.c_num_classes,self.c_degree])
234232
self.hn_kappas = np.empty([self.c_num_classes])
235233
self.hn_nus = np.empty(self.c_num_classes)
236234
self.hn_w_mats = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
237235
self.hn_w_mats_inv = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
238-
self.hn_eta_vec = np.empty(self.c_num_classes)
239-
self.hn_zeta_vecs = np.empty([self.c_num_classes,self.c_num_classes])
240236

241237
# p_params
242238
self.p_mu_vecs = np.empty([self.c_num_classes,self.c_degree])
@@ -245,118 +241,158 @@ def __init__(
245241
self.p_lambda_mats_inv = np.empty([self.c_num_classes,self.c_degree,self.c_degree])
246242

247243
self.set_h0_params(
244+
h0_eta_vec,
245+
h0_zeta_vecs,
248246
h0_m_vecs,
249247
h0_kappas,
250248
h0_nus,
251249
h0_w_mats,
252-
h0_eta_vec,
253-
h0_zeta_vecs,
254250
)
255251

256252
def set_h0_params(
257253
self,
254+
h0_eta_vec = None,
255+
h0_zeta_vecs = None,
258256
h0_m_vecs = None,
259257
h0_kappas = None,
260258
h0_nus = None,
261259
h0_w_mats = None,
262-
h0_eta_vec = None,
263-
h0_zeta_vecs = None,
264260
):
265-
# Noneでない入力について,以下をチェックする.
266-
# * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など)
267-
# * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める
268-
# 例
269-
# if h0_m_vecs is not None:
270-
# _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
271-
# if h0_m_vecs.shape[-1] != self.degree:
272-
# raise(ParameterFormatError(
273-
# "h0_m_vecs.shape[-1] must coincide with self.degree:"
274-
# +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}"))
275-
# self.h0_m_vecs[:] = h0_m_vecs
276-
277-
# 最後にreset_hn_params()を呼ぶようにする
261+
if h0_eta_vec is not None:
262+
_check.pos_floats(h0_eta_vec,'h0_eta_vec',ParameterFormatError)
263+
self.h0_eta_vec[:] = h0_eta_vec
264+
265+
if h0_zeta_vecs is not None:
266+
_check.pos_floats(h0_zeta_vecs, 'h0_zeta_vecs', ParameterFormatError)
267+
self.h0_zeta_vecs[:] = h0_zeta_vecs
268+
269+
if h0_m_vecs is not None:
270+
_check.float_vecs(h0_m_vecs, "h0_m_vecs", ParameterFormatError)
271+
_check.shape_consistency(
272+
h0_m_vecs.shape[-1],"h0_m_vecs.shape[-1]",
273+
self.c_degree,"self.c_degree",
274+
ParameterFormatError
275+
)
276+
self.h0_m_vecs[:] = h0_m_vecs
277+
278+
if h0_kappas is not None:
279+
_check.pos_floats(h0_kappas, "h0_kappas", ParameterFormatError)
280+
self.h0_kappas[:] = h0_kappas
281+
282+
if h0_nus is not None:
283+
_check.floats(h0_nus, "h0_nus", ParameterFormatError)
284+
if np.all(h0_nus <= self.c_degree - 1):
285+
raise(ParameterFormatError(
286+
"All the values in h0_nus must be greater than self.c_degree - 1: "
287+
+ f"self.c_degree = {self.c_degree}, h0_nus = {h0_nus}"))
288+
self.h0_nus[:] = h0_nus
289+
290+
if h0_w_mats is not None:
291+
_check.pos_def_sym_mats(h0_w_mats,'h0_w_mats',ParameterFormatError)
292+
_check.shape_consistency(
293+
h0_w_mats.shape[-1],"h0_w_mats.shape[-1] and h0_w_mats.shape[-2]",
294+
self.c_degree,"self.c_degree",
295+
ParameterFormatError
296+
)
297+
self.h0_w_mats[:] = h0_w_mats
298+
299+
self.h0_w_mats_inv[:] = np.linalg.inv(self.h0_w_mats)
300+
278301
self.reset_hn_params()
279302

280303
def get_h0_params(self):
281-
# h0_paramsを辞書として返す関数.
282-
# 要素の順番はset_h_paramsの引数の順にそろえる.
283-
pass
304+
return {'h0_eta_vec':self.h0_eta_vec,
305+
'h0_zeta_vecs':self.h0_zeta_vecs,
306+
'h0_m_vecs':self.h0_m_vecs,
307+
'h0_kappas':self.h0_kappas,
308+
'h0_nus':self.h0_nus,
309+
'h0_w_mats':self.h0_w_mats}
284310

285311
def set_hn_params(
286312
self,
313+
hn_eta_vec = None,
314+
hn_zeta_vecs = None,
287315
hn_m_vecs = None,
288316
hn_kappas = None,
289317
hn_nus = None,
290318
hn_w_mats = None,
291-
hn_eta_vec = None,
292-
hn_zeta_vecs = None,
293319
):
294-
# Noneでない入力について,以下をチェックする.
295-
# * それ単体として,モデルの仮定を満たすか(符号,行列の正定値性など)
296-
# * 配列のサイズなどがconstants(c_で始まる変数)と整合しているか.ただし,ブロードキャスト可能なものは認める
297-
# 例
298-
# if h0_m_vecs is not None:
299-
# _check.float_vecs(h0_m_vecs,'h0_m_vecs',ParameterFormatError)
300-
# if h0_m_vecs.shape[-1] != self.degree:
301-
# raise(ParameterFormatError(
302-
# "h0_m_vecs.shape[-1] must coincide with self.degree:"
303-
# +f"h0_m_vecs.shape[-1]={h0_m_vecs.shape[-1]}, self.degree={self.degree}"))
304-
# self.h0_m_vecs[:] = h0_m_vecs
305-
306-
# 最後にcalc_pred_dist()を呼ぶようにする
320+
if hn_eta_vec is not None:
321+
_check.pos_floats(hn_eta_vec,'hn_eta_vec',ParameterFormatError)
322+
self.hn_eta_vec[:] = hn_eta_vec
323+
324+
if hn_zeta_vecs is not None:
325+
_check.pos_floats(hn_zeta_vecs, 'hn_zeta_vecs', ParameterFormatError)
326+
self.hn_zeta_vecs[:] = hn_zeta_vecs
327+
328+
if hn_m_vecs is not None:
329+
_check.float_vecs(hn_m_vecs, "hn_m_vecs", ParameterFormatError)
330+
_check.shape_consistency(
331+
hn_m_vecs.shape[-1],"hn_m_vecs.shape[-1]",
332+
self.c_degree,"self.c_degree",
333+
ParameterFormatError
334+
)
335+
self.hn_m_vecs[:] = hn_m_vecs
336+
337+
if hn_kappas is not None:
338+
_check.pos_floats(hn_kappas, "hn_kappas", ParameterFormatError)
339+
self.hn_kappas[:] = hn_kappas
340+
341+
if hn_nus is not None:
342+
_check.floats(hn_nus, "hn_nus", ParameterFormatError)
343+
if np.all(hn_nus <= self.c_degree - 1):
344+
raise(ParameterFormatError(
345+
"All the values in hn_nus must be greater than self.c_degree - 1: "
346+
+ f"self.c_degree = {self.c_degree}, hn_nus = {hn_nus}"))
347+
self.hn_nus[:] = hn_nus
348+
349+
if hn_w_mats is not None:
350+
_check.pos_def_sym_mats(hn_w_mats,'hn_w_mats',ParameterFormatError)
351+
_check.shape_consistency(
352+
hn_w_mats.shape[-1],"hn_w_mats.shape[-1] and hn_w_mats.shape[-2]",
353+
self.c_degree,"self.c_degree",
354+
ParameterFormatError
355+
)
356+
self.hn_w_mats[:] = hn_w_mats
357+
358+
self.hn_w_mats_inv[:] = np.linalg.inv(self.hn_w_mats)
359+
307360
self.calc_pred_dist()
308361

309362
def get_hn_params(self):
310-
# hn_paramsを辞書として返す関数.
311-
# 要素の順番はset_h_paramsの引数の順にそろえる.
312-
pass
363+
return {'hn_eta_vec':self.hn_eta_vec,
364+
'hn_zeta_vecs':self.hn_zeta_vecs,
365+
'hn_m_vecs':self.hn_m_vecs,
366+
'hn_kappas':self.hn_kappas,
367+
'hn_nus':self.hn_nus,
368+
'hn_w_mats':self.hn_w_mats}
313369

314370
def reset_hn_params(self):
315-
# h0_paramsの値をhn_paramsの値にそのままコピーする
316-
# 配列サイズを揃えてあるので,簡単に書けるはず.
317-
# 例
318-
# self.hn_alpha_vec[:] = self.h0_alpha_vec
319-
# self.hn_m_vecs[:] = self.h0_m_vecs
320-
# self.hn_kappas[:] = self.h0_kappas
321-
322-
# 最後にcalc_pred_distを呼ぶ.
323-
self.calc_pred_dist()
371+
self.set_hn_params(*self.get_h0_params().values())
324372

325373
def overwrite_h0_params(self):
326-
# hn_paramsの値をh0_paramsの値にそのままコピーする
327-
# 配列サイズを揃えてあるので,簡単に書けるはず.
328-
# 例
329-
# self.h0_alpha_vec[:] = self.hn_alpha_vec
330-
# self.h0_m_vecs[:] = self.hn_m_vecs
331-
# self.h0_kappas[:] = self.hn_kappas
332-
333-
# 最後にcalc_pred_distを呼ぶ.
334-
self.calc_pred_dist()
374+
self.set_h0_params(*self.get_hn_params().values())
335375

336-
# まだ実装しなくてよい
337376
def update_posterior():
338377
pass
339378

340-
# まだ実装しなくてよい
341379
def estimate_params(self,loss="squared"):
342380
pass
343381

344-
# まだ実装しなくてよい
345382
def visualize_posterior(self):
346383
pass
347384

348385
def get_p_params(self):
349-
# p_paramsを辞書として返す関数.
350-
pass
386+
return {'p_mu_vecs':self.p_mu_vecs,
387+
'p_nus':self.p_nus,
388+
'p_lambda_mats':self.p_lambda_mats,
389+
'p_lambda_mats_inv':self.p_lambda_mats_inv}
351390

352-
# まだ実装しなくてよい
353391
def calc_pred_dist(self):
354392
pass
355393

356-
# まだ実装しなくてよい
357394
def make_prediction(self,loss="squared"):
358395
pass
359396

360-
# まだ実装しなくてよい
361397
def pred_and_update(self,x,loss="squared"):
362398
pass

bayesml/hiddenmarkovnormal/test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from bayesml import hiddenmarkovnormal
22
import numpy as np
33

4-
model = hiddenmarkovnormal.GenModel(
5-
3,2,h_w_mats=np.tile(np.identity(2)*2,[4,1,1]))
4+
model = hiddenmarkovnormal.LearnModel(
5+
3,2,h0_eta_vec=2)
66

7-
print(model.get_h_params())
7+
print(model.get_hn_params())

0 commit comments

Comments
 (0)