Skip to content

Commit a24b903

Browse files
committed
feat: simple parsing of pytorch models
1 parent d5d6e39 commit a24b903

File tree

1 file changed

+180
-7
lines changed

1 file changed

+180
-7
lines changed

src/sasctl/utils/model_info.py

Lines changed: 180 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
# SPDX-License-Identifier: Apache-2.0
66

77
from abc import ABC, abstractmethod
8+
from collections import OrderedDict
89
from typing import Any, Callable, Dict, List, Union
910

11+
import numpy as np
1012
import pandas as pd
1113

14+
try:
15+
import torch
16+
except ImportError:
17+
torch = None
1218

13-
def get_model_info(model, X, y):
19+
20+
def get_model_info(model, X, y=None):
1421
"""Extracts metadata about the model and associated data sets.
1522
1623
Parameters
@@ -32,10 +39,17 @@ def get_model_info(model, X, y):
3239
If `model` is not a recognized type.
3340
3441
"""
42+
43+
# Don't need to import sklearn, just check if the class is part of that module.
3544
if model.__class__.__module__.startswith("sklearn."):
3645
return SklearnModelInfo(model, X, y)
3746

38-
raise ValueError(f"Unrecognized model type {model} received.")
47+
# Most PyTorch models are actually subclasses of torch.nn.Module, so checking module
48+
# name alone is not sufficient.
49+
elif torch and isinstance(model, torch.nn.Module):
50+
return PyTorchModelInfo(model, X, y)
51+
52+
raise ValueError(f"Unrecognized model type {type(model)} received.")
3953

4054

4155
class ModelInfo(ABC):
@@ -63,6 +77,10 @@ class ModelInfo(ABC):
6377
threshold : float or None
6478
The cutoff value used in a binary classification model to determine which class an
6579
observation belongs to. Returns None if not a binary classification model.
80+
X : pandas.DataFrame
81+
A sample of the input data used to train the model.
82+
y : pandas.DataFrame
83+
A sample of the output data produced by the model.
6684
6785
"""
6886

@@ -113,27 +131,169 @@ def model_params(self) -> Dict[str, Any]:
113131
return
114132

115133
@property
116-
@abstractmethod
117134
def output_column_names(self) -> List[str]:
118-
return
135+
return self.y.columns.tolist()
119136

120137
@property
121138
@abstractmethod
122139
def predict_function(self) -> Callable:
123140
return
124141

142+
@property
143+
@abstractmethod
144+
def target_column(self):
145+
return
146+
125147
@property
126148
@abstractmethod
127149
def target_values(self):
128150
# "target event"
129-
# value that indicates the target event has occurred in bianry classi
151+
# value that indicates the target event has occurred in bianry classification
130152
return
131153

132154
@property
133155
@abstractmethod
134156
def threshold(self) -> Union[str, None]:
135157
return
136158

159+
@property
160+
@abstractmethod
161+
def X(self) -> pd.DataFrame:
162+
return
163+
164+
@property
165+
@abstractmethod
166+
def y(self) -> pd.DataFrame:
167+
return
168+
169+
170+
class PyTorchModelInfo(ModelInfo):
171+
"""Stores model information for a PyTorch model instance."""
172+
173+
def __init__(self, model, X, y=None):
174+
175+
if torch is None:
176+
raise RuntimeError("The PyTorch library must be installed to work with PyTorch models. Please `pip install torch`.")
177+
178+
if not isinstance(model, torch.nn.Module):
179+
raise ValueError(f"Expected PyTorch model, received {type(model)}.")
180+
if not isinstance(X, (np.ndarray, torch.Tensor)):
181+
raise ValueError(f"Expected input data to be a numpy array or PyTorch tensor, received {type(X)}.")
182+
if X.ndim != 2:
183+
raise ValueError(f"Expected input date with shape (n_samples, n_dim), received shape {X.shape}.")
184+
185+
# Store the current setting so that we can restore it later
186+
is_training = model.training
187+
188+
if y is None:
189+
model.eval()
190+
191+
with torch.no_grad():
192+
y = model(X)
193+
194+
if not isinstance(y, (np.ndarray, torch.Tensor)):
195+
raise ValueError(f"Expected output data to be a numpy array or PyTorch tensor, received {type(y)}.")
196+
197+
self._model = model
198+
199+
# TODO: convert X and y to DF with arbitrary names
200+
self._X = X
201+
self._y = y
202+
203+
self._X_df = pd.DataFrame(X, columns=[f"Var{i+1}" for i in range(X.shape[1])])
204+
self._y_df = pd.DataFrame(y, columns=[f"Out{i+1}" for i in range(y.shape[1])])
205+
206+
self._layer_info = self._get_layer_info(model, X)
207+
208+
# Reset the model to its original training state
209+
model.train(is_training)
210+
211+
@staticmethod
212+
def _get_layer_info(model, X):
213+
"""Run data through the model to determine layer types and tensor shapes.
214+
215+
Parameters
216+
----------
217+
model : torch.nn.Module
218+
X : torch.Tensor
219+
220+
Returns
221+
-------
222+
List[Tuple[torch.nn.Module, torch.Tensor, torch.Tensor]]
223+
224+
"""
225+
is_training = model.training
226+
layers = []
227+
228+
def hook(module, input, output, *args):
229+
# layers[module] = (input, output)
230+
layers.append((module, input, output))
231+
232+
for module in model.modules():
233+
module.register_forward_hook(hook)
234+
235+
model.eval()
236+
with torch.no_grad():
237+
model(X)
238+
239+
return layers
240+
241+
@property
242+
def algorithm(self):
243+
return "PyTorch"
244+
245+
@property
246+
def is_binary_classifier(self):
247+
return False
248+
249+
@property
250+
def is_classifier(self):
251+
return False
252+
253+
@property
254+
def is_clusterer(self):
255+
return False
256+
257+
@property
258+
def is_regressor(self):
259+
return False
260+
261+
@property
262+
def model(self):
263+
return self._model
264+
265+
@property
266+
def model_params(self) -> Dict[str, Any]:
267+
return self.model.__dict__
268+
269+
@property
270+
def output_column_names(self):
271+
return list(self.y.columns)
272+
273+
@property
274+
def predict_function(self):
275+
return self.model.forward
276+
277+
@property
278+
def target_column(self):
279+
return self.y.columns[0]
280+
281+
@property
282+
def target_values(self):
283+
return []
284+
285+
@property
286+
def threshold(self):
287+
return None
288+
289+
@property
290+
def X(self):
291+
return self._X_df
292+
293+
@property
294+
def y(self):
295+
return self._y_df
296+
137297

138298
class SklearnModelInfo(ModelInfo):
139299
"""Stores model information for a scikit-learn model instance."""
@@ -163,7 +323,7 @@ def __init__(self, model, X, y):
163323

164324
# If not a classfier or a clustering algorithm and output is a single column, then
165325
# assume its a regression algorithm
166-
is_regressor = not is_classifier and not is_clusterer and y_df.shape[1] == 1
326+
is_regressor = not is_classifier and not is_clusterer and (y_df.shape[1] == 1 or "Regress" in type(model).__name__)
167327

168328
if not is_classifier and not is_regressor and not is_clusterer:
169329
raise ValueError(f"Unexpected model type {model} received.")
@@ -182,7 +342,8 @@ def __init__(self, model, X, y):
182342
elif self.is_classifier:
183343
# Output is probability of each label. Name columns according to classes.
184344
y_df.columns = [f"P_{class_}" for class_ in model.classes_]
185-
else:
345+
elif not y_df.empty:
346+
# If we were passed data for `y` but we don't know the format raise an error.
186347
# This *shouldn't* happen unless a cluster algorithm somehow produces wide output.
187348
raise ValueError(f"Unrecognized model output format.")
188349

@@ -236,6 +397,10 @@ def predict_function(self):
236397
# Otherwise its the single value from .predict()
237398
return self.model.predict
238399

400+
@property
401+
def target_column(self):
402+
return self.y.columns[0]
403+
239404
@property
240405
def target_values(self):
241406
if self.is_binary_classifier:
@@ -248,3 +413,11 @@ def threshold(self):
248413
# sklearn seems to always use 0.5 as a cutoff for .predict()
249414
if self.is_binary_classifier:
250415
return 0.5
416+
417+
@property
418+
def X(self):
419+
return self._X
420+
421+
@property
422+
def y(self):
423+
return self._y

0 commit comments

Comments
 (0)