Skip to content

Commit 2c8d809

Browse files
Merge pull request #45 from yuta-nakahara/develop-metatree
Develop metatree
2 parents ecd526a + cc286c6 commit 2c8d809

File tree

15 files changed

+1666
-372
lines changed

15 files changed

+1666
-372
lines changed

bayesml/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from . import multivariate_normal
77
from . import normal
88
from . import poisson
9+
from . import metatree
910

1011
__all__ = ['bernoulli',
1112
'categorical',
@@ -14,5 +15,6 @@
1415
'linearregression',
1516
'multivariate_normal',
1617
'normal',
17-
'poisson'
18+
'poisson',
19+
'metatree'
1820
]

bayesml/_check.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def nonneg_int_vecs(val,val_name,exception_class):
6565
return val
6666
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1 and dtype is int. Its values must be non-negative (including 0)."))
6767

68+
6869
def nonneg_float_vec(val,val_name,exception_class):
6970
if type(val) is np.ndarray:
7071
if np.issubdtype(val.dtype,np.floating) and val.ndim == 1 and np.all(val>=0):
@@ -242,8 +243,14 @@ def onehot_vecs(val,val_name,exception_class):
242243
return val
243244
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))
244245

246+
def int_vecs(val,val_name,exception_class):
247+
if type(val) is np.ndarray:
248+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1:
249+
return val
250+
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and ndim >= 1."))
251+
245252
def shape_consistency(val: int, val_name: str, correct: int, correct_name: str, exception_class):
246253
if val != correct:
247254
message = (f"{val_name} must coincide with {correct_name}: "
248255
+ f"{val_name} = {val}, {correct_name} = {correct}")
249-
raise(exception_class(message))
256+
raise(exception_class(message))

bayesml/bernoulli/_bernoulli.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,18 +272,20 @@ def update_posterior(self,x):
272272
self.hn_alpha += np.sum(x==1)
273273
self.hn_beta += np.sum(x==0)
274274

275-
def estimate_params(self,loss="squared"):
275+
def estimate_params(self,loss="squared",dict_out=False):
276276
"""Estimate the parameter of the stochastic data generative model under the given criterion.
277277
278278
Parameters
279279
----------
280280
loss : str, optional
281281
Loss function underlying the Bayes risk function, by default \"squared\".
282282
This function supports \"squared\", \"0-1\", \"abs\", and \"KL\".
283+
dict_out : bool, optional
284+
If ``True``, output will be a dict, by default ``False``.
283285
284286
Returns
285287
-------
286-
Estimator : {float, None, rv_frozen}
288+
estimator : {float, None, rv_frozen} or dict of {str : float, None}
287289
The estimated values under the given loss function. If it is not exist, `None` will be returned.
288290
If the loss function is \"KL\", the posterior distribution itself will be returned
289291
as rv_frozen object of scipy.stats.
@@ -294,19 +296,37 @@ def estimate_params(self,loss="squared"):
294296
scipy.stats.rv_discrete
295297
"""
296298
if loss == "squared":
297-
return self.hn_alpha / (self.hn_alpha + self.hn_beta)
299+
if dict_out:
300+
return {'theta':self.hn_alpha / (self.hn_alpha + self.hn_beta)}
301+
else:
302+
return self.hn_alpha / (self.hn_alpha + self.hn_beta)
298303
elif loss == "0-1":
299304
if self.hn_alpha > 1.0 and self.hn_beta > 1.0:
300-
return (self.hn_alpha - 1.0) / (self.hn_alpha + self.hn_beta - 2.0)
305+
if dict_out:
306+
return {'theta':(self.hn_alpha - 1.0) / (self.hn_alpha + self.hn_beta - 2.0)}
307+
else:
308+
return (self.hn_alpha - 1.0) / (self.hn_alpha + self.hn_beta - 2.0)
301309
elif self.hn_alpha > 1.0:
302-
return 1.0
310+
if dict_out:
311+
return {'theta':1.0}
312+
else:
313+
return 1.0
303314
elif self.hn_beta > 1.0:
304-
return 0.0
315+
if dict_out:
316+
return {'theta':0.0}
317+
else:
318+
return 0.0
305319
else:
306320
warnings.warn("MAP estimate doesn't exist for the current hn_alpha and hn_beta.",ResultWarning)
307-
return None
321+
if dict_out:
322+
return {'theta':None}
323+
else:
324+
return None
308325
elif loss == "abs":
309-
return ss_beta.median(self.hn_alpha,self.hn_beta)
326+
if dict_out:
327+
return {'theta':ss_beta.median(self.hn_alpha,self.hn_beta)}
328+
else:
329+
return ss_beta.median(self.hn_alpha,self.hn_beta)
310330
elif loss == "KL":
311331
return ss_beta(self.hn_alpha,self.hn_beta)
312332
else:

bayesml/categorical/_categorical.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,18 +356,20 @@ def update_posterior(self, x):
356356
for k in range(self.degree):
357357
self.hn_alpha_vec[k] += x[:,k].sum()
358358

359-
def estimate_params(self, loss="squared"):
359+
def estimate_params(self, loss="squared",dict_out=False):
360360
"""Estimate the parameter of the stochastic data generative model under the given criterion.
361361
362362
Parameters
363363
----------
364364
loss : str, optional
365365
Loss function underlying the Bayes risk function, by default \"squared\".
366366
This function supports \"squared\", \"0-1\", and \"KL\".
367+
dict_out : bool, optional
368+
If ``True``, output will be a dict, by default ``False``.
367369
368370
Returns
369371
-------
370-
Estimates : {numpy ndarray, float, None, or rv_frozen}
372+
estimates : {numpy ndarray, float, None, or rv_frozen}
371373
The estimated values under the given loss function. If it is not exist, `None` will be returned.
372374
If the loss function is \"KL\", the posterior distribution itself will be returned
373375
as rv_frozen object of scipy.stats.
@@ -378,10 +380,16 @@ def estimate_params(self, loss="squared"):
378380
scipy.stats.rv_discrete
379381
"""
380382
if loss == "squared":
381-
return self.hn_alpha_vec / np.sum(self.hn_alpha_vec)
383+
if dict_out:
384+
return {'theta_vec':self.hn_alpha_vec / np.sum(self.hn_alpha_vec)}
385+
else:
386+
return self.hn_alpha_vec / np.sum(self.hn_alpha_vec)
382387
elif loss == "0-1":
383388
if np.all(self.hn_alpha_vec > 1):
384-
return (self.hn_alpha_vec - 1) / (np.sum(self.hn_alpha_vec) - self.degree)
389+
if dict_out:
390+
return {'theta_vec':(self.hn_alpha_vec - 1) / (np.sum(self.hn_alpha_vec) - self.degree)}
391+
else:
392+
return (self.hn_alpha_vec - 1) / (np.sum(self.hn_alpha_vec) - self.degree)
385393
else:
386394
warnings.warn("MAP estimate of lambda_mat doesn't exist for the current hn_alpha_vec.",ResultWarning)
387395
return None

bayesml/exponential/_exponential.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def gen_params(self):
6666
6767
The generated vaule is set at ``self.lambda_``.
6868
"""
69-
self.lambda_ = self.rng.gamma(self.h_alpha,1.0/self.h_beta, 1)
69+
self.lambda_ = self.rng.gamma(self.h_alpha,1.0/self.h_beta)
7070

7171
def set_params(self,lambda_):
7272
"""Set the parameter of the sthocastic data generative model.
@@ -277,18 +277,20 @@ def update_posterior(self,x):
277277
self.hn_alpha += x.size
278278
self.hn_beta += np.sum(x)
279279

280-
def estimate_params(self,loss="squared"):
280+
def estimate_params(self,loss="squared",dict_out=False):
281281
"""Estimate the parameter of the stochastic data generative model under the given criterion.
282282
283283
Parameters
284284
----------
285285
loss : str, optional
286286
Loss function underlying the Bayes risk function, by default \"squared\".
287287
This function supports \"squared\", \"0-1\", \"abs\", and \"KL\".
288+
dict_out : bool, optional
289+
If ``True``, output will be a dict, by default ``False``.
288290
289291
Returns
290292
-------
291-
Estimator : {float, None, rv_frozen}
293+
estimator : {float, None, rv_frozen}
292294
The estimated values under the given loss function. If it is not exist, `None` will be returned.
293295
If the loss function is \"KL\", the posterior distribution itself will be returned
294296
as rv_frozen object of scipy.stats.
@@ -299,14 +301,26 @@ def estimate_params(self,loss="squared"):
299301
scipy.stats.rv_discrete
300302
"""
301303
if loss == "squared":
302-
return self.hn_alpha / self.hn_beta
304+
if dict_out:
305+
return {'lambda_':self.hn_alpha / self.hn_beta}
306+
else:
307+
return self.hn_alpha / self.hn_beta
303308
elif loss == "0-1":
304309
if self.hn_alpha > 1.0 :
305-
return (self.hn_alpha - 1.0) / self.hn_beta
310+
if dict_out:
311+
return {'lambda_':(self.hn_alpha - 1.0) / self.hn_beta}
312+
else:
313+
return (self.hn_alpha - 1.0) / self.hn_beta
306314
else:
307-
return 0.0
315+
if dict_out:
316+
return {'lambda_':0.0}
317+
else:
318+
return 0.0
308319
elif loss == "abs":
309-
return ss_gamma.median(a=self.hn_alpha,scale=1/self.hn_beta)
320+
if dict_out:
321+
return {'lambda_':ss_gamma.median(a=self.hn_alpha,scale=1/self.hn_beta)}
322+
else:
323+
return ss_gamma.median(a=self.hn_alpha,scale=1/self.hn_beta)
310324
elif loss == "KL":
311325
return ss_gamma(a=self.hn_alpha,scale=1/self.hn_beta)
312326
else:

bayesml/linearregression/_linearregression.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class GenModel(base.Generative):
2525
a value consistent with ``theta_vec``, ``h_mu_vec``,
2626
and ``h_lambda_mat`` is used. If all of them are not given,
2727
degree is assumed to be 1.
28+
theta_vec : numpy ndarray, optional
29+
a vector of real numbers, by default [0.0, 0.0, ... , 0.0]
30+
tau : float, optional
31+
a positive real number, by default 1.0
2832
h_mu_vec : numpy ndarray, optional
2933
a vector of real numbers, by default [0.0, 0.0, ... , 0.0]
3034
h_lambda_mat : numpy ndarray, optional
@@ -558,7 +562,7 @@ def update_posterior(self, x, y):
558562
self.hn_beta += (-self.hn_mu_vec[np.newaxis,:] @ self.hn_lambda_mat @ self.hn_mu_vec[:,np.newaxis]
559563
+ y @ y + hn1_mu[np.newaxis,:] @ hn1_Lambda @ hn1_mu[:,np.newaxis])[0,0] /2.0
560564

561-
def estimate_params(self,loss="squared"):
565+
def estimate_params(self,loss="squared",dict_out=False):
562566
"""Estimate the parameter of the stochastic data generative model under the given criterion.
563567
564568
Note that the criterion is applied to estimating ``theta_vec`` and ``tau`` independently.
@@ -569,10 +573,12 @@ def estimate_params(self,loss="squared"):
569573
loss : str, optional
570574
Loss function underlying the Bayes risk function, by default \"squared\".
571575
This function supports \"squared\", \"0-1\", \"abs\", and \"KL\".
576+
dict_out : bool, optional
577+
If ``True``, output will be a dict, by default ``False``.
572578
573579
Returns
574580
-------
575-
Estimates : tuple of {numpy ndarray, float, None, or rv_frozen}
581+
estimates : tuple of {numpy ndarray, float, None, or rv_frozen}
576582
* ``theta_vec`` : the estimate for w
577583
* ``tau_hat`` : the estimate for tau
578584
The estimated values under the given loss function. If it is not exist, `None` will be returned.
@@ -584,15 +590,27 @@ def estimate_params(self,loss="squared"):
584590
scipy.stats.rv_continuous
585591
scipy.stats.rv_discrete
586592
"""
587-
if loss == "squared":
588-
return self.hn_mu_vec, self.hn_alpha/self.hn_beta
593+
if loss == "squared":
594+
if dict_out:
595+
return {'theta_vec':self.hn_mu_vec,'tau':self.hn_alpha/self.hn_beta}
596+
else:
597+
return self.hn_mu_vec, self.hn_alpha/self.hn_beta
589598
elif loss == "0-1":
590599
if self.hn_alpha >= 1.0:
591-
return self.hn_mu_vec, (self.hn_alpha - 1.0) / self.hn_beta
600+
if dict_out:
601+
return {'theta_vec':self.hn_mu_vec,'tau':(self.hn_alpha - 1.0) / self.hn_beta}
602+
else:
603+
return self.hn_mu_vec, (self.hn_alpha - 1.0) / self.hn_beta
592604
else:
593-
return self.hn_mu_vec, 0
605+
if dict_out:
606+
return {'theta_vec':self.hn_mu_vec,'tau':0.0}
607+
else:
608+
return self.hn_mu_vec, 0.0
594609
elif loss == "abs":
595-
return self.hn_mu_vec, ss_gamma.median(a=self.hn_alpha,scale=1.0/self.hn_beta)
610+
if dict_out:
611+
return {'theta_vec':self.hn_mu_vec,'tau':ss_gamma.median(a=self.hn_alpha,scale=1.0/self.hn_beta)}
612+
else:
613+
return self.hn_mu_vec, ss_gamma.median(a=self.hn_alpha,scale=1.0/self.hn_beta)
596614
elif loss == "KL":
597615
return (ss_multivariate_t(loc=self.hn_mu_vec,
598616
shape=np.linalg.inv(self.hn_alpha / self.hn_beta * self.hn_lambda_mat),

0 commit comments

Comments
 (0)