Skip to content

Commit 7a70a0b

Browse files
added regression forest benchmark
1 parent 775f0b7 commit 7a70a0b

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

asv_benchmarks/benchmarks/ensemble.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,58 @@
22
GradientBoostingClassifier,
33
HistGradientBoostingClassifier,
44
RandomForestClassifier,
5+
RandomForestRegressor
56
)
67

78
from .common import Benchmark, Estimator, Predictor
89
from .datasets import (
910
_20newsgroups_highdim_dataset,
1011
_20newsgroups_lowdim_dataset,
1112
_synth_classification_dataset,
13+
_synth_regression_dataset,
14+
_synth_regression_sparse_dataset
1215
)
13-
from .utils import make_gen_classif_scorers
16+
from .utils import make_gen_classif_scorers, make_gen_reg_scorers
17+
18+
19+
class RandomForestRegressorBenchmark(Predictor, Estimator, Benchmark):
20+
"""
21+
Benchmarks for RandomForestRegressor.
22+
"""
23+
24+
param_names = ["representation", "n_jobs"]
25+
params = (["dense", "sparse"], Benchmark.n_jobs_vals)
26+
27+
def setup_cache(self):
28+
super().setup_cache()
29+
30+
def make_data(self, params):
31+
representation, n_jobs = params
32+
33+
if representation == "sparse":
34+
data = _synth_regression_sparse_dataset()
35+
else:
36+
data = _synth_regression_dataset()
37+
38+
return data
39+
40+
def make_estimator(self, params):
41+
representation, n_jobs = params
42+
43+
n_estimators = 500 if Benchmark.data_size == "large" else 100
44+
45+
estimator = RandomForestRegressor(
46+
n_estimators=n_estimators,
47+
min_samples_split=10,
48+
max_features="log2",
49+
n_jobs=n_jobs,
50+
random_state=0,
51+
)
52+
53+
return estimator
54+
55+
def make_scorers(self):
56+
make_gen_reg_scorers(self)
1457

1558

1659
class RandomForestClassifierBenchmark(Predictor, Estimator, Benchmark):

0 commit comments

Comments
 (0)