Skip to content

Commit c8dde11

Browse files
authored
Merge pull request #201 from leaf-ai/cost_gen
#200 Specify list of countries and regions for cost_generator
2 parents 5150e36 + 2cfc6a6 commit c8dde11

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

covid_xprize/validation/cost_generator.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,32 @@
88
import argparse
99

1010
import numpy as np
11+
import pandas as pd
1112

12-
from covid_xprize.validation.scenario_generator import get_raw_data
1313
from covid_xprize.validation.scenario_generator import NPI_COLUMNS as IP_COLUMNS
1414

1515
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
16-
FIXTURES_PATH = os.path.join(ROOT_DIR, 'data')
17-
DATA_FILE = os.path.join(FIXTURES_PATH, "OxCGRT_latest.csv")
16+
DEFAULT_GEOS = os.path.join(ROOT_DIR, '..', '..', "countries_regions.csv")
1817

1918

2019
def generate_costs(distribution='ones'):
2120
"""
22-
Returns df of costs for each IP for each geo according to distribution.
21+
Returns a df of costs for each IP for default list of geos according to distribution.
22+
"""
23+
return generate_costs_for_geos_file(DEFAULT_GEOS, distribution)
24+
25+
26+
def generate_costs_for_geos_file(geos_file, distribution='ones'):
27+
"""
28+
Returns a df of costs for each IP for geos in geos_file according to distribution.
29+
"""
30+
geos_df = load_geos(geos_file)
31+
return generate_costs_for_geos_df(geos_df, distribution)
32+
33+
34+
def generate_costs_for_geos_df(geos_df, distribution='ones'):
35+
"""
36+
Returns df of costs for each IP for each geo in geos_df according to distribution.
2337
2438
Costs always sum to #IPS (i.e., len(IP_COLUMNS)).
2539
@@ -28,18 +42,12 @@ def generate_costs(distribution='ones'):
2842
- 'uniform': costs are sampled uniformly across IPs independently
2943
for each geo.
3044
"""
45+
# Copy the countries and regions dataset in order to add IP columns
46+
df = geos_df.copy()
47+
3148
assert distribution in ['ones', 'uniform'], \
3249
f'Unsupported distribution {distribution}'
3350

34-
35-
df = get_raw_data(DATA_FILE, latest=False)
36-
37-
# Reduce df to one row per geo
38-
df = df.groupby(['CountryName', 'RegionName']).mean().reset_index()
39-
40-
# Reduce to geo id info
41-
df = df[['CountryName', 'RegionName']]
42-
4351
if distribution == 'ones':
4452
df[IP_COLUMNS] = 1
4553

@@ -52,12 +60,20 @@ def generate_costs(distribution='ones'):
5260
weights = nb_ips * samples / samples.sum(axis=0)
5361
df[IP_COLUMNS] = weights.T
5462

55-
# Round weights for better readability with neglible loss of generality.
63+
# Round weights for better readability with negligible loss of generality.
5664
df = df.round(2)
5765

5866
return df
5967

6068

69+
def load_geos(path_to_geo_file):
70+
print(f"Loading countries and regions from {path_to_geo_file}")
71+
geos_df = pd.read_csv(path_to_geo_file,
72+
encoding="ISO-8859-1",
73+
dtype={"RegionName": str})
74+
return geos_df
75+
76+
6177
if __name__ == '__main__':
6278

6379
parser = argparse.ArgumentParser()
@@ -66,14 +82,22 @@ def generate_costs(distribution='ones'):
6682
required=True,
6783
help="Distribution to generate weights from. Current"
6884
"options are 'ones', and 'uniform'.")
85+
parser.add_argument("-c", "--countries_path",
86+
dest="countries_path",
87+
type=str,
88+
required=False,
89+
default=DEFAULT_GEOS,
90+
help="The path to a csv file containing the list of countries and regions to use. "
91+
"The csv file must contain the following columns: CountryName,RegionName "
92+
"and names must match latest Oxford's ones")
6993
parser.add_argument("-o", "--output_file",
7094
type=str,
7195
required=True,
7296
help="Name of csv file to write generated weights to.")
7397
args = parser.parse_args()
7498

7599
print(f"Generating weights with distribution {args.distribution}...")
76-
weights_df = generate_costs(args.distribution)
100+
weights_df = generate_costs_for_geos_file(args.countries_path, args.distribution)
77101
print("Writing weights to file...")
78102
weights_df.to_csv(args.output_file, index=False)
79103
print("Done. Thank you.")

0 commit comments

Comments
 (0)