|
32 | 32 | },
|
33 | 33 | {
|
34 | 34 | "cell_type": "code",
|
35 |
| - "execution_count": 2, |
| 35 | + "execution_count": 1, |
36 | 36 | "metadata": {},
|
37 | 37 | "outputs": [],
|
38 | 38 | "source": [
|
|
115 | 115 | "output_type": "stream",
|
116 | 116 | "text": [
|
117 | 117 | "BRCA\n",
|
118 |
| - "{'ACTB', 'EFTUD2', 'PLAU', 'DDX23', 'GSK3B', 'HSPA8', 'MKI67', 'ESR1', 'SHMT2', 'ERBB2', 'STAU1', 'SSR1', 'UBXN6', 'PRKRA', 'TUBA1C', 'SNIP1', 'YWHAB', 'PGR', 'BTRC', 'SRSF5'} 20\n", |
| 118 | + "{'SHMT2', 'PLAU', 'TUBA1C', 'STAU1', 'MKI67', 'HSPA8', 'ESR1', 'PRKRA', 'DDX23', 'YWHAB', 'GSK3B', 'PGR', 'UBXN6', 'SNIP1', 'ACTB', 'BTRC', 'SRSF5', 'ERBB2', 'EFTUD2', 'SSR1'} 20\n", |
119 | 119 | "LUAD\n",
|
120 |
| - "{'SSR1', 'PUM1', 'SERBP1', 'SLC2A1', 'CADM1', 'ALCAM', 'HNRNPU', 'PRKRA', 'PTK7', 'KDM1A', 'KRR1', 'STAU1', 'CDC73', 'OCIAD1', 'HIF1A', 'DHX9', 'CLTC', 'EPCAM'} 18\n", |
| 120 | + "{'PTK7', 'KRR1', 'OCIAD1', 'SLC2A1', 'DHX9', 'CLTC', 'PRKRA', 'EPCAM', 'KDM1A', 'STAU1', 'HIF1A', 'SERBP1', 'HNRNPU', 'CADM1', 'ALCAM', 'PUM1', 'SSR1', 'CDC73'} 18\n", |
121 | 121 | "COAD\n",
|
122 |
| - "{'PROM1', 'CD44', 'HNRNPK', 'ZBTB2', 'TFCP2', 'RNF4', 'SERBP1', 'HNRNPR', 'EPCAM', 'ABCG2', 'HNRNPU', 'ABCB1', 'HNRNPL', 'DHX9', 'RPL4', 'PUM1', 'ALCAM', 'ALDH1A1', 'HNRNPA1', 'ABCC1'} 20\n", |
| 122 | + "{'ZBTB2', 'DHX9', 'HNRNPL', 'SERBP1', 'ABCC1', 'HNRNPA1', 'ABCG2', 'HNRNPR', 'RNF4', 'ABCB1', 'HNRNPU', 'RPL4', 'TFCP2', 'CD44', 'PROM1', 'EPCAM', 'PUM1', 'HNRNPK', 'ALDH1A1', 'ALCAM'} 20\n", |
123 | 123 | "all three\n",
|
124 | 124 | "set()\n"
|
125 | 125 | ]
|
|
875 | 875 | },
|
876 | 876 | {
|
877 | 877 | "cell_type": "code",
|
878 |
| - "execution_count": 3, |
| 878 | + "execution_count": 14, |
879 | 879 | "metadata": {},
|
880 | 880 | "outputs": [
|
881 | 881 | {
|
|
921 | 921 | },
|
922 | 922 | {
|
923 | 923 | "cell_type": "code",
|
924 |
| - "execution_count": 4, |
| 924 | + "execution_count": 15, |
925 | 925 | "metadata": {},
|
926 | 926 | "outputs": [],
|
927 | 927 | "source": [
|
|
959 | 959 | },
|
960 | 960 | {
|
961 | 961 | "cell_type": "code",
|
962 |
| - "execution_count": 5, |
963 |
| - "metadata": {}, |
964 |
| - "outputs": [], |
965 |
| - "source": [ |
966 |
| - "def cross_validation(manager, config, log_path, external_testing_dataloader):\n", |
967 |
| - " for key, values in manager['TCGA_BLC']['dataloaders'].items():\n", |
968 |
| - " if isinstance(key, int) and config['cross_validation']:\n", |
969 |
| - " models, optimizers = create_models_and_optimizers(config)\n", |
970 |
| - " lit_model = LitFullModel(models, optimizers, config)\n", |
971 |
| - " trainer = pl.Trainer( # Create sub-folders for each fold.\n", |
972 |
| - " default_root_dir=log_path,\n", |
973 |
| - " max_epochs=config['max_epochs'],\n", |
974 |
| - " log_every_n_steps=1,\n", |
975 |
| - " enable_model_summary=False,\n", |
976 |
| - " enable_checkpointing=False,\n", |
977 |
| - " \n", |
978 |
| - " )\n", |
979 |
| - " \n", |
980 |
| - " trainer.fit(lit_model, train_dataloaders=values['train'])\n", |
981 |
| - " \n", |
982 |
| - " \n", |
983 |
| - " elif key == 'train':\n", |
984 |
| - " train = values\n", |
985 |
| - " elif key == 'test':\n", |
986 |
| - " test = external_testing_dataloader #values\n", |
987 |
| - " return train, test" |
988 |
| - ] |
989 |
| - }, |
990 |
| - { |
991 |
| - "cell_type": "code", |
992 |
| - "execution_count": 6, |
| 962 | + "execution_count": 16, |
993 | 963 | "metadata": {},
|
994 | 964 | "outputs": [],
|
995 | 965 | "source": [
|
|
1001 | 971 | " \n",
|
1002 | 972 | " with open(config_file, 'r') as f:\n",
|
1003 | 973 | " config = yaml.load(f, Loader=yaml.FullLoader)\n",
|
1004 |
| - " override_n_genes(config) \n", |
| 974 | + " override_n_genes(config) # For multi-task graph models.\n", |
1005 | 975 | " config_name = Path(config_file).stem\n",
|
1006 | 976 | "\n",
|
1007 |
| - " # Setup logging.\n", |
| 977 | + " # Setup logging.\n", |
1008 | 978 | " log_path = f'Logs/{config_name}/{datetime.now():%Y-%m-%dT%H:%M:%S}/'\n",
|
1009 | 979 | " setup_logging(log_path)\n",
|
1010 |
| - " \n", |
| 980 | + " #setup_logging(log_path := f'Logs/{config_name}/{datetime.now():%Y-%m-%dT%H:%M:%S}/')\n", |
1011 | 981 | " logger = get_logger(config_name)\n",
|
1012 | 982 | " logger.info(f'Using Random Seed {SEED} for this experiment')\n",
|
1013 |
| - " \n", |
1014 | 983 | " get_logger('lightning.pytorch.accelerators.cuda', log_level='WARNING') # Disable cuda logging.\n",
|
1015 | 984 | " filterwarnings('ignore', r'.*Skipping val loop.*') # Disable val loop warning.\n",
|
1016 |
| - " filterwarnings('ignore', r\".*Your `test_dataloader`'s sampler has shuffling enabled*\") # Disable val shuffle warning.\n", |
1017 |
| - "\n", |
| 985 | + " filterwarnings('ignore', r\".*Your `test_dataloader`'s sampler has shuffling enabled`*\") # Disable val shuffle warning.\n", |
1018 | 986 | "\n",
|
| 987 | + " # Create dataset manager for training data.\n", |
| 988 | + " data = {'TCGA_BLC': TCGA_Program_Dataset(**config['datasets'])}\n", |
| 989 | + " \n", |
1019 | 990 | " #add the external data\n",
|
1020 | 991 | " external_testing_data = ExternalDataModule(**config['external_datasets']) \n",
|
1021 | 992 | "\n",
|
1022 | 993 | " external_testing_data.setup()\n",
|
1023 | 994 | "\n",
|
1024 | 995 | " external_testing_dataloader = external_testing_data.test_dataloader()\n",
|
1025 |
| - "\n", |
1026 |
| - " # Create dataset manager for training data.\n", |
1027 |
| - " data = {'TCGA_BLC': TCGA_Program_Dataset(**config['datasets'])}\n", |
1028 |
| - " \n", |
1029 |
| - " \n", |
1030 | 996 | " \n",
|
1031 | 997 | " if 'TCGA_Balanced_Datasets_Manager' == config['datasets_manager']['type']:\n",
|
1032 | 998 | " manager = TCGA_Balanced_Datasets_Manager(datasets=data, config=config_add_subdict_key(config))\n",
|
|
1048 | 1014 | " )\n",
|
1049 | 1015 | " \n",
|
1050 | 1016 | " trainer.fit(lit_model, train_dataloaders=values['train'])\n",
|
1051 |
| - " #trainer.test(lit_model, dataloaders=test, verbose=True) \n", |
| 1017 | + " \n", |
1052 | 1018 | " \n",
|
1053 | 1019 | " \n",
|
1054 | 1020 | " elif key == 'train':\n",
|
1055 | 1021 | " train = values\n",
|
1056 | 1022 | " elif key == 'test':\n",
|
1057 |
| - " test = external_testing_dataloader\n", |
| 1023 | + " test = external_testing_dataloader #values\n", |
| 1024 | + "\n", |
1058 | 1025 | " # Train the final model from scratch with all the training data.\n",
|
1059 | 1026 | " models, optimizers = create_models_and_optimizers(config)\n",
|
1060 | 1027 | " lit_model = LitFullModel(models, optimizers, config)\n",
|
|
1069 | 1036 | "\n",
|
1070 | 1037 | " # Test the final model.\n",
|
1071 | 1038 | " bootstrap_results = []\n",
|
1072 |
| - " for _ in tqdm(range(config['bootstrap_repeats']), desc='Bootstrapping'):\n", |
1073 |
| - " \n", |
| 1039 | + " for _ in tqdm(range(config['bootstrap_repeats']), desc='Bootstrapping'): \n", |
1074 | 1040 | " bootstrap_results.append(trainer.test(lit_model, dataloaders=test, verbose=False)[0]) \n",
|
1075 |
| - " \n", |
1076 |
| - "\n", |
1077 | 1041 | " bootstrap_results = pd.DataFrame.from_records(bootstrap_results)\n",
|
1078 | 1042 | " for key, value in bootstrap_results.describe().loc[['mean', 'std']].to_dict().items():\n",
|
1079 | 1043 | " logger.info(f'| {key.ljust(10).upper()} | {value[\"mean\"]:.5f} ± {value[\"std\"]:.5f} |')\n",
|
|
1094 | 1058 | },
|
1095 | 1059 | {
|
1096 | 1060 | "cell_type": "code",
|
1097 |
| - "execution_count": 7, |
| 1061 | + "execution_count": 17, |
1098 | 1062 | "metadata": {},
|
1099 | 1063 | "outputs": [
|
1100 | 1064 | {
|
1101 | 1065 | "name": "stdout",
|
1102 | 1066 | "output_type": "stream",
|
1103 | 1067 | "text": [
|
1104 |
| - "[INFO]\tUsing Random Seed 1126 for this experiment\n" |
1105 |
| - ] |
1106 |
| - }, |
1107 |
| - { |
1108 |
| - "name": "stdout", |
1109 |
| - "output_type": "stream", |
1110 |
| - "text": [ |
1111 |
| - "[INFO]\tExternal DS - Total 88 patients\n", |
1112 |
| - "[INFO]\tNormalize clinical numerical data using all samples\n", |
1113 |
| - "[INFO]\tExternal DS - Total 88 samples after removing missing values\n", |
1114 |
| - "[INFO]\tExternal DS - Batch size 128\n", |
1115 |
| - "[INFO]\tExternal DS - Total 88 patients, 20 genomic features and 14 clinical features\n", |
1116 |
| - "[INFO]\tExternal DS - Target Type overall_survival\n", |
1117 |
| - "[INFO]\tNormalize clinical numerical data using all samples\n", |
1118 |
| - "[INFO]\tExternal DS - Total 88 samples after removing missing values\n", |
1119 |
| - "[INFO]\tExternal DS - Total 81 samples\n", |
1120 |
| - "[INFO]\tExternal DS - Total 39 features\n", |
1121 |
| - "[INFO]\tExternal DS - Overall survival imbalance ratio 81.48148148148148 %\n", |
1122 |
| - "[INFO]\tSplitting data into test set...\n", |
| 1068 | + "[INFO]\tUsing Random Seed 1126 for this experiment\n", |
1123 | 1069 | "[INFO]\tCreating a TCGA Program Dataset with 3 Projects...\n",
|
1124 | 1070 | "Case metadata {}\n",
|
1125 | 1071 | "[INFO]\tNo files to download for project TCGA-BRCA\n",
|
|
1160 | 1106 | "[INFO]\tSaving train and test indices to Cache\n",
|
1161 | 1107 | "[INFO]\tTotal 2059 patients, 20 genomic features and 14 clinical features\n",
|
1162 | 1108 | "[INFO]\tOverall survival imbalance ratio 17.678484701311316 %\n",
|
| 1109 | + "[INFO]\tExternal DS - Total 88 patients\n", |
| 1110 | + "[INFO]\tNormalize clinical numerical data using all samples\n", |
| 1111 | + "[INFO]\tExternal DS - Total 88 samples after removing missing values\n", |
| 1112 | + "[INFO]\tExternal DS - Batch size 128\n", |
| 1113 | + "[INFO]\tExternal DS - Total 88 patients, 20 genomic features and 14 clinical features\n", |
| 1114 | + "[INFO]\tExternal DS - Target Type overall_survival\n", |
| 1115 | + "[INFO]\tNormalize clinical numerical data using all samples\n", |
| 1116 | + "[INFO]\tExternal DS - Total 88 samples after removing missing values\n", |
| 1117 | + "[INFO]\tExternal DS - Total 81 samples\n", |
| 1118 | + "[INFO]\tExternal DS - Total 39 features\n", |
| 1119 | + "[INFO]\tExternal DS - Overall survival imbalance ratio 81.48148148148148 %\n", |
| 1120 | + "[INFO]\tSplitting data into test set...\n", |
1163 | 1121 | "[INFO]\tInitializing a TCGA Balanced Datasets Manager containing 1 Datasets...\n",
|
1164 |
| - "[INFO]\tUsing indices cache files created at 2024-06-05 17:58:21 from Cache\n" |
| 1122 | + "[INFO]\tUsing indices cache files created at 2024-06-21 18:36:29 from Cache\n" |
1165 | 1123 | ]
|
1166 | 1124 | },
|
1167 | 1125 | {
|
|
1172 | 1130 | "TPU available: False, using: 0 TPU cores\n",
|
1173 | 1131 | "IPU available: False, using: 0 IPUs\n",
|
1174 | 1132 | "HPU available: False, using: 0 HPUs\n",
|
1175 |
| - "Missing logger folder: Logs/MTL_train_SCLC_test/2024-06-05T17:57:57/lightning_logs\n" |
| 1133 | + "Missing logger folder: Logs/MTL_train_SCLC_test/2024-06-21T18:36:06/lightning_logs\n" |
1176 | 1134 | ]
|
1177 | 1135 | },
|
1178 | 1136 | {
|
1179 | 1137 | "name": "stdout",
|
1180 | 1138 | "output_type": "stream",
|
1181 | 1139 | "text": [
|
1182 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.87it/s, v_num=0]" |
| 1140 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.20it/s, v_num=0]" |
1183 | 1141 | ]
|
1184 | 1142 | },
|
1185 | 1143 | {
|
|
1193 | 1151 | "name": "stdout",
|
1194 | 1152 | "output_type": "stream",
|
1195 | 1153 | "text": [
|
1196 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.84it/s, v_num=0]" |
| 1154 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 35.83it/s, v_num=0]" |
1197 | 1155 | ]
|
1198 | 1156 | },
|
1199 | 1157 | {
|
|
1211 | 1169 | "output_type": "stream",
|
1212 | 1170 | "text": [
|
1213 | 1171 | "\n",
|
1214 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 26.08it/s, v_num=1]" |
| 1172 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.85it/s, v_num=1]" |
1215 | 1173 | ]
|
1216 | 1174 | },
|
1217 | 1175 | {
|
|
1225 | 1183 | "name": "stdout",
|
1226 | 1184 | "output_type": "stream",
|
1227 | 1185 | "text": [
|
1228 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 25.93it/s, v_num=1]\n" |
| 1186 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.48it/s, v_num=1]" |
1229 | 1187 | ]
|
1230 | 1188 | },
|
1231 | 1189 | {
|
|
1242 | 1200 | "name": "stdout",
|
1243 | 1201 | "output_type": "stream",
|
1244 | 1202 | "text": [
|
1245 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 12.11it/s, v_num=2]" |
| 1203 | + "\n", |
| 1204 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.89it/s, v_num=2]" |
1246 | 1205 | ]
|
1247 | 1206 | },
|
1248 | 1207 | {
|
|
1256 | 1215 | "name": "stdout",
|
1257 | 1216 | "output_type": "stream",
|
1258 | 1217 | "text": [
|
1259 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 12.08it/s, v_num=2]" |
| 1218 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 36.51it/s, v_num=2]" |
1260 | 1219 | ]
|
1261 | 1220 | },
|
1262 | 1221 | {
|
|
1274 | 1233 | "output_type": "stream",
|
1275 | 1234 | "text": [
|
1276 | 1235 | "\n",
|
1277 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.89it/s, v_num=3]" |
| 1236 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 33.32it/s, v_num=3]" |
1278 | 1237 | ]
|
1279 | 1238 | },
|
1280 | 1239 | {
|
|
1288 | 1247 | "name": "stdout",
|
1289 | 1248 | "output_type": "stream",
|
1290 | 1249 | "text": [
|
1291 |
| - "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 11.86it/s, v_num=3]\n" |
| 1250 | + "Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 33.05it/s, v_num=3]" |
1292 | 1251 | ]
|
1293 | 1252 | },
|
1294 | 1253 | {
|
|
1308 | 1267 | "952 Trainable params\n",
|
1309 | 1268 | "0 Non-trainable params\n",
|
1310 | 1269 | "952 Total params\n",
|
1311 |
| - "0.004 Total estimated model params size (MB)\n", |
| 1270 | + "0.004 Total estimated model params size (MB)\n" |
| 1271 | + ] |
| 1272 | + }, |
| 1273 | + { |
| 1274 | + "name": "stdout", |
| 1275 | + "output_type": "stream", |
| 1276 | + "text": [ |
| 1277 | + "\n" |
| 1278 | + ] |
| 1279 | + }, |
| 1280 | + { |
| 1281 | + "name": "stderr", |
| 1282 | + "output_type": "stream", |
| 1283 | + "text": [ |
1312 | 1284 | "`Trainer.fit` stopped: `max_epochs=50` reached.\n",
|
1313 |
| - "Bootstrapping: 100%|██████████| 1000/1000 [14:56<00:00, 1.12it/s]" |
| 1285 | + "Bootstrapping: 100%|██████████| 1000/1000 [03:38<00:00, 4.57it/s]" |
1314 | 1286 | ]
|
1315 | 1287 | },
|
1316 | 1288 | {
|
1317 | 1289 | "name": "stdout",
|
1318 | 1290 | "output_type": "stream",
|
1319 | 1291 | "text": [
|
1320 |
| - "[INFO]\t| AUC_2.0 | 0.41925 ± 0.08806 |\n", |
1321 |
| - "[INFO]\t| PRC_2.0 | 0.80238 ± 0.05609 |\n", |
1322 |
| - "[INFO]\t| C-INDEX_2.0 | 0.47474 ± 0.04398 |\n" |
| 1292 | + "[INFO]\t| AUC_1.0 | 0.50015 ± 0.08427 |\n", |
| 1293 | + "[INFO]\t| PRC_1.0 | 0.83636 ± 0.05430 |\n", |
| 1294 | + "[INFO]\t| C-INDEX_1.0 | 0.49382 ± 0.03842 |\n" |
1323 | 1295 | ]
|
1324 | 1296 | },
|
1325 | 1297 | {
|
|
1336 | 1308 | "trainer, train, test, models, optimizers, config, lit_model, logger = main(\"config/light/MTL_train_SCLC_test.yaml\")\n",
|
1337 | 1309 | "\n"
|
1338 | 1310 | ]
|
1339 |
| - }, |
1340 |
| - { |
1341 |
| - "cell_type": "code", |
1342 |
| - "execution_count": 9, |
1343 |
| - "metadata": {}, |
1344 |
| - "outputs": [], |
1345 |
| - "source": [ |
1346 |
| - "# save the model\n", |
1347 |
| - "\n", |
1348 |
| - "torch.save(lit_model.state_dict(), 'model.pth')" |
1349 |
| - ] |
1350 |
| - }, |
1351 |
| - { |
1352 |
| - "cell_type": "code", |
1353 |
| - "execution_count": null, |
1354 |
| - "metadata": {}, |
1355 |
| - "outputs": [], |
1356 |
| - "source": [ |
1357 |
| - "# load model\n", |
1358 |
| - "\n", |
1359 |
| - "lit_model.load_state_dict(torch.load('model.pth'))\n" |
1360 |
| - ] |
1361 |
| - }, |
1362 |
| - { |
1363 |
| - "cell_type": "code", |
1364 |
| - "execution_count": null, |
1365 |
| - "metadata": {}, |
1366 |
| - "outputs": [], |
1367 |
| - "source": [] |
1368 | 1311 | }
|
1369 | 1312 | ],
|
1370 | 1313 | "metadata": {
|
|
1383 | 1326 | "name": "python",
|
1384 | 1327 | "nbconvert_exporter": "python",
|
1385 | 1328 | "pygments_lexer": "ipython3",
|
1386 |
| - "version": "3.9.undefined" |
| 1329 | + "version": "3.9.17" |
1387 | 1330 | }
|
1388 | 1331 | },
|
1389 | 1332 | "nbformat": 4,
|
|
0 commit comments