@@ -176,10 +176,20 @@ def __init__(self, model, X, y=None):
176
176
177
177
if not isinstance (model , torch .nn .Module ):
178
178
raise ValueError (f"Expected PyTorch model, received { type (model )} ." )
179
- if not isinstance (X , (np .ndarray , torch .Tensor )):
180
- raise ValueError (f"Expected input data to be a numpy array or PyTorch tensor, received { type (X )} ." )
181
- if X .ndim != 2 :
182
- raise ValueError (f"Expected input date with shape (n_samples, n_dim), received shape { X .shape } ." )
179
+
180
+ # Some models may take multiple tensors as input. These can be passed as a tuple
181
+ # of tensors. To simplify processing, convert even single inputs into tuples.
182
+ if not isinstance (X , tuple ):
183
+ X = (X , )
184
+
185
+ for x in X :
186
+ if not isinstance (x , (np .ndarray , torch .Tensor )):
187
+ raise ValueError (f"Expected input data to be a numpy array or PyTorch tensor, received { type (X )} ." )
188
+ # if X.ndim != 2:
189
+ # raise ValueError(f"Expected input date with shape (n_samples, n_dim), received shape {X.shape}.")
190
+
191
+ # Ensure each input is a PyTorch Tensor
192
+ X = tuple (x if isinstance (x , torch .Tensor ) else torch .tensor (x ) for x in X )
183
193
184
194
# Store the current setting so that we can restore it later
185
195
is_training = model .training
@@ -188,7 +198,7 @@ def __init__(self, model, X, y=None):
188
198
model .eval ()
189
199
190
200
with torch .no_grad ():
191
- y = model (X )
201
+ y = model (* X )
192
202
193
203
if not isinstance (y , (np .ndarray , torch .Tensor )):
194
204
raise ValueError (f"Expected output data to be a numpy array or PyTorch tensor, received { type (y )} ." )
@@ -199,7 +209,17 @@ def __init__(self, model, X, y=None):
199
209
self ._X = X
200
210
self ._y = y
201
211
202
- self ._X_df = pd .DataFrame (X , columns = [f"Var{ i + 1 } " for i in range (X .shape [1 ])])
212
+ # Model Manager doesn't currently support arrays or vectors. Capture the first
213
+ # input tensor and reshape to 2 dimensions if necessary.
214
+ x0 = X [0 ]
215
+ if x0 .ndim > 2 :
216
+ x0 = x0 .reshape ((x0 .shape [0 ], - 1 ))
217
+ self ._X_df = pd .DataFrame (x0 , columns = [f"Var{ i + 1 } " for i in range (x0 .shape [1 ])])
218
+
219
+ # Flatten to 2 dimensions if necessary
220
+ if y .ndim > 2 :
221
+ y = y .reshape ((y .shape [0 ], - 1 ))
222
+
203
223
self ._y_df = pd .DataFrame (y , columns = [f"Out{ i + 1 } " for i in range (y .shape [1 ])])
204
224
205
225
self ._layer_info = self ._get_layer_info (model , X )
@@ -239,7 +259,7 @@ def hook(module, input, output, *args):
239
259
240
260
model .eval ()
241
261
with torch .no_grad ():
242
- model (X )
262
+ model (* X )
243
263
244
264
for handle in hooks :
245
265
handle .remove ()
0 commit comments