88import argparse
99
1010import numpy as np
11+ import pandas as pd
1112
12- from covid_xprize .validation .scenario_generator import get_raw_data
1313from covid_xprize .validation .scenario_generator import NPI_COLUMNS as IP_COLUMNS
1414
1515ROOT_DIR = os .path .dirname (os .path .abspath (__file__ ))
16- FIXTURES_PATH = os .path .join (ROOT_DIR , 'data' )
17- DATA_FILE = os .path .join (FIXTURES_PATH , "OxCGRT_latest.csv" )
16+ DEFAULT_GEOS = os .path .join (ROOT_DIR , '..' , '..' , "countries_regions.csv" )
1817
1918
2019def generate_costs (distribution = 'ones' ):
2120 """
22- Returns df of costs for each IP for each geo according to distribution.
21+ Returns a df of costs for each IP for default list of geos according to distribution.
22+ """
23+ return generate_costs_for_geos_file (DEFAULT_GEOS , distribution )
24+
25+
26+ def generate_costs_for_geos_file (geos_file , distribution = 'ones' ):
27+ """
28+ Returns a df of costs for each IP for geos in geos_file according to distribution.
29+ """
30+ geos_df = load_geos (geos_file )
31+ return generate_costs_for_geos_df (geos_df , distribution )
32+
33+
34+ def generate_costs_for_geos_df (geos_df , distribution = 'ones' ):
35+ """
36+ Returns df of costs for each IP for each geo in geos_df according to distribution.
2337
2438 Costs always sum to #IPS (i.e., len(IP_COLUMNS)).
2539
@@ -28,18 +42,12 @@ def generate_costs(distribution='ones'):
2842 - 'uniform': costs are sampled uniformly across IPs independently
2943 for each geo.
3044 """
45+ # Copy the countries and regions dataset in order to add IP columns
46+ df = geos_df .copy ()
47+
3148 assert distribution in ['ones' , 'uniform' ], \
3249 f'Unsupported distribution { distribution } '
3350
34-
35- df = get_raw_data (DATA_FILE , latest = False )
36-
37- # Reduce df to one row per geo
38- df = df .groupby (['CountryName' , 'RegionName' ]).mean ().reset_index ()
39-
40- # Reduce to geo id info
41- df = df [['CountryName' , 'RegionName' ]]
42-
4351 if distribution == 'ones' :
4452 df [IP_COLUMNS ] = 1
4553
@@ -52,12 +60,20 @@ def generate_costs(distribution='ones'):
5260 weights = nb_ips * samples / samples .sum (axis = 0 )
5361 df [IP_COLUMNS ] = weights .T
5462
55- # Round weights for better readability with neglible loss of generality.
63+ # Round weights for better readability with negligible loss of generality.
5664 df = df .round (2 )
5765
5866 return df
5967
6068
69+ def load_geos (path_to_geo_file ):
70+ print (f"Loading countries and regions from { path_to_geo_file } " )
71+ geos_df = pd .read_csv (path_to_geo_file ,
72+ encoding = "ISO-8859-1" ,
73+ dtype = {"RegionName" : str })
74+ return geos_df
75+
76+
6177if __name__ == '__main__' :
6278
6379 parser = argparse .ArgumentParser ()
@@ -66,14 +82,22 @@ def generate_costs(distribution='ones'):
6682 required = True ,
6783 help = "Distribution to generate weights from. Current"
6884 "options are 'ones', and 'uniform'." )
85+ parser .add_argument ("-c" , "--countries_path" ,
86+ dest = "countries_path" ,
87+ type = str ,
88+ required = False ,
89+ default = DEFAULT_GEOS ,
90+ help = "The path to a csv file containing the list of countries and regions to use. "
91+ "The csv file must contain the following columns: CountryName,RegionName "
92+ "and names must match latest Oxford's ones" )
6993 parser .add_argument ("-o" , "--output_file" ,
7094 type = str ,
7195 required = True ,
7296 help = "Name of csv file to write generated weights to." )
7397 args = parser .parse_args ()
7498
7599 print (f"Generating weights with distribution { args .distribution } ..." )
76- weights_df = generate_costs ( args .distribution )
100+ weights_df = generate_costs_for_geos_file ( args . countries_path , args .distribution )
77101 print ("Writing weights to file..." )
78102 weights_df .to_csv (args .output_file , index = False )
79103 print ("Done. Thank you." )
0 commit comments