Skip to content

Commit 12f196e

Browse files
committed
rerun
1 parent 7318afb commit 12f196e

File tree

1 file changed

+62
-119
lines changed

1 file changed

+62
-119
lines changed

MTL_Train_TCGA_Test_SCLC.ipynb

Lines changed: 62 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
},
3333
{
3434
"cell_type": "code",
35-
"execution_count": 2,
35+
"execution_count": 1,
3636
"metadata": {},
3737
"outputs": [],
3838
"source": [
@@ -115,11 +115,11 @@
115115
"output_type": "stream",
116116
"text": [
117117
"BRCA\n",
118-
"{'ACTB', 'EFTUD2', 'PLAU', 'DDX23', 'GSK3B', 'HSPA8', 'MKI67', 'ESR1', 'SHMT2', 'ERBB2', 'STAU1', 'SSR1', 'UBXN6', 'PRKRA', 'TUBA1C', 'SNIP1', 'YWHAB', 'PGR', 'BTRC', 'SRSF5'} 20\n",
118+
"{'SHMT2', 'PLAU', 'TUBA1C', 'STAU1', 'MKI67', 'HSPA8', 'ESR1', 'PRKRA', 'DDX23', 'YWHAB', 'GSK3B', 'PGR', 'UBXN6', 'SNIP1', 'ACTB', 'BTRC', 'SRSF5', 'ERBB2', 'EFTUD2', 'SSR1'} 20\n",
119119
"LUAD\n",
120-
"{'SSR1', 'PUM1', 'SERBP1', 'SLC2A1', 'CADM1', 'ALCAM', 'HNRNPU', 'PRKRA', 'PTK7', 'KDM1A', 'KRR1', 'STAU1', 'CDC73', 'OCIAD1', 'HIF1A', 'DHX9', 'CLTC', 'EPCAM'} 18\n",
120+
"{'PTK7', 'KRR1', 'OCIAD1', 'SLC2A1', 'DHX9', 'CLTC', 'PRKRA', 'EPCAM', 'KDM1A', 'STAU1', 'HIF1A', 'SERBP1', 'HNRNPU', 'CADM1', 'ALCAM', 'PUM1', 'SSR1', 'CDC73'} 18\n",
121121
"COAD\n",
122-
"{'PROM1', 'CD44', 'HNRNPK', 'ZBTB2', 'TFCP2', 'RNF4', 'SERBP1', 'HNRNPR', 'EPCAM', 'ABCG2', 'HNRNPU', 'ABCB1', 'HNRNPL', 'DHX9', 'RPL4', 'PUM1', 'ALCAM', 'ALDH1A1', 'HNRNPA1', 'ABCC1'} 20\n",
122+
"{'ZBTB2', 'DHX9', 'HNRNPL', 'SERBP1', 'ABCC1', 'HNRNPA1', 'ABCG2', 'HNRNPR', 'RNF4', 'ABCB1', 'HNRNPU', 'RPL4', 'TFCP2', 'CD44', 'PROM1', 'EPCAM', 'PUM1', 'HNRNPK', 'ALDH1A1', 'ALCAM'} 20\n",
123123
"all three\n",
124124
"set()\n"
125125
]
@@ -875,7 +875,7 @@
875875
},
876876
{
877877
"cell_type": "code",
878-
"execution_count": 3,
878+
"execution_count": 14,
879879
"metadata": {},
880880
"outputs": [
881881
{
@@ -921,7 +921,7 @@
921921
},
922922
{
923923
"cell_type": "code",
924-
"execution_count": 4,
924+
"execution_count": 15,
925925
"metadata": {},
926926
"outputs": [],
927927
"source": [
@@ -959,37 +959,7 @@
959959
},
960960
{
961961
"cell_type": "code",
962-
"execution_count": 5,
963-
"metadata": {},
964-
"outputs": [],
965-
"source": [
966-
"def cross_validation(manager, config, log_path, external_testing_dataloader):\n",
967-
" for key, values in manager['TCGA_BLC']['dataloaders'].items():\n",
968-
" if isinstance(key, int) and config['cross_validation']:\n",
969-
" models, optimizers = create_models_and_optimizers(config)\n",
970-
" lit_model = LitFullModel(models, optimizers, config)\n",
971-
" trainer = pl.Trainer( # Create sub-folders for each fold.\n",
972-
" default_root_dir=log_path,\n",
973-
" max_epochs=config['max_epochs'],\n",
974-
" log_every_n_steps=1,\n",
975-
" enable_model_summary=False,\n",
976-
" enable_checkpointing=False,\n",
977-
" \n",
978-
" )\n",
979-
" \n",
980-
" trainer.fit(lit_model, train_dataloaders=values['train'])\n",
981-
" \n",
982-
" \n",
983-
" elif key == 'train':\n",
984-
" train = values\n",
985-
" elif key == 'test':\n",
986-
" test = external_testing_dataloader #values\n",
987-
" return train, test"
988-
]
989-
},
990-
{
991-
"cell_type": "code",
992-
"execution_count": 6,
962+
"execution_count": 16,
993963
"metadata": {},
994964
"outputs": [],
995965
"source": [
@@ -1001,32 +971,28 @@
1001971
" \n",
1002972
" with open(config_file, 'r') as f:\n",
1003973
" config = yaml.load(f, Loader=yaml.FullLoader)\n",
1004-
" override_n_genes(config) \n",
974+
" override_n_genes(config) # For multi-task graph models.\n",
1005975
" config_name = Path(config_file).stem\n",
1006976
"\n",
1007-
" # Setup logging.\n",
977+
" # Setup logging.\n",
1008978
" log_path = f'Logs/{config_name}/{datetime.now():%Y-%m-%dT%H:%M:%S}/'\n",
1009979
" setup_logging(log_path)\n",
1010-
" \n",
980+
" #setup_logging(log_path := f'Logs/{config_name}/{datetime.now():%Y-%m-%dT%H:%M:%S}/')\n",
1011981
" logger = get_logger(config_name)\n",
1012982
" logger.info(f'Using Random Seed {SEED} for this experiment')\n",
1013-
" \n",
1014983
" get_logger('lightning.pytorch.accelerators.cuda', log_level='WARNING') # Disable cuda logging.\n",
1015984
" filterwarnings('ignore', r'.*Skipping val loop.*') # Disable val loop warning.\n",
1016-
" filterwarnings('ignore', r\".*Your `test_dataloader`'s sampler has shuffling enabled*\") # Disable val shuffle warning.\n",
1017-
"\n",
985+
" filterwarnings('ignore', r\".*Your `test_dataloader`'s sampler has shuffling enabled`*\") # Disable val shuffle warning.\n",
1018986
"\n",
987+
" # Create dataset manager for training data.\n",
988+
" data = {'TCGA_BLC': TCGA_Program_Dataset(**config['datasets'])}\n",
989+
" \n",
1019990
" #add the external data\n",
1020991
" external_testing_data = ExternalDataModule(**config['external_datasets']) \n",
1021992
"\n",
1022993
" external_testing_data.setup()\n",
1023994
"\n",
1024995
" external_testing_dataloader = external_testing_data.test_dataloader()\n",
1025-
"\n",
1026-
" # Create dataset manager for training data.\n",
1027-
" data = {'TCGA_BLC': TCGA_Program_Dataset(**config['datasets'])}\n",
1028-
" \n",
1029-
" \n",
1030996
" \n",
1031997
" if 'TCGA_Balanced_Datasets_Manager' == config['datasets_manager']['type']:\n",
1032998
" manager = TCGA_Balanced_Datasets_Manager(datasets=data, config=config_add_subdict_key(config))\n",
@@ -1048,13 +1014,14 @@
10481014
" )\n",
10491015
" \n",
10501016
" trainer.fit(lit_model, train_dataloaders=values['train'])\n",
1051-
" #trainer.test(lit_model, dataloaders=test, verbose=True) \n",
1017+
" \n",
10521018
" \n",
10531019
" \n",
10541020
" elif key == 'train':\n",
10551021
" train = values\n",
10561022
" elif key == 'test':\n",
1057-
" test = external_testing_dataloader\n",
1023+
" test = external_testing_dataloader #values\n",
1024+
"\n",
10581025
" # Train the final model from scratch with all the training data.\n",
10591026
" models, optimizers = create_models_and_optimizers(config)\n",
10601027
" lit_model = LitFullModel(models, optimizers, config)\n",
@@ -1069,11 +1036,8 @@
10691036
"\n",
10701037
" # Test the final model.\n",
10711038
" bootstrap_results = []\n",
1072-
" for _ in tqdm(range(config['bootstrap_repeats']), desc='Bootstrapping'):\n",
1073-
" \n",
1039+
" for _ in tqdm(range(config['bootstrap_repeats']), desc='Bootstrapping'): \n",
10741040
" bootstrap_results.append(trainer.test(lit_model, dataloaders=test, verbose=False)[0]) \n",
1075-
" \n",
1076-
"\n",
10771041
" bootstrap_results = pd.DataFrame.from_records(bootstrap_results)\n",
10781042
" for key, value in bootstrap_results.describe().loc[['mean', 'std']].to_dict().items():\n",
10791043
" logger.info(f'| {key.ljust(10).upper()} | {value[\"mean\"]:.5f} ± {value[\"std\"]:.5f} |')\n",
@@ -1094,32 +1058,14 @@
10941058
},
10951059
{
10961060
"cell_type": "code",
1097-
"execution_count": 7,
1061+
"execution_count": 17,
10981062
"metadata": {},
10991063
"outputs": [
11001064
{
11011065
"name": "stdout",
11021066
"output_type": "stream",
11031067
"text": [
1104-
"[INFO]\tUsing Random Seed 1126 for this experiment\n"
1105-
]
1106-
},
1107-
{
1108-
"name": "stdout",
1109-
"output_type": "stream",
1110-
"text": [
1111-
"[INFO]\tExternal DS - Total 88 patients\n",
1112-
"[INFO]\tNormalize clinical numerical data using all samples\n",
1113-
"[INFO]\tExternal DS - Total 88 samples after removing missing values\n",
1114-
"[INFO]\tExternal DS - Batch size 128\n",
1115-
"[INFO]\tExternal DS - Total 88 patients, 20 genomic features and 14 clinical features\n",
1116-
"[INFO]\tExternal DS - Target Type overall_survival\n",
1117-
"[INFO]\tNormalize clinical numerical data using all samples\n",
1118-
"[INFO]\tExternal DS - Total 88 samples after removing missing values\n",
1119-
"[INFO]\tExternal DS - Total 81 samples\n",
1120-
"[INFO]\tExternal DS - Total 39 features\n",
1121-
"[INFO]\tExternal DS - Overall survival imbalance ratio 81.48148148148148 %\n",
1122-
"[INFO]\tSplitting data into test set...\n",
1068+
"[INFO]\tUsing Random Seed 1126 for this experiment\n",
11231069
"[INFO]\tCreating a TCGA Program Dataset with 3 Projects...\n",
11241070
"Case metadata {}\n",
11251071
"[INFO]\tNo files to download for project TCGA-BRCA\n",
@@ -1160,8 +1106,20 @@
11601106
"[INFO]\tSaving train and test indices to Cache\n",
11611107
"[INFO]\tTotal 2059 patients, 20 genomic features and 14 clinical features\n",
11621108
"[INFO]\tOverall survival imbalance ratio 17.678484701311316 %\n",
1109+
"[INFO]\tExternal DS - Total 88 patients\n",
1110+
"[INFO]\tNormalize clinical numerical data using all samples\n",
1111+
"[INFO]\tExternal DS - Total 88 samples after removing missing values\n",
1112+
"[INFO]\tExternal DS - Batch size 128\n",
1113+
"[INFO]\tExternal DS - Total 88 patients, 20 genomic features and 14 clinical features\n",
1114+
"[INFO]\tExternal DS - Target Type overall_survival\n",
1115+
"[INFO]\tNormalize clinical numerical data using all samples\n",
1116+
"[INFO]\tExternal DS - Total 88 samples after removing missing values\n",
1117+
"[INFO]\tExternal DS - Total 81 samples\n",
1118+
"[INFO]\tExternal DS - Total 39 features\n",
1119+
"[INFO]\tExternal DS - Overall survival imbalance ratio 81.48148148148148 %\n",
1120+
"[INFO]\tSplitting data into test set...\n",
11631121
"[INFO]\tInitializing a TCGA Balanced Datasets Manager containing 1 Datasets...\n",
1164-
"[INFO]\tUsing indices cache files created at 2024-06-05 17:58:21 from Cache\n"
1122+
"[INFO]\tUsing indices cache files created at 2024-06-21 18:36:29 from Cache\n"
11651123
]
11661124
},
11671125
{
@@ -1172,14 +1130,14 @@
11721130
"TPU available: False, using: 0 TPU cores\n",
11731131
"IPU available: False, using: 0 IPUs\n",
11741132
"HPU available: False, using: 0 HPUs\n",
1175-
"Missing logger folder: Logs/MTL_train_SCLC_test/2024-06-05T17:57:57/lightning_logs\n"
1133+
"Missing logger folder: Logs/MTL_train_SCLC_test/2024-06-21T18:36:06/lightning_logs\n"
11761134
]
11771135
},
11781136
{
11791137
"name": "stdout",
11801138
"output_type": "stream",
11811139
"text": [
1182-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.87it/s, v_num=0]"
1140+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.20it/s, v_num=0]"
11831141
]
11841142
},
11851143
{
@@ -1193,7 +1151,7 @@
11931151
"name": "stdout",
11941152
"output_type": "stream",
11951153
"text": [
1196-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.84it/s, v_num=0]"
1154+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 35.83it/s, v_num=0]"
11971155
]
11981156
},
11991157
{
@@ -1211,7 +1169,7 @@
12111169
"output_type": "stream",
12121170
"text": [
12131171
"\n",
1214-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 26.08it/s, v_num=1]"
1172+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.85it/s, v_num=1]"
12151173
]
12161174
},
12171175
{
@@ -1225,7 +1183,7 @@
12251183
"name": "stdout",
12261184
"output_type": "stream",
12271185
"text": [
1228-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 25.93it/s, v_num=1]\n"
1186+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.48it/s, v_num=1]"
12291187
]
12301188
},
12311189
{
@@ -1242,7 +1200,8 @@
12421200
"name": "stdout",
12431201
"output_type": "stream",
12441202
"text": [
1245-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 12.11it/s, v_num=2]"
1203+
"\n",
1204+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.89it/s, v_num=2]"
12461205
]
12471206
},
12481207
{
@@ -1256,7 +1215,7 @@
12561215
"name": "stdout",
12571216
"output_type": "stream",
12581217
"text": [
1259-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 12.08it/s, v_num=2]"
1218+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.51it/s, v_num=2]"
12601219
]
12611220
},
12621221
{
@@ -1274,7 +1233,7 @@
12741233
"output_type": "stream",
12751234
"text": [
12761235
"\n",
1277-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.89it/s, v_num=3]"
1236+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 33.32it/s, v_num=3]"
12781237
]
12791238
},
12801239
{
@@ -1288,7 +1247,7 @@
12881247
"name": "stdout",
12891248
"output_type": "stream",
12901249
"text": [
1291-
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.86it/s, v_num=3]\n"
1250+
"Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 33.05it/s, v_num=3]"
12921251
]
12931252
},
12941253
{
@@ -1308,18 +1267,31 @@
13081267
"952 Trainable params\n",
13091268
"0 Non-trainable params\n",
13101269
"952 Total params\n",
1311-
"0.004 Total estimated model params size (MB)\n",
1270+
"0.004 Total estimated model params size (MB)\n"
1271+
]
1272+
},
1273+
{
1274+
"name": "stdout",
1275+
"output_type": "stream",
1276+
"text": [
1277+
"\n"
1278+
]
1279+
},
1280+
{
1281+
"name": "stderr",
1282+
"output_type": "stream",
1283+
"text": [
13121284
"`Trainer.fit` stopped: `max_epochs=50` reached.\n",
1313-
"Bootstrapping: 100%|██████████| 1000/1000 [14:56<00:00, 1.12it/s]"
1285+
"Bootstrapping: 100%|██████████| 1000/1000 [03:38<00:00, 4.57it/s]"
13141286
]
13151287
},
13161288
{
13171289
"name": "stdout",
13181290
"output_type": "stream",
13191291
"text": [
1320-
"[INFO]\t| AUC_2.0 | 0.41925 ± 0.08806 |\n",
1321-
"[INFO]\t| PRC_2.0 | 0.80238 ± 0.05609 |\n",
1322-
"[INFO]\t| C-INDEX_2.0 | 0.47474 ± 0.04398 |\n"
1292+
"[INFO]\t| AUC_1.0 | 0.50015 ± 0.08427 |\n",
1293+
"[INFO]\t| PRC_1.0 | 0.83636 ± 0.05430 |\n",
1294+
"[INFO]\t| C-INDEX_1.0 | 0.49382 ± 0.03842 |\n"
13231295
]
13241296
},
13251297
{
@@ -1336,35 +1308,6 @@
13361308
"trainer, train, test, models, optimizers, config, lit_model, logger = main(\"config/light/MTL_train_SCLC_test.yaml\")\n",
13371309
"\n"
13381310
]
1339-
},
1340-
{
1341-
"cell_type": "code",
1342-
"execution_count": 9,
1343-
"metadata": {},
1344-
"outputs": [],
1345-
"source": [
1346-
"# save the model\n",
1347-
"\n",
1348-
"torch.save(lit_model.state_dict(), 'model.pth')"
1349-
]
1350-
},
1351-
{
1352-
"cell_type": "code",
1353-
"execution_count": null,
1354-
"metadata": {},
1355-
"outputs": [],
1356-
"source": [
1357-
"# load model\n",
1358-
"\n",
1359-
"lit_model.load_state_dict(torch.load('model.pth'))\n"
1360-
]
1361-
},
1362-
{
1363-
"cell_type": "code",
1364-
"execution_count": null,
1365-
"metadata": {},
1366-
"outputs": [],
1367-
"source": []
13681311
}
13691312
],
13701313
"metadata": {
@@ -1383,7 +1326,7 @@
13831326
"name": "python",
13841327
"nbconvert_exporter": "python",
13851328
"pygments_lexer": "ipython3",
1386-
"version": "3.9.undefined"
1329+
"version": "3.9.17"
13871330
}
13881331
},
13891332
"nbformat": 4,

0 commit comments

Comments
 (0)