You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
By default, fusilli will split your data into train/test or cross-validation splits for you randomly based on a test size or a number of folds you specify in the :func:`~.fusilli.data.prepare_fusion_data` function.
260
+
261
+
You can remove the randomness and specify the data indices for train and test, or for the different cross validation folds yourself by passing in optional arguments to :func:`~.fusilli.data.prepare_fusion_data`.
262
+
263
+
264
+
For train/test splitting, the argument `test_indices` should be a list of indices for the test set. To make the test set the first 6 data points in the overall dataset, follow the example below:
265
+
266
+
.. code-block:: python
267
+
268
+
from fusilli.data import prepare_fusion_data
269
+
from fusilli.train import train_and_save_models
270
+
271
+
test_indices = [0, 1, 2, 3, 4, 5]
272
+
273
+
datamodule = prepare_fusion_data(
274
+
prediction_task="binary",
275
+
fusion_model=example_model,
276
+
data_paths=data_paths,
277
+
output_paths=output_path,
278
+
test_indices=test_indices,
279
+
)
280
+
281
+
For specifying your own cross validation folds, the argument `own_kfold_indices` should be a list of lists of indices for each fold.
282
+
283
+
If you wanted to have non-random cross validation folds through your data, you can either specify the folds like so for 3 folds:
Or to do this automatically, use the Scikit-Learn `KFold functionality <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html>`_ to generate the folds outside of the fusilli functions, like so:
294
+
295
+
.. code-block:: python
296
+
297
+
from sklearn.model_selection import KFold
298
+
299
+
num_folds =5
300
+
301
+
own_kfold_indices = [(train_index, test_index) for train_index, test_index in KFold(n_splits=num_folds).split(range(len(dataset)))]
0 commit comments