|
445 | 445 | },
|
446 | 446 | "outputs": [],
|
447 | 447 | "source": [
|
448 |
| - "def evaluate_and_save_results(models, architectures, batch_size=1024, num_workers=6, folder_name=\"results\"):\n", |
| 448 | + "def evaluate_and_save_results(\n", |
| 449 | + " models, architectures, valid_loader, folder_name=\"results\"\n", |
| 450 | + "):\n", |
| 451 | + " \"\"\"\n", |
| 452 | + " Оценивает модели на валидационном наборе данных и сохраняет результаты в файлы JSON.\n", |
| 453 | + " Аргументы:\n", |
| 454 | + " models (list): Список обученных моделей.\n", |
| 455 | + " architectures (list): Список архитектур моделей.\n", |
| 456 | + " valid_loader (DataLoader): DataLoader для валидационных данных.\n", |
| 457 | + " folder_name (str, необязательно): Имя папки для сохранения результатов. По умолчанию \"results\".\n", |
| 458 | + " Исключения:\n", |
| 459 | + " ValueError: Если количество моделей и архитектур не совпадает.\n", |
| 460 | + " Результаты:\n", |
| 461 | + " Для каждой модели создается файл JSON с результатами, содержащий:\n", |
| 462 | + " - architecture: Архитектура модели.\n", |
| 463 | + " - valid_predictions: Предсказания модели на валидационном наборе данных.\n", |
| 464 | + " - valid_accuracy: Точность модели на валидационном наборе данных.\n", |
| 465 | + " \"\"\"\n", |
449 | 466 | " if len(models) != len(architectures):\n",
|
450 | 467 | " raise ValueError(\"Количество моделей и архитектур должно совпадать\")\n",
|
451 | 468 | "\n",
|
452 | 469 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
453 | 470 | " os.makedirs(folder_name, exist_ok=True)\n",
|
454 | 471 | "\n",
|
455 |
| - " transform = transforms.Compose([\n", |
456 |
| - " transforms.ToTensor(),\n", |
457 |
| - " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),\n", |
458 |
| - " ])\n", |
459 |
| - " test_dataset = CIFAR10(root=\"./data\", train=False, download=True, transform=transform)\n", |
460 |
| - " test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n", |
461 |
| - "\n", |
462 | 472 | " for i, (model, architecture) in enumerate(zip(models, architectures)):\n",
|
463 | 473 | " model.to(device)\n",
|
464 | 474 | " model.eval()\n",
|
465 | 475 | "\n",
|
466 |
| - " test_correct = 0\n", |
467 |
| - " test_total = 0\n", |
468 |
| - " test_preds = []\n", |
| 476 | + " valid_correct = 0\n", |
| 477 | + " valid_total = 0\n", |
| 478 | + " valid_preds = []\n", |
469 | 479 | "\n",
|
470 |
| - " # Оценка на тестовом датасете\n", |
471 | 480 | " with torch.no_grad():\n",
|
472 |
| - " for images, labels in test_loader:\n", |
| 481 | + " for images, labels in valid_loader:\n", |
473 | 482 | " images, labels = images.to(device), labels.to(device)\n",
|
474 | 483 | " outputs = model(images)\n",
|
475 | 484 | " _, predicted = torch.max(outputs, 1)\n",
|
476 |
| - " test_preds.extend(predicted.cpu().tolist())\n", |
477 |
| - " test_correct += (predicted == labels).sum().item()\n", |
478 |
| - " test_total += labels.size(0)\n", |
479 |
| - " \n", |
480 |
| - " test_accuracy = test_correct / test_total\n", |
| 485 | + " valid_preds.extend(predicted.cpu().tolist())\n", |
| 486 | + " valid_correct += (predicted == labels).sum().item()\n", |
| 487 | + " valid_total += labels.size(0)\n", |
| 488 | + "\n", |
| 489 | + " valid_accuracy = valid_correct / valid_total\n", |
481 | 490 | "\n",
|
482 | 491 | " result = {\n",
|
483 | 492 | " \"architecture\": architecture,\n",
|
484 |
| - " \"test_predictions\": test_preds,\n", |
485 |
| - " \"test_accuracy\": test_accuracy\n", |
| 493 | + " \"valid_predictions\": valid_preds,\n", |
| 494 | + " \"valid_accuracy\": valid_accuracy,\n", |
486 | 495 | " }\n",
|
487 | 496 | "\n",
|
488 | 497 | " file_name = f\"model_{i+1}_results.json\"\n",
|
|
491 | 500 | " with open(file_path, \"w\") as f:\n",
|
492 | 501 | " json.dump(result, f, indent=4)\n",
|
493 | 502 | "\n",
|
494 |
| - " print(f\"Results for model_{i + 1} saved to {file_path}\")\n" |
| 503 | + " print(f\"Results for model_{i + 1} saved to {file_path}\")" |
495 | 504 | ]
|
496 | 505 | },
|
497 | 506 | {
|
|
630 | 639 | "sourceType": "notebook"
|
631 | 640 | },
|
632 | 641 | "kernelspec": {
|
633 |
| - "display_name": "Python 3", |
| 642 | + "display_name": "usr", |
634 | 643 | "language": "python",
|
635 | 644 | "name": "python3"
|
636 | 645 | },
|
|
644 | 653 | "name": "python",
|
645 | 654 | "nbconvert_exporter": "python",
|
646 | 655 | "pygments_lexer": "ipython3",
|
647 |
| - "version": "3.10.12" |
| 656 | + "version": "3.12.3" |
648 | 657 | }
|
649 | 658 | },
|
650 | 659 | "nbformat": 4,
|
|
0 commit comments