Skip to content

Commit 650cbd6

Browse files
authored
Merge pull request #17 from leaf-ai/plotly
Plotly
2 parents 0617d98 + a88a231 commit 650cbd6

File tree

2 files changed

+211
-1
lines changed

2 files changed

+211
-1
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ notebook==6.1.4
44
scikit-learn==0.23.2
55
tensorflow==2.3.0
66
keras==2.4.3
7+
plotly==4.9.0

robojudge.ipynb

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,21 @@
324324
"metadata": {},
325325
"outputs": [],
326326
"source": [
327-
"actual_df.head()"
327+
"actual_df.head(8)"
328+
]
329+
},
330+
{
331+
"cell_type": "code",
332+
"execution_count": null,
333+
"metadata": {},
334+
"outputs": [],
335+
"source": [
336+
"# actual_as_pred_df = actual_df.copy()\n",
337+
"# actual_as_pred_df[\"PredictorName\"] = \"Ground truth\"\n",
338+
"# actual_as_pred_df[\"Prediction\"] = False\n",
339+
"# actual_as_pred_df = actual_as_pred_df.rename(columns={\"ActualDailyNewCases\": \"PredictedDailyNewCases\",\n",
340+
"# \"ActualDailyNewCases7DMA\": \"PredictedDailyNewCases7DMA\"})\n",
341+
"# actual_as_pred_df.head(8)"
328342
]
329343
},
330344
{
@@ -575,6 +589,201 @@
575589
"cr_df[(cr_df.CountryName.isin(NORTH_AMERICA)) & (cr_df.RegionName == \"\")]"
576590
]
577591
},
592+
{
593+
"cell_type": "markdown",
594+
"metadata": {},
595+
"source": [
596+
"# Plots"
597+
]
598+
},
599+
{
600+
"cell_type": "code",
601+
"execution_count": null,
602+
"metadata": {},
603+
"outputs": [],
604+
"source": [
605+
"default_country = \"Italy\""
606+
]
607+
},
608+
{
609+
"cell_type": "markdown",
610+
"metadata": {},
611+
"source": [
612+
"## Prediction vs actual"
613+
]
614+
},
615+
{
616+
"cell_type": "code",
617+
"execution_count": null,
618+
"metadata": {},
619+
"outputs": [],
620+
"source": [
621+
"country_df = ranking_df[ranking_df.CountryName == default_country]"
622+
]
623+
},
624+
{
625+
"cell_type": "code",
626+
"execution_count": null,
627+
"metadata": {},
628+
"outputs": [],
629+
"source": [
630+
"predictor_names = list(country_df.PredictorName.unique())\n",
631+
"country_names = list(ranking_df.CountryName.unique())"
632+
]
633+
},
634+
{
635+
"cell_type": "code",
636+
"execution_count": null,
637+
"metadata": {},
638+
"outputs": [],
639+
"source": [
640+
"country_df[country_df[\"PredictorName\"] == 'Predictor #27']"
641+
]
642+
},
643+
{
644+
"cell_type": "code",
645+
"execution_count": null,
646+
"metadata": {},
647+
"outputs": [],
648+
"source": [
649+
"import plotly.graph_objects as go\n",
650+
"\n",
651+
"fig = go.Figure(layout=dict(title=dict(text=f'Predicted New Cases 7-day Moving Average in {default_country}',\n",
652+
" y=0.9,\n",
653+
" x=0.5,\n",
654+
" xanchor='center',\n",
655+
" yanchor='top'\n",
656+
" ),\n",
657+
" plot_bgcolor='#f2f2f2',\n",
658+
" xaxis_title=\"Date\",\n",
659+
" yaxis_title=\"New Cases\"\n",
660+
" ))\n",
661+
"\n",
662+
"# Add 1 trace per predictor\n",
663+
"for predictor_name in predictor_names:\n",
664+
" pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n",
665+
" fig.add_trace(go.Scatter(x=pred_country_df.Date,\n",
666+
" y=pred_country_df.PredictedDailyNewCases7DMA,\n",
667+
" name=predictor_name)\n",
668+
" )\n",
669+
"\n",
670+
"# Add 1 trace for the true number of cases\n",
671+
"country_actual_df = actual_df[(actual_df.CountryName == default_country) &\n",
672+
" (actual_df.Date >= start_date)]\n",
673+
"fig.add_trace(go.Scatter(x=country_actual_df.Date,\n",
674+
" y=country_actual_df.ActualDailyNewCases7DMA,\n",
675+
" name=\"Ground Truth\",\n",
676+
" line=dict(color='orange', width=4, dash='dash'))\n",
677+
" )\n",
678+
"# Format x axis\n",
679+
"fig.update_xaxes(\n",
680+
"dtick=\"D1\", # Means 1 day\n",
681+
"tickformat=\"%d\\n%b\")\n",
682+
"\n",
683+
"fig.show()"
684+
]
685+
},
686+
{
687+
"cell_type": "markdown",
688+
"metadata": {},
689+
"source": [
690+
"## Filter by country"
691+
]
692+
},
693+
{
694+
"cell_type": "code",
695+
"execution_count": null,
696+
"metadata": {},
697+
"outputs": [],
698+
"source": [
699+
"import plotly.graph_objects as go\n",
700+
"\n",
701+
"fig = go.Figure(layout=dict(title=dict(text=f'Predicted 7-day Moving Average of New Cases in {default_country}',\n",
702+
" y=0.9,\n",
703+
" x=0.5,\n",
704+
" xanchor='center',\n",
705+
" yanchor='top'\n",
706+
" ),\n",
707+
" plot_bgcolor='#f2f2f2',\n",
708+
" xaxis_title=\"Date\",\n",
709+
" yaxis_title=\"New Cases\"\n",
710+
" ))\n",
711+
"\n",
712+
"# Keep track of trace visibility by country name\n",
713+
"country_plot_names = []\n",
714+
"\n",
715+
"# Add 1 trace per predictor, per country\n",
716+
"for predictor_name in predictor_names:\n",
717+
" for country_name in country_names:\n",
718+
" country_df = ranking_df[ranking_df.CountryName == country_name]\n",
719+
" pred_country_df = country_df[country_df[\"PredictorName\"] == predictor_name]\n",
720+
" fig.add_trace(go.Scatter(x=pred_country_df.Date,\n",
721+
" y=pred_country_df.PredictedDailyNewCases7DMA,\n",
722+
" name=predictor_name,\n",
723+
" visible= (country_name == default_country))\n",
724+
" )\n",
725+
" country_plot_names.append(country_name)\n",
726+
"\n",
727+
"# For each country\n",
728+
"# Add 1 trace for the true number of cases\n",
729+
"for country_name in country_names:\n",
730+
" country_actual_df = actual_df[(actual_df.CountryName == country_name) &\n",
731+
" (actual_df.Date >= start_date)]\n",
732+
" fig.add_trace(go.Scatter(x=country_actual_df.Date,\n",
733+
" y=country_actual_df.ActualDailyNewCases7DMA,\n",
734+
" name=\"Ground Truth\",\n",
735+
" visible= (country_name == default_country),\n",
736+
" line=dict(color='orange', width=4, dash='dash'))\n",
737+
" )\n",
738+
" country_plot_names.append(country_name)\n",
739+
"\n",
740+
"# Format x axis\n",
741+
"fig.update_xaxes(\n",
742+
"dtick=\"D1\", # Means 1 day\n",
743+
"tickformat=\"%d\\n%b\")\n",
744+
"\n",
745+
"# Filter\n",
746+
"buttons=[]\n",
747+
"for country_name in country_names:\n",
748+
" buttons.append(dict(method='update',\n",
749+
" label=country_name,\n",
750+
" args = [{'visible': [country_name==r for r in country_plot_names]},\n",
751+
" {'title': \"Predicted 7-day Moving Average of New Cases in \" + country_name}]))\n",
752+
"fig.update_layout(showlegend=True,\n",
753+
" updatemenus=[{\"buttons\": buttons,\n",
754+
" \"direction\": \"down\",\n",
755+
" \"active\": country_names.index(default_country),\n",
756+
" \"showactive\": True,\n",
757+
" \"x\": 0.1,\n",
758+
" \"y\": 1.15}])\n",
759+
"\n",
760+
"fig.show()"
761+
]
762+
},
763+
{
764+
"cell_type": "markdown",
765+
"metadata": {},
766+
"source": [
767+
"## Daily diff in 7 days moving average"
768+
]
769+
},
770+
{
771+
"cell_type": "code",
772+
"execution_count": null,
773+
"metadata": {},
774+
"outputs": [],
775+
"source": [
776+
"fig = px.line(ranking_df[ranking_df.CountryName == \"Italy\"], x=\"Date\", y=\"Diff7DMA\", color='PredictorName')\n",
777+
"fig.show()"
778+
]
779+
},
780+
{
781+
"cell_type": "code",
782+
"execution_count": null,
783+
"metadata": {},
784+
"outputs": [],
785+
"source": []
786+
},
578787
{
579788
"cell_type": "code",
580789
"execution_count": null,

0 commit comments

Comments
 (0)