|
2 | 2 | GradientBoostingClassifier,
|
3 | 3 | HistGradientBoostingClassifier,
|
4 | 4 | RandomForestClassifier,
|
| 5 | + RandomForestRegressor |
5 | 6 | )
|
6 | 7 |
|
7 | 8 | from .common import Benchmark, Estimator, Predictor
|
8 | 9 | from .datasets import (
|
9 | 10 | _20newsgroups_highdim_dataset,
|
10 | 11 | _20newsgroups_lowdim_dataset,
|
11 | 12 | _synth_classification_dataset,
|
| 13 | + _synth_regression_dataset, |
| 14 | + _synth_regression_sparse_dataset |
12 | 15 | )
|
13 |
| -from .utils import make_gen_classif_scorers |
| 16 | +from .utils import make_gen_classif_scorers, make_gen_reg_scorers |
| 17 | + |
| 18 | + |
| 19 | +class RandomForestRegressorBenchmark(Predictor, Estimator, Benchmark): |
| 20 | + """ |
| 21 | + Benchmarks for RandomForestRegressor. |
| 22 | + """ |
| 23 | + |
| 24 | + param_names = ["representation", "n_jobs"] |
| 25 | + params = (["dense", "sparse"], Benchmark.n_jobs_vals) |
| 26 | + |
| 27 | + def setup_cache(self): |
| 28 | + super().setup_cache() |
| 29 | + |
| 30 | + def make_data(self, params): |
| 31 | + representation, n_jobs = params |
| 32 | + |
| 33 | + if representation == "sparse": |
| 34 | + data = _synth_regression_sparse_dataset() |
| 35 | + else: |
| 36 | + data = _synth_regression_dataset() |
| 37 | + |
| 38 | + return data |
| 39 | + |
| 40 | + def make_estimator(self, params): |
| 41 | + representation, n_jobs = params |
| 42 | + |
| 43 | + n_estimators = 500 if Benchmark.data_size == "large" else 100 |
| 44 | + |
| 45 | + estimator = RandomForestRegressor( |
| 46 | + n_estimators=n_estimators, |
| 47 | + min_samples_split=10, |
| 48 | + max_features="log2", |
| 49 | + n_jobs=n_jobs, |
| 50 | + random_state=0, |
| 51 | + ) |
| 52 | + |
| 53 | + return estimator |
| 54 | + |
| 55 | + def make_scorers(self): |
| 56 | + make_gen_reg_scorers(self) |
14 | 57 |
|
15 | 58 |
|
16 | 59 | class RandomForestClassifierBenchmark(Predictor, Estimator, Benchmark):
|
|
0 commit comments