Skip to content

Commit 4a282f3

Browse files
committed
Merge branch 'develop' into develop-contexttree-GenModel
2 parents ae01e36 + 40aef2c commit 4a282f3

File tree

7 files changed

+566
-39
lines changed

7 files changed

+566
-39
lines changed

bayesml/_check.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Code Author
22
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
4+
# Yasushi Esaki <esakiful@gmail.com>
5+
# Jun Nishikawa <jun.b.nishikawa@gmail.com>
46
import numpy as np
57

68
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
@@ -179,6 +181,14 @@ def float_vec_sum_1(val,val_name,exception_class):
179181
return val
180182
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
181183

184+
def float_vecs_sum_1(val,val_name,exception_class):
185+
if type(val) is np.ndarray:
186+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
187+
return val.astype(float)
188+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1 and np.all(np.abs(np.sum(val, axis=-1) - 1.) <= _EPSILON):
189+
return val
190+
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1, and the sum along the last dimension must equal to 1."))
191+
182192
def int_(val,val_name,exception_class):
183193
if np.issubdtype(type(val),np.integer):
184194
return val
@@ -205,3 +215,9 @@ def onehot_vecs(val,val_name,exception_class):
205215
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val >= 0) and np.all(val.sum(axis=-1)==1):
206216
return val
207217
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))
218+
219+
def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
220+
if val != correct:
221+
message = (f"{val_name} must coincide with {correct_name}: "
222+
+ f"{val_name} = {val}, {correct_name} = {correct}")
223+
raise(exception_class(message))

bayesml/base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,21 @@ def load_hn_params(self,filename):
257257
+'or ``LearnModel.save_hn_params()``.')
258258
)
259259

260-
@abstractmethod
261260
def reset_hn_params(self):
262-
pass
263-
264-
@abstractmethod
261+
"""Reset the hyperparameters of the posterior distribution to their initial values.
262+
263+
They are reset to the output of `self.get_h0_params()`.
264+
Note that the parameters of the predictive distribution are also calculated from them.
265+
"""
266+
self.set_hn_params(*self.get_h0_params().values())
267+
265268
def overwrite_h0_params(self):
266-
pass
269+
"""Overwrite the initial values of the hyperparameters of the posterior distribution by the learned values.
270+
271+
They are overwitten by the output of `self.get_hn_params()`.
272+
Note that the parameters of the predictive distribution are also calculated from them.
273+
"""
274+
self.set_h0_params(*self.get_hn_params().values())
267275

268276
@abstractmethod
269277
def update_posterior(self):

doc/devdoc/abstract_class.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<img src="../logos/BayesML_logo.png" width="200">
22

3-
# 抽象クラス概要 Ver.4
3+
# 抽象クラス概要 Ver.5
44
<div style="text-align:right">
55
作成:中原
66
</div>
@@ -10,7 +10,7 @@
1010
データ生成観測確率モデルとその事前分布の抽象基底クラス.GenModelクラスに継承することで,以下の名前のメソッドの実装を強いる.
1111

1212
* `def set_h_params(self):`
13-
* 事前分布のハイパーパラメータを設定するためのメソッド.入力されたハイパーパラメータが理論上の仮定(分散共分散行列の正定値性等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.GenModelの`get_h_params()`,LearnModelの`get_h0_params()`, `get_hn_params()`で得られる辞書の順序と位置引数の順序を揃えるようにする.Python 3.7以降では辞書の要素の順序が保持されるようになったため,これにより`genmodel.set_h_params(*learnmodel.get_hn_params().vaules())`といったアンパック渡しの機能を活かした使い方が可能となる.
13+
* 事前分布のハイパーパラメータを設定するためのメソッド.`c_`で始まる変数は変更しなくてよく,それに整合しない場合はエラーを返す.入力されたハイパーパラメータが理論上の仮定(分散共分散行列の正定値性等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.GenModelの`get_h_params()`,LearnModelの`get_h0_params()`, `get_hn_params()`で得られる辞書の順序と位置引数の順序を揃えるようにする.Python 3.7以降では辞書の要素の順序が保持されるようになったため,これにより`genmodel.set_h_params(*learnmodel.get_hn_params().vaules())`といったアンパック渡しの機能を活かした使い方が可能となる.
1414
* `def get_h_params(self):`
1515
* 事前分布のハイパーパラメータを返すメソッド.ハイパーパラメータ名をキーとする辞書を返す.
1616
* `def save_h_params(self):`(抽象クラスではない)
@@ -39,7 +39,7 @@
3939
データ生成観測確率モデルのパラメータ事後分布の抽象基底クラス.LearnModelクラスに継承することで以下のメソッドの実装を強いる.
4040

4141
* `def set_h0_params(self):`
42-
* 事後分布のハイパーパラメータの初期値を設定するためのメソッド(`reset_hn_params()`を呼ぶことで,`hn_`で始まる事後分布ハイパーパラメータや`p_`で始まる予測分布パラメータも同時に初期化する).入力されたハイパーパラメータが理論上の仮定(分散共分散行列の正定値性等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.GenModelの`get_h_params()`,LearnModelの`get_h0_params()`, `get_hn_params()`で得られる辞書の順序と位置引数の順序を揃えるようにする.Python 3.7以降では辞書の要素の順序が保持されるようになったため,これにより`learnmodel.set_h_params(*genmodel.get_hn_params().vaules())`といったアンパック渡しの機能を活かした使い方が可能となる.
42+
* 事後分布のハイパーパラメータの初期値を設定するためのメソッド(`reset_hn_params()`を呼ぶことで,`hn_`で始まる事後分布ハイパーパラメータや`p_`で始まる予測分布パラメータも同時に初期化する).`c_`で始まる変数は変更しなくてよく,それに整合しない場合はエラーを返す.入力されたハイパーパラメータが理論上の仮定(分散共分散行列の正定値性等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.GenModelの`get_h_params()`,LearnModelの`get_h0_params()`, `get_hn_params()`で得られる辞書の順序と位置引数の順序を揃えるようにする.Python 3.7以降では辞書の要素の順序が保持されるようになったため,これにより`learnmodel.set_h_params(*genmodel.get_hn_params().vaules())`といったアンパック渡しの機能を活かした使い方が可能となる.
4343
* `def get_h0_params(self):`
4444
* 事後分布のハイパーパラメータの初期値を返すメソッド.ハイパーパラメータ名をキーとする辞書を返す.
4545
* `def save_h0_params(self):`(抽象クラスではない)
@@ -49,17 +49,17 @@
4949
* `def get_hn_params(self):`
5050
* データに基づいて更新された事後分布のハイパーパラメータを返すメソッド.ハイパーパラメータ名をキーとする辞書を返す.
5151
* `def set_hn_params(self):`
52-
* 更新後の事後分布のハイパーパラメータを直接設定するためのメソッド(`calc_pred_dist()`を用いて`p_`で始まる予測分布パラメータも同時に初期化する).入力されたハイパーパラメータが理論上の仮定(分散共分散行列の正定値性等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.GenModelの`get_h_params()`,LearnModelの`get_h0_params()`, `get_hn_params()`で得られる辞書の順序と位置引数の順序を揃えるようにする.Python 3.7以降では辞書の要素の順序が保持されるようになったため,これにより`learnmodel.set_h_params(*genmodel.get_hn_params().vaules())`といったアンパック渡しの機能を活かした使い方が可能となる.
52+
* 更新後の事後分布のハイパーパラメータを直接設定するためのメソッド(`calc_pred_dist()`を用いて`p_`で始まる予測分布パラメータも同時に初期化する).`c_`で始まる変数は変更しなくてよく,それに整合しない場合はエラーを返す.入力されたハイパーパラメータが理論上の仮定(分散共分散行列の正定値性等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.GenModelの`get_h_params()`,LearnModelの`get_h0_params()`, `get_hn_params()`で得られる辞書の順序と位置引数の順序を揃えるようにする.Python 3.7以降では辞書の要素の順序が保持されるようになったため,これにより`learnmodel.set_h_params(*genmodel.get_hn_params().vaules())`といったアンパック渡しの機能を活かした使い方が可能となる.
5353
* `def save_hn_params(self):`(抽象クラスではない)
5454
* データに基づいて更新された事後分布のハイパーパラメータをファイルに保存するメソッド.`get_hn_params(self):`さえ正しく実装されていれば汎用的に機能するよう`bayesml/base.py`に実装済みなので,基本的に個別のモデルでオーバーライドする必要はない.
5555
* `def load_hn_params(self):`(抽象クラスではない)
5656
* `save_hn_params`で保存したハイパーパラメータを読み込んで`set_hn_params`で設定するメソッド.`get_hn_params(self):`さえ正しく実装されていれば汎用的に機能するよう`bayesml/base.py`に実装済みなので,基本的に個別のモデルでオーバーライドする必要はない.
57-
* `def reset_hn_params(self):`
58-
* 更新後の事後分布ハイパーパラメータ(`hn_`で始まるハイパーパラメータの値)を初期値(`h0_`で始まるハイパーパラメータの値)に設定し直すメソッド.`calc_pred_dist()`を用いて`p_`で始まる予測分布パラメータも同時に初期化する
59-
* `def overwrite_h0_params(self):`
60-
* 事後分布のハイパーパラメータの初期値(`h0_`で始まるハイパーパラメータの値)を更新後の事後分布ハイパーパラメータ(`hn_`で始まるハイパーパラメータの値)で上書きするメソッド.`calc_pred_dist()`を用いて`p_`で始まる予測分布パラメータも同時に初期化する
57+
* `def reset_hn_params(self):`(抽象クラスではない)
58+
* 更新後の事後分布ハイパーパラメータ(`hn_`で始まるハイパーパラメータの値)を初期値(`h0_`で始まるハイパーパラメータの値)に設定し直すメソッド.`get_h0_params(self):`, `set_hn_params(self):`さえ正しく実装されていれば汎用的に機能するよう`bayesml/base.py`に実装済みなので,基本的に個別のモデルでオーバーライドする必要はない
59+
* `def overwrite_h0_params(self):`(抽象クラスではない)
60+
* 事後分布のハイパーパラメータの初期値(`h0_`で始まるハイパーパラメータの値)を更新後の事後分布ハイパーパラメータ(`hn_`で始まるハイパーパラメータの値)で上書きするメソッド.`get_hn_params(self):`, `set_h0_params(self):`さえ正しく実装されていれば汎用的に機能するよう`bayesml/base.py`に実装済みなので,基本的に個別のモデルでオーバーライドする必要はない
6161
* `def update_posterior(self):`
62-
* データに基づいて事後分布のハイパーパラメータを更新するためのメソッド.データは引数として渡し,変数として保持しない.データが理論上の仮定(整数かどうか等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.
62+
* データに基づいて事後分布のハイパーパラメータを更新するためのメソッド.in-placeな処理を心がける.データは引数として渡し,変数として保持しない.データが理論上の仮定(整数かどうか等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.
6363
* `def estimate_params(self):`
6464
* データに基づいて更新された事後分布を用いてパラメータを推定するためのメソッド.推定の評価基準をオプション`loss="criteria"`として指定することで,出力値が変わる.事後分布の種類によってはmodeが存在しない場合などもあるので,そういった場合には`None`を返し,警告を表示するようにする.
6565
* `def visualize_posterior(self):`
@@ -78,7 +78,7 @@
7878
* `def load_p_params(self):`(抽象クラスではない)
7979
* `save_p_params`で保存したハイパーパラメータを読み込んで`set_p_params`で設定するメソッド.`get_p_params(self):`さえ正しく実装されていれば汎用的に機能するよう`bayesml/base.py`に実装済みなので,基本的に個別のモデルでオーバーライドする必要はない. -->
8080
* `def calc_pred_dist(self):`
81-
* 事後分布のハイパーパラメータと新規データから予測分布のパラメータを計算するためのメソッド.新規データが理論上の仮定(整数かどうか等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.
81+
* 事後分布のハイパーパラメータと新規データから予測分布のパラメータを計算するためのメソッド.in-placeな処理を心がける.新規データが理論上の仮定(整数かどうか等)を満たさない時はエラーを返すようにする.よく使う入力値チェックは`bayesml/_check.py`に書いておく.
8282
* `def make_prediction(self):`
8383
* 予測分布を用いて新規データを予測するためのメソッド.予測の評価基準をオプション`loss="criteria"`として指定することで,出力値が変わる.予測分布の種類によってはmodeが存在しない場合などもあるので,そういった場合には`None`を返し,警告を表示するようにする.
8484
* `def pred_and_update(self):`

0 commit comments

Comments
 (0)