Skip to content

Commit 725ea62

Browse files
committed
Revise index search
1 parent ae6c960 commit 725ea62

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

bayesml/metatree/_metatree.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,11 @@ def _gen_sample_recursion(self,node:_Node,x_continuous,x_categorical):
393393
return node.sub_model.gen_sample(sample_size=1)
394394
else:
395395
if node.k < self.c_dim_continuous:
396-
for i in range(self.c_num_children_vec[node.k]):
396+
index = 0
397+
for i in range(self.c_num_children_vec[node.k]-1):
397398
if x_continuous[node.k] < node.thresholds[i+1]:
398-
index = i
399399
break
400+
index += 1
400401
else:
401402
index = x_categorical[node.k-self.c_dim_continuous]
402403
return self._gen_sample_recursion(node.children[index],x_continuous,x_categorical)
@@ -1651,10 +1652,12 @@ def _update_posterior_recursion(self,node:_Node,x_continuous,x_categorical,y):
16511652
return self._update_posterior_leaf(node,y)
16521653
else: # inner node
16531654
if node.k < self.c_dim_continuous:
1654-
for i in range(self.c_num_children_vec[node.k]):
1655+
index = 0
1656+
for i in range(self.c_num_children_vec[node.k]-1):
16551657
if x_continuous[node.k] < node.thresholds[i+1]:
1656-
index = i
16571658
break
1659+
index += 1
1660+
# index = np.count_nonzero(node.thresholds[1:-1]<x_continuous[node.k]) # slower
16581661
else:
16591662
index = x_categorical[node.k-self.c_dim_continuous]
16601663
tmp1 = self._update_posterior_recursion(node.children[index],x_continuous,x_categorical,y)
@@ -1672,10 +1675,12 @@ def _update_posterior_recursion_lr(self,node:_Node,x_continuous,x_categorical,y)
16721675
return self._update_posterior_leaf_lr(node,x_continuous,y)
16731676
else: # inner node
16741677
if node.k < self.c_dim_continuous:
1675-
for i in range(self.c_num_children_vec[node.k]):
1678+
index = 0
1679+
for i in range(self.c_num_children_vec[node.k]-1):
16761680
if x_continuous[node.k] < node.thresholds[i+1]:
1677-
index = i
16781681
break
1682+
index += 1
1683+
# index = np.count_nonzero(node.thresholds[1:-1]<x_continuous[node.k]) # slower
16791684
else:
16801685
index = x_categorical[node.k-self.c_dim_continuous]
16811686
tmp1 = self._update_posterior_recursion_lr(node.children[index],x_continuous,x_categorical,y)
@@ -2302,10 +2307,11 @@ def _calc_pred_dist_recursion(self,node:_Node,x_continuous,x_categorical):
23022307
node.sub_model.calc_pred_dist()
23032308
if not node.leaf: # inner node
23042309
if node.k < self.c_dim_continuous:
2305-
for i in range(self.c_num_children_vec[node.k]):
2310+
index = 0
2311+
for i in range(self.c_num_children_vec[node.k]-1):
23062312
if x_continuous[node.k] < node.thresholds[i+1]:
2307-
index = i
23082313
break
2314+
index += 1
23092315
else:
23102316
index = x_categorical[node.k-self.c_dim_continuous]
23112317
self._calc_pred_dist_recursion(node.children[index],x_continuous,x_categorical)
@@ -2314,10 +2320,11 @@ def _calc_pred_dist_recursion_lr(self,node:_Node,x_continuous,x_categorical):
23142320
node.sub_model._calc_pred_dist(x_continuous)
23152321
if not node.leaf: # inner node
23162322
if node.k < self.c_dim_continuous:
2317-
for i in range(self.c_num_children_vec[node.k]):
2323+
index = 0
2324+
for i in range(self.c_num_children_vec[node.k]-1):
23182325
if x_continuous[node.k] < node.thresholds[i+1]:
2319-
index = i
23202326
break
2327+
index += 1
23212328
else:
23222329
index = x_categorical[node.k-self.c_dim_continuous]
23232330
self._calc_pred_dist_recursion_lr(node.children[index],x_continuous,x_categorical)
@@ -2352,10 +2359,11 @@ def _make_prediction_recursion_squared(self,node:_Node):
23522359
return node.sub_model.make_prediction(loss='squared')
23532360
else: # inner node
23542361
if node.k < self.c_dim_continuous:
2355-
for i in range(self.c_num_children_vec[node.k]):
2362+
index = 0
2363+
for i in range(self.c_num_children_vec[node.k]-1):
23562364
if self._tmp_x_continuous[node.k] < node.thresholds[i+1]:
2357-
index = i
23582365
break
2366+
index += 1
23592367
else:
23602368
index = self._tmp_x_categorical[node.k-self.c_dim_continuous]
23612369
return ((1 - node.h_g) * node.sub_model.make_prediction(loss='squared')
@@ -2366,10 +2374,11 @@ def _make_prediction_recursion_kl(self,node:_Node):
23662374
return node.sub_model.make_prediction(loss='KL')
23672375
else: # inner node
23682376
if node.k < self.c_dim_continuous:
2369-
for i in range(self.c_num_children_vec[node.k]):
2377+
index = 0
2378+
for i in range(self.c_num_children_vec[node.k]-1):
23702379
if self._tmp_x_continuous[node.k] < node.thresholds[i+1]:
2371-
index = i
23722380
break
2381+
index += 1
23732382
else:
23742383
index = self._tmp_x_categorical[node.k-self.c_dim_continuous]
23752384
return ((1 - node.h_g) * node.sub_model.make_prediction(loss='KL')

bayesml/metatree/metatree_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
c_dim_categorical=dim_categorical,
1717
c_max_depth=2,
1818
h_g=0.75,
19-
SubModel=linearregression,
20-
sub_constants={'c_degree':2},
21-
sub_h_params={'h_lambda_mat':np.eye(2)*0.01},
19+
SubModel=bernoulli,
20+
# sub_constants={'c_degree':2},
21+
# sub_h_params={'h_lambda_mat':np.eye(2)*0.01},
2222
# sub_h_params={'h_kappa':0.1})
2323
# sub_h_params={'h_alpha':0.3,'h_beta':0.3}
2424
)
@@ -33,9 +33,9 @@
3333
c_num_children_vec=2,
3434
c_max_depth=2,
3535
h0_g=0.75,
36-
SubModel=linearregression,
37-
sub_constants={'c_degree':2},
38-
sub_h0_params={'h0_lambda_mat':np.eye(2)*0.01},
36+
SubModel=bernoulli,
37+
# sub_constants={'c_degree':2},
38+
# sub_h0_params={'h0_lambda_mat':np.eye(2)*0.01},
3939
# sub_h0_params={'h0_kappa':0.1})
4040
# sub_h0_params={'h0_alpha':0.3,'h0_beta':0.3})
4141
)

0 commit comments

Comments
 (0)