5
5
# SPDX-License-Identifier: Apache-2.0
6
6
7
7
from abc import ABC , abstractmethod
8
+ from collections import OrderedDict
8
9
from typing import Any , Callable , Dict , List , Union
9
10
11
+ import numpy as np
10
12
import pandas as pd
11
13
14
+ try :
15
+ import torch
16
+ except ImportError :
17
+ torch = None
12
18
13
- def get_model_info (model , X , y ):
19
+
20
+ def get_model_info (model , X , y = None ):
14
21
"""Extracts metadata about the model and associated data sets.
15
22
16
23
Parameters
@@ -32,10 +39,17 @@ def get_model_info(model, X, y):
32
39
If `model` is not a recognized type.
33
40
34
41
"""
42
+
43
+ # Don't need to import sklearn, just check if the class is part of that module.
35
44
if model .__class__ .__module__ .startswith ("sklearn." ):
36
45
return SklearnModelInfo (model , X , y )
37
46
38
- raise ValueError (f"Unrecognized model type { model } received." )
47
+ # Most PyTorch models are actually subclasses of torch.nn.Module, so checking module
48
+ # name alone is not sufficient.
49
+ elif torch and isinstance (model , torch .nn .Module ):
50
+ return PyTorchModelInfo (model , X , y )
51
+
52
+ raise ValueError (f"Unrecognized model type { type (model )} received." )
39
53
40
54
41
55
class ModelInfo (ABC ):
@@ -63,6 +77,10 @@ class ModelInfo(ABC):
63
77
threshold : float or None
64
78
The cutoff value used in a binary classification model to determine which class an
65
79
observation belongs to. Returns None if not a binary classification model.
80
+ X : pandas.DataFrame
81
+ A sample of the input data used to train the model.
82
+ y : pandas.DataFrame
83
+ A sample of the output data produced by the model.
66
84
67
85
"""
68
86
@@ -113,27 +131,169 @@ def model_params(self) -> Dict[str, Any]:
113
131
return
114
132
115
133
@property
116
- @abstractmethod
117
134
def output_column_names (self ) -> List [str ]:
118
- return
135
+ return self . y . columns . tolist ()
119
136
120
137
@property
121
138
@abstractmethod
122
139
def predict_function (self ) -> Callable :
123
140
return
124
141
142
+ @property
143
+ @abstractmethod
144
+ def target_column (self ):
145
+ return
146
+
125
147
@property
126
148
@abstractmethod
127
149
def target_values (self ):
128
150
# "target event"
129
- # value that indicates the target event has occurred in bianry classi
151
+ # value that indicates the target event has occurred in bianry classification
130
152
return
131
153
132
154
@property
133
155
@abstractmethod
134
156
def threshold (self ) -> Union [str , None ]:
135
157
return
136
158
159
+ @property
160
+ @abstractmethod
161
+ def X (self ) -> pd .DataFrame :
162
+ return
163
+
164
+ @property
165
+ @abstractmethod
166
+ def y (self ) -> pd .DataFrame :
167
+ return
168
+
169
+
170
+ class PyTorchModelInfo (ModelInfo ):
171
+ """Stores model information for a PyTorch model instance."""
172
+
173
+ def __init__ (self , model , X , y = None ):
174
+
175
+ if torch is None :
176
+ raise RuntimeError ("The PyTorch library must be installed to work with PyTorch models. Please `pip install torch`." )
177
+
178
+ if not isinstance (model , torch .nn .Module ):
179
+ raise ValueError (f"Expected PyTorch model, received { type (model )} ." )
180
+ if not isinstance (X , (np .ndarray , torch .Tensor )):
181
+ raise ValueError (f"Expected input data to be a numpy array or PyTorch tensor, received { type (X )} ." )
182
+ if X .ndim != 2 :
183
+ raise ValueError (f"Expected input date with shape (n_samples, n_dim), received shape { X .shape } ." )
184
+
185
+ # Store the current setting so that we can restore it later
186
+ is_training = model .training
187
+
188
+ if y is None :
189
+ model .eval ()
190
+
191
+ with torch .no_grad ():
192
+ y = model (X )
193
+
194
+ if not isinstance (y , (np .ndarray , torch .Tensor )):
195
+ raise ValueError (f"Expected output data to be a numpy array or PyTorch tensor, received { type (y )} ." )
196
+
197
+ self ._model = model
198
+
199
+ # TODO: convert X and y to DF with arbitrary names
200
+ self ._X = X
201
+ self ._y = y
202
+
203
+ self ._X_df = pd .DataFrame (X , columns = [f"Var{ i + 1 } " for i in range (X .shape [1 ])])
204
+ self ._y_df = pd .DataFrame (y , columns = [f"Out{ i + 1 } " for i in range (y .shape [1 ])])
205
+
206
+ self ._layer_info = self ._get_layer_info (model , X )
207
+
208
+ # Reset the model to its original training state
209
+ model .train (is_training )
210
+
211
+ @staticmethod
212
+ def _get_layer_info (model , X ):
213
+ """Run data through the model to determine layer types and tensor shapes.
214
+
215
+ Parameters
216
+ ----------
217
+ model : torch.nn.Module
218
+ X : torch.Tensor
219
+
220
+ Returns
221
+ -------
222
+ List[Tuple[torch.nn.Module, torch.Tensor, torch.Tensor]]
223
+
224
+ """
225
+ is_training = model .training
226
+ layers = []
227
+
228
+ def hook (module , input , output , * args ):
229
+ # layers[module] = (input, output)
230
+ layers .append ((module , input , output ))
231
+
232
+ for module in model .modules ():
233
+ module .register_forward_hook (hook )
234
+
235
+ model .eval ()
236
+ with torch .no_grad ():
237
+ model (X )
238
+
239
+ return layers
240
+
241
+ @property
242
+ def algorithm (self ):
243
+ return "PyTorch"
244
+
245
+ @property
246
+ def is_binary_classifier (self ):
247
+ return False
248
+
249
+ @property
250
+ def is_classifier (self ):
251
+ return False
252
+
253
+ @property
254
+ def is_clusterer (self ):
255
+ return False
256
+
257
+ @property
258
+ def is_regressor (self ):
259
+ return False
260
+
261
+ @property
262
+ def model (self ):
263
+ return self ._model
264
+
265
+ @property
266
+ def model_params (self ) -> Dict [str , Any ]:
267
+ return self .model .__dict__
268
+
269
+ @property
270
+ def output_column_names (self ):
271
+ return list (self .y .columns )
272
+
273
+ @property
274
+ def predict_function (self ):
275
+ return self .model .forward
276
+
277
+ @property
278
+ def target_column (self ):
279
+ return self .y .columns [0 ]
280
+
281
+ @property
282
+ def target_values (self ):
283
+ return []
284
+
285
+ @property
286
+ def threshold (self ):
287
+ return None
288
+
289
+ @property
290
+ def X (self ):
291
+ return self ._X_df
292
+
293
+ @property
294
+ def y (self ):
295
+ return self ._y_df
296
+
137
297
138
298
class SklearnModelInfo (ModelInfo ):
139
299
"""Stores model information for a scikit-learn model instance."""
@@ -163,7 +323,7 @@ def __init__(self, model, X, y):
163
323
164
324
# If not a classfier or a clustering algorithm and output is a single column, then
165
325
# assume its a regression algorithm
166
- is_regressor = not is_classifier and not is_clusterer and y_df .shape [1 ] == 1
326
+ is_regressor = not is_classifier and not is_clusterer and ( y_df .shape [1 ] == 1 or "Regress" in type ( model ). __name__ )
167
327
168
328
if not is_classifier and not is_regressor and not is_clusterer :
169
329
raise ValueError (f"Unexpected model type { model } received." )
@@ -182,7 +342,8 @@ def __init__(self, model, X, y):
182
342
elif self .is_classifier :
183
343
# Output is probability of each label. Name columns according to classes.
184
344
y_df .columns = [f"P_{ class_ } " for class_ in model .classes_ ]
185
- else :
345
+ elif not y_df .empty :
346
+ # If we were passed data for `y` but we don't know the format raise an error.
186
347
# This *shouldn't* happen unless a cluster algorithm somehow produces wide output.
187
348
raise ValueError (f"Unrecognized model output format." )
188
349
@@ -236,6 +397,10 @@ def predict_function(self):
236
397
# Otherwise its the single value from .predict()
237
398
return self .model .predict
238
399
400
+ @property
401
+ def target_column (self ):
402
+ return self .y .columns [0 ]
403
+
239
404
@property
240
405
def target_values (self ):
241
406
if self .is_binary_classifier :
@@ -248,3 +413,11 @@ def threshold(self):
248
413
# sklearn seems to always use 0.5 as a cutoff for .predict()
249
414
if self .is_binary_classifier :
250
415
return 0.5
416
+
417
+ @property
418
+ def X (self ):
419
+ return self ._X
420
+
421
+ @property
422
+ def y (self ):
423
+ return self ._y
0 commit comments