|
1 |
| -from collections.abc import Iterable |
2 | 1 | from typing import Any
|
3 | 2 |
|
4 | 3 | import hail as hl
|
@@ -28,29 +27,25 @@ def __init__(
|
28 | 27 |
|
29 | 28 | def validate_allele_type(
|
30 | 29 | t: hl.Table | hl.MatrixTable,
|
31 |
| - dataset_type: DatasetType | Iterable[DatasetType], |
| 30 | + dataset_type: DatasetType, |
32 | 31 | **_: Any,
|
33 | 32 | ) -> None:
|
34 | 33 | ht = t.rows() if isinstance(t, hl.MatrixTable) else t
|
35 |
| - dataset_types = ( |
36 |
| - [dataset_type] if isinstance(dataset_type, DatasetType) else dataset_type |
| 34 | + ht = ht.filter( |
| 35 | + dataset_type.invalid_allele_types.contains( |
| 36 | + hl.numeric_allele_type(ht.alleles[0], ht.alleles[1]), |
| 37 | + ), |
37 | 38 | )
|
38 |
| - for dataset_type in dataset_types: |
39 |
| - ht = ht.filter( |
40 |
| - dataset_type.invalid_allele_types.contains( |
41 |
| - hl.numeric_allele_type(ht.alleles[0], ht.alleles[1]), |
42 |
| - ), |
| 39 | + if ht.count() > 0: |
| 40 | + collected_alleles = sorted( |
| 41 | + [tuple(x) for x in ht.aggregate(hl.agg.collect_as_set(ht.alleles))], |
43 | 42 | )
|
44 |
| - if ht.count() > 0: |
45 |
| - collected_alleles = sorted( |
46 |
| - [tuple(x) for x in ht.aggregate(hl.agg.collect_as_set(ht.alleles))], |
47 |
| - ) |
48 |
| - # Handle case where all invalid alleles are NON_REF, indicating a gvcf: |
49 |
| - if all('<NON_REF>' in alleles for alleles in collected_alleles): |
50 |
| - msg = 'Alleles with invalid allele <NON_REF> are present in the callset. This appears to be a GVCF containing records for sites with no variants.' |
51 |
| - raise SeqrValidationError(msg) |
52 |
| - msg = f'Alleles with invalid AlleleType are present in the callset: {collected_alleles[:10]}' |
| 43 | + # Handle case where all invalid alleles are NON_REF, indicating a gvcf: |
| 44 | + if all('<NON_REF>' in alleles for alleles in collected_alleles): |
| 45 | + msg = 'Alleles with invalid allele <NON_REF> are present in the callset. This appears to be a GVCF containing records for sites with no variants.' |
53 | 46 | raise SeqrValidationError(msg)
|
| 47 | + msg = f'Alleles with invalid AlleleType are present in the callset: {collected_alleles[:10]}' |
| 48 | + raise SeqrValidationError(msg) |
54 | 49 |
|
55 | 50 |
|
56 | 51 | def validate_no_duplicate_variants(
|
|
0 commit comments