12
12
use Rubix \ML \Classifiers \KNearestNeighbors ;
13
13
use Rubix \ML \Classifiers \MultilayerPerceptron ;
14
14
use Rubix \ML \Datasets \Labeled ;
15
+ use Rubix \ML \DataType ;
15
16
use Rubix \ML \Estimator ;
16
17
use Rubix \ML \NeuralNet \Layers \Dense ;
17
18
use Rubix \ML \NeuralNet \Layers \PReLU ;
@@ -131,14 +132,7 @@ public function handle()
131
132
private function getEstimator (string $ modelPath , Estimator $ baseEstimator ): Estimator
132
133
{
133
134
$ estimator = new PersistentModel (
134
- new Pipeline (
135
- [
136
- new MissingDataImputer (),
137
- new OneHotEncoder (),
138
- new ZScaleStandardizer (),
139
- ],
140
- $ baseEstimator
141
- ),
135
+ new Pipeline ($ this ->getTransformers ($ baseEstimator ), $ baseEstimator ),
142
136
new Filesystem ($ modelPath )
143
137
);
144
138
@@ -151,12 +145,6 @@ private function getEstimator(string $modelPath, Estimator $baseEstimator): Esti
151
145
152
146
private function getDefaultBaseEstimator (bool $ continuous ): Estimator
153
147
{
154
- // $layers = [
155
- // new Dense(100),
156
- // new Dense(100),
157
- // new Dense(100),
158
- // ];
159
-
160
148
$ baseEstimator = new KDNeighbors ();
161
149
162
150
if ($ continuous ) {
@@ -165,4 +153,20 @@ private function getDefaultBaseEstimator(bool $continuous): Estimator
165
153
166
154
return $ baseEstimator ;
167
155
}
156
+
157
+ private function getTransformers (Estimator $ estimator ): array
158
+ {
159
+ $ dataTypes = $ estimator ->compatibility ();
160
+
161
+ $ transformers = [];
162
+ $ transformers [] = new MissingDataImputer ();
163
+
164
+ if (!in_array (DataType::categorical (), $ dataTypes ) && in_array (DataType::continuous (), $ dataTypes )) {
165
+ $ transformers [] = new OneHotEncoder ();
166
+ }
167
+
168
+ $ transformers [] = new ZScaleStandardizer ();
169
+
170
+ return $ transformers ;
171
+ }
168
172
}
0 commit comments