Skip to content

Commit f3da427

Browse files
committed
Support categorical and linearregression
1 parent 8574e98 commit f3da427

File tree

10 files changed

+441
-312
lines changed

10 files changed

+441
-312
lines changed

bayesml/bernoulli/_bernoulli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ def get_hn_params(self):
290290
"""
291291
return {"hn_alpha":self.hn_alpha, "hn_beta":self.hn_beta}
292292

293+
def _check_sample(self,x):
294+
return _check.ints_of_01(x,'x',DataFormatError)
295+
293296
def update_posterior(self,x):
294297
"""Update the hyperparameters of the posterior distribution using traning data.
295298
@@ -298,7 +301,7 @@ def update_posterior(self,x):
298301
x : numpy.ndarray
299302
All the elements must be 0 or 1.
300303
"""
301-
_check.ints_of_01(x,'x',DataFormatError)
304+
x = self._check_sample(x)
302305
self.hn_alpha += np.count_nonzero(x==1)
303306
self.hn_beta += np.count_nonzero(x==0)
304307
return self

bayesml/categorical/_categorical.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,22 @@ def get_hn_params(self):
321321
"""
322322
return {"hn_alpha_vec": self.hn_alpha_vec}
323323

324+
# default onehot option is False because it is used in metatree
325+
def _check_sample(self,x,onehot=False):
326+
if onehot:
327+
_check.onehot_vecs(x,'x',DataFormatError)
328+
if x.shape[-1] != self.c_degree:
329+
raise(DataFormatError(f"x.shape[-1] must be c_degree:{self.c_degree}"))
330+
return x.reshape(-1,self.c_degree)
331+
else:
332+
_check.nonneg_ints(x,'x',DataFormatError)
333+
if np.max(x) >= self.c_degree:
334+
raise(DataFormatError(
335+
'np.max(x) must be smaller than self.c_degree: '
336+
+f'np.max(x) = {np.max(x)}, self.c_degree = {self.c_degree}'
337+
))
338+
return x
339+
324340
def update_posterior(self,x,onehot=True):
325341
"""Update the hyperparameters of the posterior distribution using traning data.
326342
@@ -336,22 +352,12 @@ def update_posterior(self,x,onehot=True):
336352
If True, the input sample must be one-hot encoded,
337353
by default True.
338354
"""
355+
x = self._check_sample(x,onehot)
339356
if onehot:
340-
_check.onehot_vecs(x,'x',DataFormatError)
341-
if x.shape[-1] != self.c_degree:
342-
raise(DataFormatError(f"x.shape[-1] must be c_degree:{self.c_degree}"))
343-
x = x.reshape(-1,self.c_degree)
344357
self.hn_alpha_vec[:] += x.sum(axis=0)
345358
else:
346-
_check.nonneg_ints(x,'x',DataFormatError)
347-
if np.max(x) >= self.c_degree:
348-
raise(DataFormatError(
349-
'np.max(x) must be smaller than self.c_degree: '
350-
+f'np.max(x) = {np.max(x)}, self.c_degree = {self.c_degree}'
351-
))
352359
for k in range(self.c_degree):
353360
self.hn_alpha_vec[k] += np.count_nonzero(x==k)
354-
355361
return self
356362

357363
def _update_posterior(self,x):
@@ -396,7 +402,10 @@ def estimate_params(self, loss="squared",dict_out=False):
396402
return (self.hn_alpha_vec - 1) / (np.sum(self.hn_alpha_vec) - self.c_degree)
397403
else:
398404
warnings.warn("MAP estimate of lambda_mat doesn't exist for the current hn_alpha_vec.",ResultWarning)
399-
return None
405+
if dict_out:
406+
return {'theta_vec':None}
407+
else:
408+
return None
400409
elif loss == "KL":
401410
return ss_dirichlet(alpha=self.hn_alpha_vec)
402411
else:
@@ -476,6 +485,9 @@ def calc_pred_dist(self):
476485
"""Calculate the parameters of the predictive distribution."""
477486
self.p_theta_vec[:] = self.hn_alpha_vec / self.hn_alpha_vec.sum()
478487
return self
488+
489+
def _calc_pred_density(self,x):
490+
return self.p_theta_vec[x]
479491

480492
def make_prediction(self,loss="squared",onehot=True):
481493
"""Predict a new data point under the given criterion.

bayesml/exponential/_exponential.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def get_hn_params(self):
292292
"""
293293
return {"hn_alpha":self.hn_alpha, "hn_beta":self.hn_beta}
294294

295+
def _check_sample(self,x):
296+
return _check.pos_floats(x, 'x', DataFormatError)
297+
295298
def update_posterior(self,x):
296299
"""Update the hyperparameters of the posterior distribution using traning data.
297300
@@ -300,7 +303,7 @@ def update_posterior(self,x):
300303
x : numpy.ndarray
301304
All the elements must be positive real numbers.
302305
"""
303-
_check.pos_floats(x, 'x', DataFormatError)
306+
x = self._check_sample(x)
304307
try:
305308
self.hn_alpha += x.size
306309
except:

bayesml/linearregression/_linearregression.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,24 @@ def get_hn_params(self):
501501
"""
502502
return {"hn_mu_vec":self.hn_mu_vec, "hn_lambda_mat":self.hn_lambda_mat, "hn_alpha":self.hn_alpha, "hn_beta":self.hn_beta}
503503

504+
def _check_sample_x(self,x):
505+
_check.float_vecs(x,'x',DataFormatError)
506+
if x.shape[-1] != self.c_degree:
507+
raise(DataFormatError(f"x.shape[-1] must be c_degree:{self.c_degree}"))
508+
509+
def _check_sample_y(self,y):
510+
_check.floats(y,'y',DataFormatError)
511+
512+
def _check_sample(self,x,y):
513+
self._check_sample_x(x)
514+
self._check_sample_y(y)
515+
if type(y) is np.ndarray:
516+
if x.shape[:-1] != y.shape:
517+
raise(DataFormatError(f"x.shape[:-1] and y.shape must be same."))
518+
elif x.shape[:-1] != ():
519+
raise(DataFormatError(f"If y is a scaler, x.shape[:-1] must be the empty tuple ()."))
520+
return x.reshape(-1,self.c_degree), np.ravel(y)
521+
504522
def update_posterior(self, x, y):
505523
"""Update the hyperparameters of the posterior distribution using traning data.
506524
@@ -512,18 +530,7 @@ def update_posterior(self, x, y):
512530
y : numpy ndarray
513531
float array.
514532
"""
515-
_check.float_vecs(x,'x',DataFormatError)
516-
if x.shape[-1] != self.c_degree:
517-
raise(DataFormatError(f"x.shape[-1] must be c_degree:{self.c_degree}"))
518-
_check.floats(y,'y',DataFormatError)
519-
if type(y) is np.ndarray:
520-
if x.shape[:-1] != y.shape:
521-
raise(DataFormatError(f"x.shape[:-1] and y.shape must be same."))
522-
elif x.shape[:-1] != ():
523-
raise(DataFormatError(f"If y is a scaler, x.shape[:-1] must be the empty tuple ()."))
524-
525-
x = x.reshape(-1,self.c_degree)
526-
y = np.ravel(y)
533+
x,y = self._check_sample(x,y)
527534

528535
hn1_Lambda = np.array(self.hn_lambda_mat)
529536
hn1_mu = np.array(self.hn_mu_vec)
@@ -536,6 +543,8 @@ def update_posterior(self, x, y):
536543

537544
def _update_posterior(self, x, y):
538545
"""Update opsterior without input check."""
546+
x = x.reshape(-1,self.c_degree)
547+
y = np.ravel(y)
539548
hn1_Lambda = np.array(self.hn_lambda_mat)
540549
hn1_mu = np.array(self.hn_mu_vec)
541550
self.hn_lambda_mat += x.T @ x
@@ -703,6 +712,16 @@ def calc_pred_dist(self, x):
703712
self.p_nu = 2.0 * self.hn_alpha
704713
return self
705714

715+
def _calc_pred_dist(self, x):
716+
"""Calculate predictive distribution without check."""
717+
self.p_m = x @ self.hn_mu_vec
718+
self.p_lambda = self.hn_alpha / self.hn_beta / (1.0 + x @ np.linalg.solve(self.hn_lambda_mat,x))
719+
self.p_nu = 2.0 * self.hn_alpha
720+
return self
721+
722+
def _calc_pred_density(self,y):
723+
return ss_t.pdf(y,loc=self.p_m, scale=1.0/np.sqrt(self.p_lambda), df=self.p_nu)
724+
706725
def make_prediction(self,loss="squared"):
707726
"""Predict a new data point under the given criterion.
708727

0 commit comments

Comments
 (0)