Skip to content

Commit c2e2116

Browse files
committed
Only perform one shot encoding if the estimator does not support categorical data types
1 parent 40aaa91 commit c2e2116

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

src/Console/Commands/Train.php

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use Rubix\ML\Classifiers\KNearestNeighbors;
1313
use Rubix\ML\Classifiers\MultilayerPerceptron;
1414
use Rubix\ML\Datasets\Labeled;
15+
use Rubix\ML\DataType;
1516
use Rubix\ML\Estimator;
1617
use Rubix\ML\NeuralNet\Layers\Dense;
1718
use Rubix\ML\NeuralNet\Layers\PReLU;
@@ -131,14 +132,7 @@ public function handle()
131132
private function getEstimator(string $modelPath, Estimator $baseEstimator): Estimator
132133
{
133134
$estimator = new PersistentModel(
134-
new Pipeline(
135-
[
136-
new MissingDataImputer(),
137-
new OneHotEncoder(),
138-
new ZScaleStandardizer(),
139-
],
140-
$baseEstimator
141-
),
135+
new Pipeline($this->getTransformers($baseEstimator), $baseEstimator),
142136
new Filesystem($modelPath)
143137
);
144138

@@ -151,12 +145,6 @@ private function getEstimator(string $modelPath, Estimator $baseEstimator): Esti
151145

152146
private function getDefaultBaseEstimator(bool $continuous): Estimator
153147
{
154-
// $layers = [
155-
// new Dense(100),
156-
// new Dense(100),
157-
// new Dense(100),
158-
// ];
159-
160148
$baseEstimator = new KDNeighbors();
161149

162150
if ($continuous) {
@@ -165,4 +153,20 @@ private function getDefaultBaseEstimator(bool $continuous): Estimator
165153

166154
return $baseEstimator;
167155
}
156+
157+
private function getTransformers(Estimator $estimator): array
158+
{
159+
$dataTypes = $estimator->compatibility();
160+
161+
$transformers = [];
162+
$transformers[] = new MissingDataImputer();
163+
164+
if (!in_array(DataType::categorical(), $dataTypes) && in_array(DataType::continuous(), $dataTypes)) {
165+
$transformers[] = new OneHotEncoder();
166+
}
167+
168+
$transformers[] = new ZScaleStandardizer();
169+
170+
return $transformers;
171+
}
168172
}

0 commit comments

Comments
 (0)