|
| 1 | +from collections.abc import Iterable |
1 | 2 | from typing import Any
|
2 | 3 |
|
3 | 4 | import hail as hl
|
@@ -27,25 +28,29 @@ def __init__(
|
27 | 28 |
|
28 | 29 | def validate_allele_type(
|
29 | 30 | t: hl.Table | hl.MatrixTable,
|
30 |
| - dataset_type: DatasetType, |
| 31 | + dataset_type: DatasetType | Iterable[DatasetType], |
31 | 32 | **_: Any,
|
32 | 33 | ) -> None:
|
33 | 34 | ht = t.rows() if isinstance(t, hl.MatrixTable) else t
|
34 |
| - ht = ht.filter( |
35 |
| - dataset_type.invalid_allele_types.contains( |
36 |
| - hl.numeric_allele_type(ht.alleles[0], ht.alleles[1]), |
37 |
| - ), |
| 35 | + dataset_types = ( |
| 36 | + [dataset_type] if isinstance(dataset_type, DatasetType) else dataset_type |
38 | 37 | )
|
39 |
| - if ht.count() > 0: |
40 |
| - collected_alleles = sorted( |
41 |
| - [tuple(x) for x in ht.aggregate(hl.agg.collect_as_set(ht.alleles))], |
| 38 | + for dataset_type in dataset_types: |
| 39 | + ht = ht.filter( |
| 40 | + dataset_type.invalid_allele_types.contains( |
| 41 | + hl.numeric_allele_type(ht.alleles[0], ht.alleles[1]), |
| 42 | + ), |
42 | 43 | )
|
43 |
| - # Handle case where all invalid alleles are NON_REF, indicating a gvcf: |
44 |
| - if all('<NON_REF>' in alleles for alleles in collected_alleles): |
45 |
| - msg = 'Alleles with invalid allele <NON_REF> are present in the callset. This appears to be a GVCF containing records for sites with no variants.' |
| 44 | + if ht.count() > 0: |
| 45 | + collected_alleles = sorted( |
| 46 | + [tuple(x) for x in ht.aggregate(hl.agg.collect_as_set(ht.alleles))], |
| 47 | + ) |
| 48 | + # Handle case where all invalid alleles are NON_REF, indicating a gvcf: |
| 49 | + if all('<NON_REF>' in alleles for alleles in collected_alleles): |
| 50 | + msg = 'Alleles with invalid allele <NON_REF> are present in the callset. This appears to be a GVCF containing records for sites with no variants.' |
| 51 | + raise SeqrValidationError(msg) |
| 52 | + msg = f'Alleles with invalid AlleleType are present in the callset: {collected_alleles[:10]}' |
46 | 53 | raise SeqrValidationError(msg)
|
47 |
| - msg = f'Alleles with invalid AlleleType are present in the callset: {collected_alleles[:10]}' |
48 |
| - raise SeqrValidationError(msg) |
49 | 54 |
|
50 | 55 |
|
51 | 56 | def validate_no_duplicate_variants(
|
|
0 commit comments