1
1
# # THE CONSTANT REGRESSOR
2
2
3
+ const MMI = MLJModelInterface
3
4
export ConstantClassifier, ConstantRegressor,
4
- DeterministicConstantClassifier,
5
- ProbabilisticConstantClassifer
5
+ DeterministicConstantClassifier,
6
+ ProbabilisticConstantClassifer
6
7
7
- import MLJBase
8
8
import Distributions
9
9
10
10
"""
@@ -14,49 +14,61 @@ A regressor that, for any new input pattern, predicts the univariate
14
14
probability distribution best fitting the training target data. Use
15
15
`predict_mean` to predict the mean value instead.
16
16
"""
17
- struct ConstantRegressor{D} <: MLJBase.Probabilistic
18
- distribution_type:: Type{D}
19
- end
17
+ struct ConstantRegressor{D} <: MMI.Probabilistic end
20
18
21
19
function ConstantRegressor (; distribution_type= Distributions. Normal)
22
- model = ConstantRegressor ( distribution_type)
20
+ model = ConstantRegressor { distribution_type} ( )
23
21
message = clean! (model)
24
22
isempty (message) || @warn message
25
23
return model
26
24
end
27
25
28
- function clean! (model:: ConstantRegressor )
26
+ function MMI . clean! (model:: ConstantRegressor{D} ) where D
29
27
message = " "
30
- MLJBase . isdistribution (model . distribution_type) ||
28
+ D <: Distributions.Sampleable ||
31
29
error (" $model .distribution_type is not a valid distribution_type." )
32
30
return message
33
31
end
34
32
35
- function MLJBase. fit (:: ConstantRegressor{D} , verbosity:: Int , X, y) where D
33
+ MMI. reformat (:: ConstantRegressor , X) = (MMI. matrix (X),)
34
+ MMI. reformat (:: ConstantRegressor , X, y) = (MMI. matrix (X), y)
35
+ MMI. selectrows (:: ConstantRegressor , I, A) = (view (A, I, :),)
36
+ MMI. selectrows (:: ConstantRegressor , I, A, y) = (view (A, I, :), y[I])
37
+
38
+ function MMI. fit (:: ConstantRegressor{D} , verbosity:: Int , A, y) where D
36
39
fitresult = Distributions. fit (D, y)
37
40
cache = nothing
38
41
report = NamedTuple ()
39
42
return fitresult, cache, report
40
43
end
41
44
42
- MLJBase. fitted_params (:: ConstantRegressor , fitresult) = (target_distribution= fitresult,)
45
+ MMI. fitted_params (:: ConstantRegressor , fitresult) =
46
+ (target_distribution= fitresult,)
43
47
44
- MLJBase. predict (:: ConstantRegressor , fitresult, Xnew) = fill (fitresult, nrows (Xnew))
48
+ MMI. predict (:: ConstantRegressor , fitresult, Xnew) =
49
+ fill (fitresult, nrows (Xnew))
45
50
46
51
# #
47
52
# # THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
48
53
# #
49
54
50
- struct DeterministicConstantRegressor <: MLJBase .Deterministic end
55
+ struct DeterministicConstantRegressor <: MMI .Deterministic end
51
56
52
- function MLJBase . fit (:: DeterministicConstantRegressor , verbosity:: Int , X, y)
57
+ function MMI . fit (:: DeterministicConstantRegressor , verbosity:: Int , X, y)
53
58
fitresult = mean (y)
54
59
cache = nothing
55
60
report = NamedTuple ()
56
61
return fitresult, cache, report
57
62
end
58
63
59
- MLJBase. predict (:: DeterministicConstantRegressor , fitresult, Xnew) = fill (fitresult, nrows (Xnew))
64
+ MMI. reformat (:: DeterministicConstantRegressor , X) = (MMI. matrix (X),)
65
+ MMI. reformat (:: DeterministicConstantRegressor , X, y) = (MMI. matrix (X), y)
66
+ MMI. selectrows (:: DeterministicConstantRegressor , I, A) = (view (A, I, :),)
67
+ MMI. selectrows (:: DeterministicConstantRegressor , I, A, y) =
68
+ (view (A, I, :), y[I])
69
+
70
+ MMI. predict (:: DeterministicConstantRegressor , fitresult, Xnew) =
71
+ fill (fitresult, nrows (Xnew))
60
72
61
73
# #
62
74
# # THE CONSTANT CLASSIFIER
@@ -71,39 +83,89 @@ training target data. So, `pdf(d, level)` is the proportion of levels
71
83
in the training data coinciding with `level`. Use `predict_mode` to
72
84
obtain the training target mode instead.
73
85
"""
74
- struct ConstantClassifier <: MLJBase.Probabilistic end
86
+ mutable struct ConstantClassifier <: MMI.Probabilistic
87
+ testing:: Bool
88
+ bogus:: Int
89
+ end
90
+
91
+ ConstantClassifier (; testing= false , bogus= 0 ) =
92
+ ConstantClassifier (testing, bogus)
93
+
94
+ function MMI. reformat (model:: ConstantClassifier , X)
95
+ model. testing && @info " reformatting X"
96
+ return (MMI. matrix (X),)
97
+ end
98
+
99
+ function MMI. reformat (model:: ConstantClassifier , X, y)
100
+ model. testing && @info " reformatting X, y"
101
+ return (MMI. matrix (X), y)
102
+ end
103
+
104
+ function MMI. reformat (model:: ConstantClassifier , X, y, w)
105
+ model. testing && @info " reformatting X, y, w"
106
+ return (MMI. matrix (X), y, w)
107
+ end
108
+
109
+ function MMI. selectrows (model:: ConstantClassifier , I, A)
110
+ model. testing && @info " resampling X"
111
+ return (view (A, I, :),)
112
+ end
113
+
114
+ function MMI. selectrows (model:: ConstantClassifier , I, A, y)
115
+ model. testing && @info " resampling X, y"
116
+ return (view (A, I, :), y[I])
117
+ end
118
+
119
+ function MMI. selectrows (model:: ConstantClassifier , I, A, y, :: Nothing )
120
+ model. testing && @info " resampling X, y, nothing"
121
+ return (view (A, I, :), y[I], nothing )
122
+ end
123
+
124
+ function MMI. selectrows (model:: ConstantClassifier , I, A, y, w)
125
+ model. testing && @info " resampling X, y, nothing"
126
+ return (view (A, I, :), y[I], w[I])
127
+ end
75
128
76
129
# here `args` is `y` or `y, w`:
77
- function MLJBase . fit (:: ConstantClassifier , verbosity:: Int , X , y, w= nothing )
130
+ function MMI . fit (:: ConstantClassifier , verbosity:: Int , A , y, w= nothing )
78
131
fitresult = Distributions. fit (MLJBase. UnivariateFinite, y, w)
79
132
cache = nothing
80
133
report = NamedTuple
81
134
return fitresult, cache, report
82
135
end
83
136
84
- MLJBase. fitted_params (:: ConstantClassifier , fitresult) = (target_distribution= fitresult,)
137
+ MMI. fitted_params (:: ConstantClassifier , fitresult) =
138
+ (target_distribution= fitresult,)
85
139
86
- MLJBase. predict (:: ConstantClassifier , fitresult, Xnew) = fill (fitresult, nrows (Xnew))
140
+ MMI. predict (:: ConstantClassifier , fitresult, Xnew) =
141
+ fill (fitresult, nrows (Xnew))
87
142
88
143
# #
89
144
# # DETERMINISTIC CONSTANT CLASSIFIER (FOR TESTING)
90
145
# #
91
146
92
- struct DeterministicConstantClassifier <: MLJBase .Deterministic end
147
+ struct DeterministicConstantClassifier <: MMI .Deterministic end
93
148
94
- function MLJBase . fit (:: DeterministicConstantClassifier , verbosity:: Int , X, y)
149
+ function MMI . fit (:: DeterministicConstantClassifier , verbosity:: Int , X, y)
95
150
# dump missing target values and make into a regular array:
96
- fitresult = mode (skipmissing (y) |> collect) # a CategoricalValue or CategoricalString
151
+ fitresult = mode (skipmissing (y) |> collect) # a CategoricalValue
97
152
cache = nothing
98
153
report = NamedTuple ()
99
154
return fitresult, cache, report
100
155
end
101
156
102
- MLJBase. predict (:: DeterministicConstantClassifier , fitresult, Xnew) = fill (fitresult, nrows (Xnew))
157
+ MMI. reformat (:: DeterministicConstantClassifier , X) = (MMI. matrix (X),)
158
+ MMI. reformat (:: DeterministicConstantClassifier , X, y) = (MMI. matrix (X), y)
159
+ MMI. selectrows (:: DeterministicConstantClassifier , I, A) = (view (A, I, :),)
160
+ MMI. selectrows (:: DeterministicConstantClassifier , I, A, y) =
161
+ (view (A, I, :), y[I])
103
162
104
- # #
105
- # # METADATA
106
- # #
163
+ MMI. predict (:: DeterministicConstantClassifier , fitresult, Xnew) =
164
+ fill (fitresult, nrows (Xnew))
165
+
166
+ #
167
+ # METADATA
168
+ #
107
169
108
170
metadata_pkg .((ConstantRegressor, ConstantClassifier,
109
171
DeterministicConstantRegressor, DeterministicConstantClassifier),
@@ -115,29 +177,29 @@ metadata_pkg.((ConstantRegressor, ConstantClassifier,
115
177
is_wrapper= false )
116
178
117
179
metadata_model (ConstantRegressor,
118
- input= MLJBase . Table (MLJBase . Scientific) ,
119
- target= AbstractVector{MLJBase . Continuous},
180
+ input= MMI . Table,
181
+ target= AbstractVector{MMI . Continuous},
120
182
weights= false ,
121
183
descr= " Constant regressor (Probabilistic)." ,
122
184
path= " MLJModels.ConstantRegressor" )
123
185
124
186
metadata_model (DeterministicConstantRegressor,
125
- input= MLJBase . Table (MLJBase . Scientific) ,
126
- target= AbstractVector{MLJBase . Continuous},
187
+ input= MMI . Table,
188
+ target= AbstractVector{MMI . Continuous},
127
189
weights= false ,
128
190
descr= " Constant regressor (Deterministic)." ,
129
191
path= " MLJModels.DeterministicConstantRegressor" )
130
192
131
193
metadata_model (ConstantClassifier,
132
- input= MLJBase . Table (MLJBase . Scientific) ,
133
- target= AbstractVector{<: MLJBase .Finite },
194
+ input= MMI . Table,
195
+ target= AbstractVector{<: MMI .Finite },
134
196
weights= true ,
135
197
descr= " Constant classifier (Probabilistic)." ,
136
198
path= " MLJModels.ConstantClassifier" )
137
199
138
200
metadata_model (DeterministicConstantClassifier,
139
- input= MLJBase . Table (MLJBase . Scientific) ,
140
- target= AbstractVector{<: MLJBase .Finite },
201
+ input= MMI . Table,
202
+ target= AbstractVector{<: MMI .Finite },
141
203
weights= false ,
142
204
descr= " Constant classifier (Deterministic)." ,
143
205
path= " MLJModels.DeterministicConstantClassifier" )
0 commit comments