@@ -1896,21 +1896,21 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
1896
1896
self .c_dim_continuous ,'self.c_dim_continuous' ,
1897
1897
ParameterFormatError
1898
1898
)
1899
- x_continuous .reshape ([- 1 ,self .c_dim_continuous ])
1899
+ x_continuous = x_continuous .reshape ([- 1 ,self .c_dim_continuous ])
1900
1900
1901
1901
_check .nonneg_int_vecs (x_categorical ,'x_categorical' ,DataFormatError )
1902
1902
_check .shape_consistency (
1903
1903
x_categorical .shape [- 1 ],'x_categorical.shape[-1]' ,
1904
1904
self .c_dim_categorical ,'self.c_dim_categorical' ,
1905
1905
ParameterFormatError
1906
1906
)
1907
+ x_categorical = x_categorical .reshape ([- 1 ,self .c_dim_categorical ])
1907
1908
for i in range (self .c_dim_categorical ):
1908
1909
if x_categorical [:,i ].max () >= self .c_num_children_vec [self .c_dim_continuous + i ]:
1909
1910
raise (DataFormatError (
1910
1911
f"x_categorical[:,{ i } ].max() must smaller than "
1911
1912
+ f"self.c_num_children_vec[{ self .c_dim_continuous + i } ]: "
1912
1913
+ f"{ self .c_num_children_vec [self .c_dim_continuous + i ]} " ))
1913
- x_categorical .reshape ([- 1 ,self .c_dim_categorical ])
1914
1914
1915
1915
_check .shape_consistency (
1916
1916
x_continuous .shape [0 ],'x_continuous.shape[0]' ,
@@ -1930,7 +1930,7 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
1930
1930
self .c_dim_continuous ,'self.c_dim_continuous' ,
1931
1931
ParameterFormatError
1932
1932
)
1933
- x_continuous .reshape ([- 1 ,self .c_dim_continuous ])
1933
+ x_continuous = x_continuous .reshape ([- 1 ,self .c_dim_continuous ])
1934
1934
1935
1935
_check .shape_consistency (
1936
1936
x_continuous .shape [0 ],'x_continuous.shape[0]' ,
@@ -1947,13 +1947,13 @@ def update_posterior(self,x_continuous=None,x_categorical=None,y=None,alg_type='
1947
1947
self .c_dim_categorical ,'self.c_dim_categorical' ,
1948
1948
ParameterFormatError
1949
1949
)
1950
+ x_categorical = x_categorical .reshape ([- 1 ,self .c_dim_categorical ])
1950
1951
for i in range (self .c_dim_categorical ):
1951
1952
if x_categorical [:,i ].max () >= self .c_num_children_vec [self .c_dim_continuous + i ]:
1952
1953
raise (DataFormatError (
1953
1954
f"x_categorical[:,{ i } ].max() must smaller than "
1954
1955
+ f"self.c_num_children_vec[{ self .c_dim_continuous + i } ]: "
1955
1956
+ f"{ self .c_num_children_vec [self .c_dim_continuous + i ]} " ))
1956
- x_categorical .reshape ([- 1 ,self .c_dim_categorical ])
1957
1957
1958
1958
_check .shape_consistency (
1959
1959
x_categorical .shape [0 ],'x_categorical.shape[0]' ,
@@ -2372,52 +2372,41 @@ def calc_pred_dist(self,x_continuous=None,x_categorical=None):
2372
2372
self ._calc_pred_dist_recursion (root ,self ._tmp_x_continuous ,self ._tmp_x_categorical )
2373
2373
2374
2374
def _make_prediction_recursion_squared (self ,node :_Node ):
2375
- if node .leaf == False : # inner node
2376
- return ((1 - node .h_g ) * node .sub_model .make_prediction (loss = 'squared' )
2377
- + node .h_g * self ._make_prediction_recursion_squared (node .children [self ._tmp_x [node .k ]]))
2378
- else : # leaf node
2375
+ if node .leaf : # leaf node
2379
2376
return node .sub_model .make_prediction (loss = 'squared' )
2377
+ else : # inner node
2378
+ if node .k < self .c_dim_continuous :
2379
+ for i in range (self .c_num_children_vec [node .k ]):
2380
+ if node .thresholds [i ] < self ._tmp_x_continuous [node .k ] and self ._tmp_x_continuous [node .k ] < node .thresholds [i + 1 ]:
2381
+ index = i
2382
+ break
2383
+ else :
2384
+ index = self ._tmp_x_categorical [node .k - self .c_dim_continuous ]
2385
+ return ((1 - node .h_g ) * node .sub_model .make_prediction (loss = 'squared' )
2386
+ + node .h_g * self ._make_prediction_recursion_squared (node .children [index ]))
2380
2387
2381
- def _make_prediction_leaf_01 (self ,node :_Node ):
2382
- mode = node .sub_model .make_prediction (loss = '0-1' )
2383
- pred_dist = node .sub_model .make_prediction (loss = 'KL' )
2384
- if type (pred_dist ) is np .ndarray :
2385
- mode_prob = pred_dist [mode ]
2386
- else :
2387
- try :
2388
- mode_prob = pred_dist .pdf (mode )
2389
- except :
2390
- try :
2391
- mode_prob = pred_dist .pmf (mode )
2392
- except :
2393
- mode_prob = None
2394
- # elif hasattr(pred_dist,'pdf'):
2395
- # mode_prob = pred_dist.pdf(mode)
2396
- # elif hasattr(pred_dist,'pmf'):
2397
- # mode_prob = pred_dist.pmf(mode)
2398
- # else:
2399
- # mode_prob = None
2400
- return mode , mode_prob
2401
-
2402
- def _make_prediction_recursion_01 (self ,node :_Node ):
2403
- if node .leaf == False : # inner node
2404
- mode1 ,mode_prob1 = self ._make_prediction_leaf_01 (node )
2405
- mode2 ,mode_prob2 = self ._make_prediction_recursion_01 (node .children [self ._tmp_x [node .k ]])
2406
- if (1 - node .h_g ) * mode_prob1 > node .h_g * mode_prob2 :
2407
- return mode1 ,mode_prob1
2388
+ def _make_prediction_recursion_kl (self ,node :_Node ):
2389
+ if node .leaf : # leaf node
2390
+ return node .sub_model .make_prediction (loss = 'KL' )
2391
+ else : # inner node
2392
+ if node .k < self .c_dim_continuous :
2393
+ for i in range (self .c_num_children_vec [node .k ]):
2394
+ if node .thresholds [i ] < self ._tmp_x_continuous [node .k ] and self ._tmp_x_continuous [node .k ] < node .thresholds [i + 1 ]:
2395
+ index = i
2396
+ break
2408
2397
else :
2409
- return mode2 , mode_prob2
2410
- else : # leaf node
2411
- return self ._make_prediction_leaf_01 (node )
2398
+ index = self . _tmp_x_categorical [ node . k - self . c_dim_continuous ]
2399
+ return (( 1 - node . h_g ) * node . sub_model . make_prediction ( loss = 'KL' )
2400
+ + node . h_g * self ._make_prediction_recursion_kl (node . children [ index ]) )
2412
2401
2413
- def make_prediction (self ,loss = "0-1 " ):
2402
+ def make_prediction (self ,loss = "squared " ):
2414
2403
"""Predict a new data point under the given criterion.
2415
2404
2416
2405
Parameters
2417
2406
----------
2418
2407
loss : str, optional
2419
- Loss function underlying the Bayes risk function, by default \" 0-1 \" .
2420
- This function supports \" squared\" , \" 0-1\" .
2408
+ Loss function underlying the Bayes risk function, by default \" squared \" .
2409
+ This function supports \" squared\" , \" 0-1\" , and \" KL \" .
2421
2410
2422
2411
Returns
2423
2412
-------
@@ -2430,39 +2419,49 @@ def make_prediction(self,loss="0-1"):
2430
2419
tmp_pred_vec [i ] = self ._make_prediction_recursion_squared (metatree )
2431
2420
return self .hn_metatree_prob_vec @ tmp_pred_vec
2432
2421
elif loss == "0-1" :
2433
- tmp_mode = np .empty (len (self .hn_metatree_list ))
2434
- tmp_mode_prob_vec = np .empty (len (self .hn_metatree_list ))
2422
+ if self .SubModel is not bernoulli :
2423
+ raise (CriteriaError ("Unsupported loss function! "
2424
+ + "\" 0-1\" is supported only when self.SubModel is bernoulli." ))
2425
+ tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2426
+ for i ,metatree in enumerate (self .hn_metatree_list ):
2427
+ tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2428
+ return np .argmax (self .hn_metatree_prob_vec @ tmp_pred_dist_vec )
2429
+ elif loss == "KL" :
2430
+ if self .SubModel is not bernoulli :
2431
+ raise (CriteriaError ("Unsupported loss function! "
2432
+ + "\" KL\" is supported only when self.SubModel is bernoulli." ))
2433
+ tmp_pred_dist_vec = np .empty ([len (self .hn_metatree_list ),2 ])
2435
2434
for i ,metatree in enumerate (self .hn_metatree_list ):
2436
- tmp_mode [i ], tmp_mode_prob_vec [ i ] = self ._make_prediction_recursion_01 (metatree )
2437
- return tmp_mode [ np . argmax ( self .hn_metatree_prob_vec * tmp_mode_prob_vec )]
2435
+ tmp_pred_dist_vec [i ] = self ._make_prediction_recursion_kl (metatree )
2436
+ return self .hn_metatree_prob_vec @ tmp_pred_dist_vec
2438
2437
else :
2439
2438
raise (CriteriaError ("Unsupported loss function! "
2440
- + "This function supports \" squared\" and \" 0-1\" ." ))
2439
+ + "This function supports \" squared\" , \" 0-1\" , and \" KL \" ." ))
2441
2440
2442
- def pred_and_update (self ,x , y ,loss = "0-1 " ):
2441
+ def pred_and_update (self ,x_continuous = None , x_categorical = None , y = None ,loss = "squared " ):
2443
2442
"""Predict a new data point and update the posterior sequentially.
2444
2443
2445
2444
Parameters
2446
2445
----------
2447
- x : numpy.ndarray
2448
- It must be a degree-dimensional vector
2446
+ x_continuous : numpy ndarray, optional
2447
+ A float vector whose length is ``self.c_dim_continuous``,
2448
+ by default None.
2449
+ x_categorical : numpy ndarray, optional
2450
+ A int vector whose length is ``self.c_dim_categorical``,
2451
+ by default None. Each element x_categorical[i] must satisfy
2452
+ 0 <= x_categorical[i] < self.c_num_children_vec[self.c_dim_continuous+i].
2449
2453
y : numpy ndarray
2450
2454
values of objective variable whose dtype may be int or float
2451
2455
loss : str, optional
2452
- Loss function underlying the Bayes risk function, by default \" 0-1 \" .
2456
+ Loss function underlying the Bayes risk function, by default \" squared \" .
2453
2457
This function supports \" squared\" , \" 0-1\" , and \" KL\" .
2454
2458
2455
2459
Returns
2456
2460
-------
2457
2461
predicted_value : {float, numpy.ndarray}
2458
2462
The predicted value under the given loss function.
2459
2463
"""
2460
- _check .nonneg_int_vec (x ,'x' ,DataFormatError )
2461
- if x .shape [- 1 ] != self .c_k :
2462
- raise (DataFormatError (f"x.shape[-1] must equal to c_k:{ self .c_k } " ))
2463
- if x .max () >= self .c_num_children :
2464
- raise (DataFormatError (f"x.max() must smaller than c_num_children:{ self .c_num_children } " ))
2465
- self .calc_pred_dist (x )
2464
+ self .calc_pred_dist (x_continuous ,x_categorical )
2466
2465
prediction = self .make_prediction (loss = loss )
2467
- self .update_posterior (x ,y ,alg_type = 'given_MT' )
2466
+ self .update_posterior (x_continuous , x_categorical ,y ,alg_type = 'given_MT' )
2468
2467
return prediction
0 commit comments