|
324 | 324 | "metadata": {}, |
325 | 325 | "outputs": [], |
326 | 326 | "source": [ |
327 | | - "actual_df.head()" |
| 327 | + "actual_df.head(8)" |
| 328 | + ] |
| 329 | + }, |
| 330 | + { |
| 331 | + "cell_type": "code", |
| 332 | + "execution_count": null, |
| 333 | + "metadata": {}, |
| 334 | + "outputs": [], |
| 335 | + "source": [ |
| 336 | + "# actual_as_pred_df = actual_df.copy()\n", |
| 337 | + "# actual_as_pred_df[\"PredictorName\"] = \"Ground truth\"\n", |
| 338 | + "# actual_as_pred_df[\"Prediction\"] = False\n", |
| 339 | + "# actual_as_pred_df = actual_as_pred_df.rename(columns={\"ActualDailyNewCases\": \"PredictedDailyNewCases\",\n", |
| 340 | + "# \"ActualDailyNewCases7DMA\": \"PredictedDailyNewCases7DMA\"})\n", |
| 341 | + "# actual_as_pred_df.head(8)" |
328 | 342 | ] |
329 | 343 | }, |
330 | 344 | { |
|
575 | 589 | "cr_df[(cr_df.CountryName.isin(NORTH_AMERICA)) & (cr_df.RegionName == \"\")]" |
576 | 590 | ] |
577 | 591 | }, |
| 592 | + { |
| 593 | + "cell_type": "markdown", |
| 594 | + "metadata": {}, |
| 595 | + "source": [ |
| 596 | + "# Plots" |
| 597 | + ] |
| 598 | + }, |
| 599 | + { |
| 600 | + "cell_type": "code", |
| 601 | + "execution_count": null, |
| 602 | + "metadata": {}, |
| 603 | + "outputs": [], |
| 604 | + "source": [ |
| 605 | + "default_country = \"Italy\"" |
| 606 | + ] |
| 607 | + }, |
| 608 | + { |
| 609 | + "cell_type": "markdown", |
| 610 | + "metadata": {}, |
| 611 | + "source": [ |
| 612 | + "## Prediction vs actual" |
| 613 | + ] |
| 614 | + }, |
| 615 | + { |
| 616 | + "cell_type": "code", |
| 617 | + "execution_count": null, |
| 618 | + "metadata": {}, |
| 619 | + "outputs": [], |
| 620 | + "source": [ |
| 621 | + "country_df = ranking_df[ranking_df.CountryName == default_country]" |
| 622 | + ] |
| 623 | + }, |
| 624 | + { |
| 625 | + "cell_type": "code", |
| 626 | + "execution_count": null, |
| 627 | + "metadata": {}, |
| 628 | + "outputs": [], |
| 629 | + "source": [ |
| 630 | + "predictor_names = list(country_df.PredictorName.unique())\n", |
| 631 | + "country_names = list(ranking_df.CountryName.unique())" |
| 632 | + ] |
| 633 | + }, |
| 634 | + { |
| 635 | + "cell_type": "code", |
| 636 | + "execution_count": null, |
| 637 | + "metadata": {}, |
| 638 | + "outputs": [], |
| 639 | + "source": [ |
| 640 | + "country_df[country_df[\"PredictorName\"] == 'Predictor #27']" |
| 641 | + ] |
| 642 | + }, |
| 643 | + { |
| 644 | + "cell_type": "code", |
| 645 | + "execution_count": null, |
| 646 | + "metadata": {}, |
| 647 | + "outputs": [], |
| 648 | + "source": [ |
| 649 | + "import plotly.graph_objects as go\n", |
| 650 | + "\n", |
| 651 | + "fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases 7-day Moving Average in {default_country}',\n", |
| 652 | + " y=0.9,\n", |
| 653 | + " x=0.5,\n", |
| 654 | + " xanchor='center',\n", |
| 655 | + " yanchor='top'\n", |
| 656 | + " ),\n", |
| 657 | + " plot_bgcolor='#f2f2f2',\n", |
| 658 | + " xaxis_title=\"Date\",\n", |
| 659 | + " yaxis_title=\"New Cases\"\n", |
| 660 | + " ))\n", |
| 661 | + "\n", |
| 662 | + "# Add 1 trace per predictor\n", |
| 663 | + "for predictor_name in predictor_names:\n", |
| 664 | + " pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n", |
| 665 | + " fig.add_trace(go.Scatter(x=pred_country_df.Date,\n", |
| 666 | + " y=pred_country_df.PredictedDailyNewCases7DMA,\n", |
| 667 | + " name=predictor_name)\n", |
| 668 | + " )\n", |
| 669 | + "\n", |
| 670 | + "# Add 1 trace for the true number of cases\n", |
| 671 | + "country_actual_df = actual_df[(actual_df.CountryName == default_country) &\n", |
| 672 | + " (actual_df.Date >= start_date)]\n", |
| 673 | + "fig.add_trace(go.Scatter(x=country_actual_df.Date,\n", |
| 674 | + " y=country_actual_df.ActualDailyNewCases7DMA,\n", |
| 675 | + " name=\"Ground Truth\",\n", |
| 676 | + " line=dict(color='orange', width=4, dash='dash'))\n", |
| 677 | + " )\n", |
| 678 | + "# Format x axis\n", |
| 679 | + "fig.update_xaxes(\n", |
| 680 | + "dtick=\"D1\", # Means 1 day\n", |
| 681 | + "tickformat=\"%d\\n%b\")\n", |
| 682 | + "\n", |
| 683 | + "fig.show()" |
| 684 | + ] |
| 685 | + }, |
| 686 | + { |
| 687 | + "cell_type": "markdown", |
| 688 | + "metadata": {}, |
| 689 | + "source": [ |
| 690 | + "## Filter by country" |
| 691 | + ] |
| 692 | + }, |
| 693 | + { |
| 694 | + "cell_type": "code", |
| 695 | + "execution_count": null, |
| 696 | + "metadata": {}, |
| 697 | + "outputs": [], |
| 698 | + "source": [ |
| 699 | + "import plotly.graph_objects as go\n", |
| 700 | + "\n", |
| 701 | + "fig = go.Figure(layout=dict(title=dict(text=f'Predicted 7-day Moving Average of New Cases in {default_country}',\n", |
| 702 | + " y=0.9,\n", |
| 703 | + " x=0.5,\n", |
| 704 | + " xanchor='center',\n", |
| 705 | + " yanchor='top'\n", |
| 706 | + " ),\n", |
| 707 | + " plot_bgcolor='#f2f2f2',\n", |
| 708 | + " xaxis_title=\"Date\",\n", |
| 709 | + " yaxis_title=\"New Cases\"\n", |
| 710 | + " ))\n", |
| 711 | + "\n", |
| 712 | + "# Keep track of trace visibility by country name\n", |
| 713 | + "country_plot_names = []\n", |
| 714 | + "\n", |
| 715 | + "# Add 1 trace per predictor, per country\n", |
| 716 | + "for predictor_name in predictor_names:\n", |
| 717 | + " for country_name in country_names:\n", |
| 718 | + " country_df = ranking_df[ranking_df.CountryName == country_name]\n", |
| 719 | + " pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n", |
| 720 | + " fig.add_trace(go.Scatter(x=pred_country_df.Date,\n", |
| 721 | + " y=pred_country_df.PredictedDailyNewCases7DMA,\n", |
| 722 | + " name=predictor_name,\n", |
| 723 | + " visible= (country_name == default_country))\n", |
| 724 | + " )\n", |
| 725 | + " country_plot_names.append(country_name)\n", |
| 726 | + "\n", |
| 727 | + "# For each country\n", |
| 728 | + "# Add 1 trace for the true number of cases\n", |
| 729 | + "for country_name in country_names:\n", |
| 730 | + " country_actual_df = actual_df[(actual_df.CountryName == country_name) &\n", |
| 731 | + " (actual_df.Date >= start_date)]\n", |
| 732 | + " fig.add_trace(go.Scatter(x=country_actual_df.Date,\n", |
| 733 | + " y=country_actual_df.ActualDailyNewCases7DMA,\n", |
| 734 | + " name=\"Ground Truth\",\n", |
| 735 | + " visible= (country_name == default_country),\n", |
| 736 | + " line=dict(color='orange', width=4, dash='dash'))\n", |
| 737 | + " )\n", |
| 738 | + " country_plot_names.append(country_name)\n", |
| 739 | + "\n", |
| 740 | + "# Format x axis\n", |
| 741 | + "fig.update_xaxes(\n", |
| 742 | + "dtick=\"D1\", # Means 1 day\n", |
| 743 | + "tickformat=\"%d\\n%b\")\n", |
| 744 | + "\n", |
| 745 | + "# Filter\n", |
| 746 | + "buttons=[]\n", |
| 747 | + "for country_name in country_names:\n", |
| 748 | + " buttons.append(dict(method='update',\n", |
| 749 | + " label=country_name,\n", |
| 750 | + " args = [{'visible': [country_name==r for r in country_plot_names]},\n", |
| 751 | + " {'title': \"Predicted 7-day Moving Average of New Cases in \" + country_name}]))\n", |
| 752 | + "fig.update_layout(showlegend=True,\n", |
| 753 | + " updatemenus=[{\"buttons\": buttons,\n", |
| 754 | + " \"direction\": \"down\",\n", |
| 755 | + " \"active\": country_names.index(default_country),\n", |
| 756 | + " \"showactive\": True,\n", |
| 757 | + " \"x\": 0.1,\n", |
| 758 | + " \"y\": 1.15}])\n", |
| 759 | + "\n", |
| 760 | + "fig.show()" |
| 761 | + ] |
| 762 | + }, |
| 763 | + { |
| 764 | + "cell_type": "markdown", |
| 765 | + "metadata": {}, |
| 766 | + "source": [ |
| 767 | + "## Daily diff in 7 days moving average" |
| 768 | + ] |
| 769 | + }, |
| 770 | + { |
| 771 | + "cell_type": "code", |
| 772 | + "execution_count": null, |
| 773 | + "metadata": {}, |
| 774 | + "outputs": [], |
| 775 | + "source": [ |
| 776 | + "fig = px.line(ranking_df[ranking_df.CountryName == \"Italy\"], x=\"Date\", y=\"Diff7DMA\", color='PredictorName')\n", |
| 777 | + "fig.show()" |
| 778 | + ] |
| 779 | + }, |
| 780 | + { |
| 781 | + "cell_type": "code", |
| 782 | + "execution_count": null, |
| 783 | + "metadata": {}, |
| 784 | + "outputs": [], |
| 785 | + "source": [] |
| 786 | + }, |
578 | 787 | { |
579 | 788 | "cell_type": "code", |
580 | 789 | "execution_count": null, |
|
0 commit comments