Skip to content

Commit 75bc4f3

Browse files
Merge pull request #3 from yuta-nakahara/develop
Develop categorical model
2 parents f81434e + 2abf906 commit 75bc4f3

File tree

9 files changed

+542
-310
lines changed

9 files changed

+542
-310
lines changed

bayesml/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from . import bernoulli
2+
from . import categorical
23
from . import autoregressive
34
from . import exponential
45
from . import linearregression
@@ -7,6 +8,7 @@
78
from . import poisson
89

910
__all__ = ['bernoulli',
11+
'categorical',
1012
'autoregressive',
1113
'exponential',
1214
'linearregression',

bayesml/_check.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,34 @@
33
# Yuji Iikubo <yuji-iikubo.8@fuji.waseda.jp>
44
import numpy as np
55

6-
FLOATS = list({'float128','float64','float32','float16'} & set(dir(np)) | {float})
7-
INTS = list({'int64','int32','int16','int8'} & set(dir(np)) | {int})
6+
_EPSILON = np.sqrt(np.finfo(np.float64).eps)
87

98
def float_in_closed01(val,val_name,exception_class):
10-
if type(val) in FLOATS:
9+
if np.issubdtype(type(val),np.floating):
1110
if val >= 0.0 and val <= 1.0:
1211
return val
13-
if type(val) in INTS:
12+
if np.issubdtype(type(val),np.integer):
1413
if val >= 0.0 and val <= 1.0:
1514
return float(val)
1615
raise(exception_class(val_name + " must be in [0,1]."))
1716

1817
def pos_float(val,val_name,exception_class):
19-
if type(val) in FLOATS:
18+
if np.issubdtype(type(val),np.floating):
2019
if val > 0.0:
2120
return val
22-
if type(val) in INTS:
21+
if np.issubdtype(type(val),np.integer):
2322
if val > 0.0:
2423
return float(val)
2524
raise(exception_class(val_name + " must be positive (not including 0.0)."))
2625

2726
def pos_int(val,val_name,exception_class):
28-
if type(val) in INTS:
27+
if np.issubdtype(type(val),np.integer):
2928
if val > 0:
3029
return val
3130
raise(exception_class(val_name + " must be int. Its value must be positive (not including 0)."))
3231

3332
def nonneg_int(val,val_name,exception_class):
34-
if type(val) in INTS:
33+
if np.issubdtype(type(val),np.integer):
3534
if val >= 0:
3635
return val
3736
raise(exception_class(val_name + " must be int. Its value must be non-negative (including 0)."))
@@ -42,18 +41,18 @@ def nonneg_ints(val,val_name,exception_class):
4241
except:
4342
pass
4443
if type(val) is np.ndarray:
45-
if val.dtype in INTS and np.all(val>=0):
44+
if np.issubdtype(val.dtype,np.integer) and np.all(val>=0):
4645
return val
4746
raise(exception_class(val_name + " must be int or a numpy.ndarray whose dtype is int. Its values must be non-negative (including 0)."))
4847

4948
def nonneg_int_vec(val,val_name,exception_class):
5049
if type(val) is np.ndarray:
51-
if val.dtype in INTS and val.ndim == 1 and np.all(val>=0):
50+
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1 and np.all(val>=0):
5251
return val
5352
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray whose dtype is int. Its values must be non-negative (including 0)."))
5453

5554
def int_of_01(val,val_name,exception_class):
56-
if type(val) in INTS:
55+
if np.issubdtype(type(val),np.integer):
5756
if val == 0 or val ==1:
5857
return val
5958
raise(exception_class(val_name + " must be int. Its value must be 0 or 1."))
@@ -64,23 +63,23 @@ def ints_of_01(val,val_name,exception_class):
6463
except:
6564
pass
6665
if type(val) is np.ndarray:
67-
if val.dtype in INTS and np.all(val >= 0) and np.all(val <= 1):
66+
if np.issubdtype(val.dtype,np.integer) and np.all(val >= 0) and np.all(val <= 1):
6867
return val
6968
raise(exception_class(val_name + " must be int or a numpy.ndarray whose dtype is int. Its values must be 0 or 1."))
7069

7170
def int_vec_of_01(val,val_name,exception_class):
7271
if type(val) is np.ndarray:
73-
if val.dtype in INTS and val.ndim == 1 and np.all(val >= 0) and np.all(val <= 1):
72+
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1 and np.all(val >= 0) and np.all(val <= 1):
7473
return val
7574
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray whose dtype is int. Its values must be 0 or 1."))
7675

7776
def scalar(val,val_name,exception_class):
78-
if type(val) in INTS or type(val) in FLOATS:
77+
if np.issubdtype(type(val),np.integer) or np.issubdtype(type(val),np.floating):
7978
return val
8079
raise(exception_class(val_name + " must be a scalar."))
8180

8281
def pos_scalar(val,val_name,exception_class):
83-
if type(val) in INTS or type(val) in FLOATS:
82+
if np.issubdtype(type(val),np.integer) or np.issubdtype(type(val),np.floating):
8483
if val > 0.0:
8584
return val
8685
raise(exception_class(val_name + " must be a positive scalar."))
@@ -102,9 +101,9 @@ def pos_def_sym_mat(val,val_name,exception_class):
102101
raise(exception_class(val_name + " must be a positive definite symmetric 2-dimensional numpy.ndarray."))
103102

104103
def float_(val,val_name,exception_class):
105-
if type(val) in FLOATS:
104+
if np.issubdtype(type(val),np.floating):
106105
return val
107-
if type(val) in INTS:
106+
if np.issubdtype(type(val),np.integer):
108107
return float(val)
109108
raise(exception_class(val_name + " must be a scalar."))
110109

@@ -114,9 +113,9 @@ def floats(val,val_name,exception_class):
114113
except:
115114
pass
116115
if type(val) is np.ndarray:
117-
if val.dtype in INTS:
116+
if np.issubdtype(val.dtype,np.integer):
118117
return val.astype(float)
119-
if val.dtype in FLOATS:
118+
if np.issubdtype(val.dtype,np.floating):
120119
return val
121120
raise(exception_class(val_name + " must be float or a numpy.ndarray."))
122121

@@ -126,33 +125,67 @@ def pos_floats(val,val_name,exception_class):
126125
except:
127126
pass
128127
if type(val) is np.ndarray:
129-
if val.dtype in INTS and np.all(val>0):
128+
if np.issubdtype(val.dtype,np.integer) and np.all(val>0):
130129
return val.astype(float)
131-
if val.dtype in FLOATS and np.all(val>0.0):
130+
if np.issubdtype(val.dtype,np.floating) and np.all(val>0.0):
132131
return val
133132
raise(exception_class(val_name + " must be float or a numpy.ndarray. Its values must be positive (not including 0)"))
134133

135134
def float_vec(val,val_name,exception_class):
136135
if type(val) is np.ndarray:
137-
if val.dtype in INTS and val.ndim == 1:
136+
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1:
138137
return val.astype(float)
139-
if val.dtype in FLOATS and val.ndim == 1:
138+
if np.issubdtype(val.dtype,np.floating) and val.ndim == 1:
140139
return val
141140
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray."))
142141

143142
def pos_float_vec(val,val_name,exception_class):
144143
if type(val) is np.ndarray:
145-
if val.dtype in INTS and val.ndim == 1 and np.all(val>0):
144+
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1 and np.all(val>0):
146145
return val.astype(float)
147-
if val.dtype in FLOATS and val.ndim == 1 and np.all(val>0.0):
146+
if np.issubdtype(val.dtype,np.floating) and val.ndim == 1 and np.all(val>0.0):
148147
return val
149148
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray. Its values must be positive (not including 0)"))
150149

151150
def float_vecs(val,val_name,exception_class):
152151
if type(val) is np.ndarray:
153-
if val.dtype in INTS and val.ndim >= 1:
152+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1:
154153
return val.astype(float)
155-
if val.dtype in FLOATS and val.ndim >= 1:
154+
if np.issubdtype(val.dtype,np.floating) and val.ndim >= 1:
156155
return val
157156
raise(exception_class(val_name + " must be a numpy.ndarray whose ndim >= 1."))
158157

158+
def float_vec_sum_1(val,val_name,exception_class):
159+
if type(val) is np.ndarray:
160+
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1 and abs(val.sum() - 1.) <= _EPSILON:
161+
return val.astype(float)
162+
if np.issubdtype(val.dtype,np.floating) and val.ndim == 1 and abs(val.sum() - 1.) <= _EPSILON:
163+
return val
164+
raise(exception_class(val_name + " must be a 1-dimensional numpy.ndarray, and the sum of its elements must equal to 1."))
165+
166+
def int_(val,val_name,exception_class):
167+
if np.issubdtype(type(val),np.integer):
168+
return val
169+
raise(exception_class(val_name + " must be an integer."))
170+
171+
def ints(val,val_name,exception_class):
172+
try:
173+
return int_(val,val_name,exception_class)
174+
except:
175+
pass
176+
if type(val) is np.ndarray:
177+
if np.issubdtype(val.dtype,np.integer):
178+
return val
179+
raise(exception_class(val_name + " must be int or a numpy.ndarray whose dtype is int."))
180+
181+
def onehot_vec(val,val_name,exception_class):
182+
if type(val) is np.ndarray:
183+
if np.issubdtype(val.dtype,np.integer) and val.ndim == 1 and np.all(val >= 0) and val.sum()==1:
184+
return val
185+
raise(exception_class(val_name + " must be a one-hot vector (1-dimensional ndarray) whose dtype must be int."))
186+
187+
def onehot_vecs(val,val_name,exception_class):
188+
if type(val) is np.ndarray:
189+
if np.issubdtype(val.dtype,np.integer) and val.ndim >= 1 and np.all(val >= 0) and np.all(val.sum(axis=-1)==1):
190+
return val
191+
raise(exception_class(val_name + " must be a numpy.ndarray whose dtype is int and whose last axis constitutes one-hot vectors."))

bayesml/bernoulli/_bernoulli.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# Yuta Nakahara <yuta.nakahara@aoni.waseda.jp>
55
import warnings
66
import numpy as np
7-
import os
8-
import sys
97
from scipy.stats import beta as ss_beta
108
# from scipy.stats import betabino as ss_betabinom
119
import matplotlib.pyplot as plt
@@ -155,29 +153,33 @@ def visualize_model(self,sample_size=20,sample_num=5):
155153
>>> model = bernoulli.GenModel()
156154
>>> model.visualize_model()
157155
p:0.5
158-
x0:[0 0 1 1 1 1 1 1 1 0 1 0 1 1 1 0 1 0 1 0]
159-
x1:[1 0 0 1 1 0 1 0 0 0 0 0 0 0 0 1 0 1 0 1]
160-
x2:[1 1 1 1 0 1 0 1 0 0 0 1 1 0 0 1 1 1 0 1]
161-
x3:[0 0 1 1 1 0 0 1 1 0 0 1 0 0 1 0 0 1 0 1]
162-
x4:[0 1 0 1 1 0 1 0 1 1 1 1 1 0 1 0 0 1 1 0]
163-
156+
x0:[1 1 0 0 0 1 0 1 0 0 0 1 0 1 0 1 0 1 0 0]
157+
x1:[1 1 0 0 0 0 0 1 1 0 0 0 1 0 1 0 0 0 0 0]
158+
x2:[0 1 0 1 0 0 1 0 0 0 1 0 1 1 1 0 1 0 1 1]
159+
x3:[0 0 0 1 1 0 1 0 1 0 0 0 1 0 1 0 1 0 1 1]
160+
x4:[1 0 1 1 1 1 0 1 0 0 1 1 0 0 0 0 0 0 1 1]
161+
164162
.. image:: ./images/bernoulli_example.png
165163
"""
166164
_check.pos_int(sample_size,'sample_size',DataFormatError)
167165
_check.pos_int(sample_num,'sample_num',DataFormatError)
168166
print(f"p:{self.p}")
169-
fig, ax = plt.subplots(figsize=(5,sample_num))
167+
fig, ax = plt.subplots(2,1,figsize=(5, sample_num+1),gridspec_kw={'height_ratios': [1,sample_num]})
168+
ax[0].set_title("True distribution")
169+
ax[0].barh(0,self.p,label=1,color="C0")
170+
ax[0].barh(0,1.0-self.p,left=self.p,label=0,color="C1")
171+
ax[1].set_title("Generated sample")
170172
for i in range(sample_num):
171173
x = self.gen_sample(sample_size)
172174
print(f"x{i}:{x}")
173175
if i == 0:
174-
ax.barh(i,x.sum(),label=1,color="C0")
175-
ax.barh(i,sample_size-x.sum(),left=x.sum(),label=0,color="C1")
176+
ax[1].barh(i,x.sum(),label=1,color="C0")
177+
ax[1].barh(i,sample_size-x.sum(),left=x.sum(),label=0,color="C1")
176178
else:
177-
ax.barh(i,x.sum(),color="C0")
178-
ax.barh(i,sample_size-x.sum(),left=x.sum(),color="C1")
179-
ax.legend()
180-
ax.set_xlabel("Number of occurrences")
179+
ax[1].barh(i,x.sum(),color="C0")
180+
ax[1].barh(i,sample_size-x.sum(),left=x.sum(),color="C1")
181+
ax[1].legend()
182+
ax[1].set_xlabel("Number of occurrences")
181183
plt.show()
182184

183185
class LearnModel(base.Posterior,base.PredictiveMixin):

0 commit comments

Comments
 (0)