Skip to content

Commit b0ea23a

Browse files
authored
add complete() logic to UpdatedReferenceDatasetCollectionTask, including check globals against config (#669)
* first pass at update task * move stuff back to /lib file * add tests for former combine file * add tests for former combine file * comment w/ logic for complete methjod * task unit test * add missing test cases and default luigi datasets param * delete_reference_data_ht removed * pseudocode * implement match check functions * handle dataset not in joined globals? * oops * tests for all fns except version * add version tests * test validate function some refactor * add logging and tests for task * refactor without custom selects * eval in validate function, add dataclass * move checking dataset in ht into compare file, simplify complete logic * get_ht_path * raise error on dataset config version mismatch * rethink selects compare, use existing get field code * typo * review comments * review comments
1 parent c0bca33 commit b0ea23a

23 files changed

+784
-144
lines changed

v03_pipeline/bin/write_cached_reference_dataset_query_ht.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
valid_reference_dataset_collection_path,
1616
)
1717
from v03_pipeline.lib.reference_data.config import CONFIG
18+
from v03_pipeline.lib.reference_data.dataset_table_operations import (
19+
import_ht_from_config_path,
20+
)
1821

1922

2023
def get_ht(
@@ -25,11 +28,7 @@ def get_ht(
2528
# If the query is defined over an uncombined reference dataset, use the combiner config.
2629
if query.reference_dataset:
2730
config = CONFIG[query.reference_dataset][reference_genome.v02_value]
28-
return (
29-
config['custom_import'](config['source_path'], reference_genome)
30-
if 'custom_import' in config
31-
else hl.read_table(config['path'])
32-
)
31+
return import_ht_from_config_path(config, reference_genome)
3332
return hl.read_table(
3433
valid_reference_dataset_collection_path(
3534
reference_genome,
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import logging
2+
from dataclasses import dataclass
3+
4+
import hail as hl
5+
6+
from v03_pipeline.lib.model import ReferenceGenome
7+
from v03_pipeline.lib.reference_data.config import CONFIG
8+
from v03_pipeline.lib.reference_data.dataset_table_operations import (
9+
get_all_select_fields,
10+
get_ht_path,
11+
import_ht_from_config_path,
12+
parse_dataset_version,
13+
)
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
@dataclass
19+
class ReferenceDataGlobals:
20+
paths: dict[str]
21+
versions: dict[str]
22+
enums: dict[str, dict[str, list[str]]]
23+
24+
def __init__(self, globals_struct: hl.Struct):
25+
self.paths = self._struct_to_dict(globals_struct.paths)
26+
self.versions = self._struct_to_dict(globals_struct.versions)
27+
self.enums = self._struct_to_dict(globals_struct.enums)
28+
29+
def _struct_to_dict(self, struct: hl.Struct) -> dict:
30+
result_dict = {}
31+
for field in struct:
32+
if isinstance(struct[field], hl.Struct):
33+
result_dict[field] = self._struct_to_dict(struct[field])
34+
else:
35+
result_dict[field] = struct[field]
36+
return result_dict
37+
38+
39+
def get_datasets_to_update(
40+
joined_ht: hl.Table,
41+
datasets: list[str],
42+
reference_genome: ReferenceGenome,
43+
) -> list[str]:
44+
joined_ht_globals = ReferenceDataGlobals(hl.eval(joined_ht.index_globals()))
45+
datasets_to_update = []
46+
for dataset in datasets:
47+
if dataset not in joined_ht.row:
48+
datasets_to_update.append(dataset)
49+
continue
50+
51+
if not validate_joined_ht_globals_match_config(
52+
joined_ht,
53+
joined_ht_globals,
54+
dataset,
55+
reference_genome,
56+
):
57+
datasets_to_update.append(dataset)
58+
return datasets_to_update
59+
60+
61+
def validate_joined_ht_globals_match_config(
62+
joined_ht: hl.Table,
63+
joined_ht_globals: ReferenceDataGlobals,
64+
dataset: str,
65+
reference_genome: ReferenceGenome,
66+
) -> bool:
67+
dataset_config = CONFIG[dataset][reference_genome.v02_value]
68+
dataset_ht = import_ht_from_config_path(dataset_config, reference_genome)
69+
checks = {
70+
'version': ht_version_matches_config(
71+
joined_ht_globals,
72+
dataset,
73+
dataset_config,
74+
dataset_ht,
75+
),
76+
'path': ht_path_matches_config(joined_ht_globals, dataset, dataset_config),
77+
'enum': ht_enums_match_config(joined_ht_globals, dataset, dataset_config),
78+
'select': ht_selects_match_config(
79+
joined_ht,
80+
dataset,
81+
dataset_config,
82+
dataset_ht,
83+
),
84+
}
85+
86+
results = []
87+
for check, result in checks.items():
88+
if result is False:
89+
logger.info(f'{check} mismatch for {dataset}')
90+
results.append(result)
91+
return all(results)
92+
93+
94+
def ht_version_matches_config(
95+
joined_ht_globals: ReferenceDataGlobals,
96+
dataset: str,
97+
dataset_config: dict,
98+
dataset_ht: hl.Table,
99+
) -> bool:
100+
joined_ht_version = joined_ht_globals.versions.get(dataset)
101+
if joined_ht_version is None:
102+
return False
103+
104+
config_or_dataset_version = hl.eval(
105+
parse_dataset_version(
106+
dataset_ht,
107+
dataset,
108+
dataset_config,
109+
),
110+
)
111+
return joined_ht_version == config_or_dataset_version
112+
113+
114+
def ht_path_matches_config(
115+
joined_ht_globals: ReferenceDataGlobals,
116+
dataset: str,
117+
dataset_config: dict,
118+
) -> bool:
119+
joined_ht_path = joined_ht_globals.paths.get(dataset)
120+
if joined_ht_path is None:
121+
return False
122+
123+
config_path = get_ht_path(dataset_config)
124+
return joined_ht_path == config_path
125+
126+
127+
def ht_enums_match_config(
128+
joined_ht_globals: ReferenceDataGlobals,
129+
dataset: str,
130+
dataset_config: dict,
131+
) -> bool:
132+
joined_ht_enums = joined_ht_globals.enums.get(dataset, {})
133+
config_enums = dataset_config.get('enum_select', {})
134+
return joined_ht_enums == config_enums
135+
136+
137+
def ht_selects_match_config(
138+
joined_ht: hl.Table,
139+
dataset: str,
140+
dataset_config: dict,
141+
dataset_ht: hl.Table,
142+
) -> bool:
143+
joined_ht_selects = set(joined_ht[dataset])
144+
config_selects = set(get_all_select_fields(dataset_ht, dataset_config).keys())
145+
return len(config_selects.symmetric_difference(joined_ht_selects)) == 0

0 commit comments

Comments
 (0)