Skip to content

Commit ce74c53

Browse files
committed
feat: reshape 3+d tensors
1 parent 8c87f9e commit ce74c53

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

src/sasctl/utils/model_info.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,20 @@ def __init__(self, model, X, y=None):
176176

177177
if not isinstance(model, torch.nn.Module):
178178
raise ValueError(f"Expected PyTorch model, received {type(model)}.")
179-
if not isinstance(X, (np.ndarray, torch.Tensor)):
180-
raise ValueError(f"Expected input data to be a numpy array or PyTorch tensor, received {type(X)}.")
181-
if X.ndim != 2:
182-
raise ValueError(f"Expected input date with shape (n_samples, n_dim), received shape {X.shape}.")
179+
180+
# Some models may take multiple tensors as input. These can be passed as a tuple
181+
# of tensors. To simplify processing, convert even single inputs into tuples.
182+
if not isinstance(X, tuple):
183+
X = (X, )
184+
185+
for x in X:
186+
if not isinstance(x, (np.ndarray, torch.Tensor)):
187+
raise ValueError(f"Expected input data to be a numpy array or PyTorch tensor, received {type(X)}.")
188+
# if X.ndim != 2:
189+
# raise ValueError(f"Expected input date with shape (n_samples, n_dim), received shape {X.shape}.")
190+
191+
# Ensure each input is a PyTorch Tensor
192+
X = tuple(x if isinstance(x, torch.Tensor) else torch.tensor(x) for x in X)
183193

184194
# Store the current setting so that we can restore it later
185195
is_training = model.training
@@ -188,7 +198,7 @@ def __init__(self, model, X, y=None):
188198
model.eval()
189199

190200
with torch.no_grad():
191-
y = model(X)
201+
y = model(*X)
192202

193203
if not isinstance(y, (np.ndarray, torch.Tensor)):
194204
raise ValueError(f"Expected output data to be a numpy array or PyTorch tensor, received {type(y)}.")
@@ -199,7 +209,17 @@ def __init__(self, model, X, y=None):
199209
self._X = X
200210
self._y = y
201211

202-
self._X_df = pd.DataFrame(X, columns=[f"Var{i+1}" for i in range(X.shape[1])])
212+
# Model Manager doesn't currently support arrays or vectors. Capture the first
213+
# input tensor and reshape to 2 dimensions if necessary.
214+
x0 = X[0]
215+
if x0.ndim > 2:
216+
x0 = x0.reshape((x0.shape[0], -1))
217+
self._X_df = pd.DataFrame(x0, columns=[f"Var{i+1}" for i in range(x0.shape[1])])
218+
219+
# Flatten to 2 dimensions if necessary
220+
if y.ndim > 2:
221+
y = y.reshape((y.shape[0], -1))
222+
203223
self._y_df = pd.DataFrame(y, columns=[f"Out{i+1}" for i in range(y.shape[1])])
204224

205225
self._layer_info = self._get_layer_info(model, X)
@@ -239,7 +259,7 @@ def hook(module, input, output, *args):
239259

240260
model.eval()
241261
with torch.no_grad():
242-
model(X)
262+
model(*X)
243263

244264
for handle in hooks:
245265
handle.remove()

0 commit comments

Comments
 (0)