Skip to content

Commit a88a791

Browse files
authored
Merge pull request #12 from rafa-rod/patch-1
train_test_split test_size and shuffle parameters
2 parents 9e87082 + 7b54647 commit a88a791

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

imodelsx/kan/kan_sklearn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self,
1818
hidden_layer_size: int = 64,
1919
hidden_layer_sizes: List[int] = None,
2020
regularize_activation: float = 1.0, regularize_entropy: float = 1.0, regularize_ridge: float = 0.0,
21+
test_size=0.2, random_state=42, shuffle=True,
2122
device: str = 'cpu',
2223
**kwargs):
2324
'''
@@ -82,7 +83,7 @@ def fit(self, X, y, batch_size=512, lr=1e-3, weight_decay=1e-4, gamma=0.8):
8283
).to(self.device)
8384

8485
X_train, X_tune, y_train, y_tune = train_test_split(
85-
X, y, test_size=0.2, random_state=42)
86+
X, y, test_size=test_size, random_state=random_state, shuffle=shuffle)
8687

8788
dset_train = torch.utils.data.TensorDataset(X_train, y_train)
8889
dset_tune = torch.utils.data.TensorDataset(X_tune, y_tune)

0 commit comments

Comments
 (0)