From 2da38598972f2583f2c00d7c36f982d65bf9c9a9 Mon Sep 17 00:00:00 2001 From: Biowilko Date: Thu, 8 May 2025 14:32:06 +0100 Subject: [PATCH] Add paired support to align_trim.py --- artic/align_trim.py | 433 +++++++++++++++++++++++++++++++++++++++----- setup.cfg | 2 +- 2 files changed, 391 insertions(+), 44 deletions(-) diff --git a/artic/align_trim.py b/artic/align_trim.py index 4a77c1a..f158482 100755 --- a/artic/align_trim.py +++ b/artic/align_trim.py @@ -7,6 +7,7 @@ import numpy as np import random import argparse +from collections import defaultdict from artic.utils import read_bed_file # consumesReference lookup for if a CIGAR operation consumes the reference sequence @@ -343,6 +344,198 @@ def handle_segment( return (amplicon, segment) +def handle_paired_segment( + segments: tuple[pysam.AlignedSegment, pysam.AlignedSegment], + bed: dict, + args: argparse.Namespace, + min_mapq: int, + report_writer: csv.DictWriter = False, +): + """Handle the alignment segment including + + Args: + segment (pysam.AlignedSegment): The alignment segment to process + bed (dict): The primer scheme + reportfh (typing.IO): The report file handle + args (argparse.Namespace): The command line arguments + + Returns: + tuple [int, pysam.AlignedSegment] | bool: A tuple containing the amplicon number and the alignment segment, or False if the segment is to be skipped + """ + + segment1, segment2 = segments + + if not segment1 or not segment2: + if args.verbose: + print( + "Segment pair skipped as at least one segment in pair does not exist", + file=sys.stderr, + ) + return False + + # filter out unmapped and supplementary alignment segments + if segment1.is_unmapped or segment2.is_unmapped: + if args.verbose: + print( + "Segment pair: %s skipped as unmapped" % (segment1.query_name), + file=sys.stderr, + ) + return False + + if segment1.is_supplementary or segment2.is_supplementary: + if args.verbose: + print( + "Segment pair: %s skipped as supplementary" % (segment1.query_name), + file=sys.stderr, + ) + return False + + if segment1.mapping_quality < min_mapq or segment2.mapping_quality < min_mapq: + if args.verbose: + print( + "Segment pair: %s skipped as mapping quality below threshold" + % (segment1.query_name), + file=sys.stderr, + ) + return False + + # locate the nearest primers to this alignment segment + p1 = find_primer( + bed=bed, + pos=segment1.reference_start, + direction="+", + chrom=segment1.reference_name, + threshold=args.primer_match_threshold, + ) + p2 = find_primer( + bed=bed, + pos=segment2.reference_end, + direction="-", + chrom=segment2.reference_name, + threshold=args.primer_match_threshold, + ) + + if not p1 or not p2: + if args.verbose: + print( + "Paired segment: %s skipped as no primer found for segment" + % (segment1.query_name), + file=sys.stderr, + ) + return False + + # check if primers are correctly paired and then assign read group + # NOTE: removed this as a function as only called once + # TODO: will try improving this / moving it to the primer scheme processing code + correctly_paired = p1[2]["Primer_ID"].replace("_LEFT", "") == p2[2][ + "Primer_ID" + ].replace("_RIGHT", "") + + if not args.no_read_groups: + if correctly_paired: + segment1.set_tag("RG", p1[2]["PoolName"]) + segment2.set_tag("RG", p1[2]["PoolName"]) + else: + segment1.set_tag("RG", "unmatched") + segment2.set_tag("RG", "unmatched") + + # get the amplicon number + amplicon = p1[2]["Primer_ID"].split("_")[1] + + if args.report: + # update the report with this alignment segment + primer details + report = { + "chrom": segment1.reference_name, + "QueryName": segment1.query_name, + "ReferenceStart": segment1.reference_start, + "ReferenceEnd": segment2.reference_end, + "PrimerPair": f"{p1[2]['Primer_ID']}_{p2[2]['Primer_ID']}", + "Primer1": p1[2]["Primer_ID"], + "Primer1Start": abs(p1[1]), + "Primer2": p2[2]["Primer_ID"], + "Primer2Start": abs(p2[1]), + "IsSecondary": segment1.is_secondary, + "IsSupplementary": segment1.is_supplementary, + "Start": p1[2]["start"], + "End": p2[2]["end"], + "CorrectlyPaired": correctly_paired, + } + + report_writer.writerow(report) + + if args.remove_incorrect_pairs and not correctly_paired: + if args.verbose: + print( + "Paired segment: %s skipped as not correctly paired" + % (segment1.query_name), + file=sys.stderr, + ) + return False + + if args.verbose: + # Dont screw with the order of the dict + report_str = "\t".join(str(x) for x in report.values()) + print(report_str, file=sys.stderr) + + # get the primer positions + if args.trim_primers: + p1_position = p1[2]["end"] + p2_position = p2[2]["start"] + else: + p1_position = p1[2]["start"] + p2_position = p2[2]["end"] + + # softmask the alignment if left primer start/end inside alignment + if segment1.reference_start < p1_position: + try: + trim(segment1, p1_position, False, args.verbose) + if args.verbose: + print( + "ref start %s >= primer_position %s" + % (segment1.reference_start, p1_position), + file=sys.stderr, + ) + except Exception as e: + print( + "problem soft masking left primer in {} (error: {}), skipping".format( + segment1.query_name, e + ), + file=sys.stderr, + ) + return False + + # softmask the alignment if right primer start/end inside alignment + if segment2.reference_end > p2_position: + try: + trim(segment2, p2_position, True, args.verbose) + if args.verbose: + print( + "ref start %s >= primer_position %s" + % (segment2.reference_start, p2_position), + file=sys.stderr, + ) + except Exception as e: + print( + "problem soft masking right primer in {} (error: {}), skipping".format( + segment1.query_name, e + ), + file=sys.stderr, + ) + return False + + # check the the alignment still contains bases matching the reference + if "M" not in segment1.cigarstring or "M" not in segment2.cigarstring: + if args.verbose: + print( + "Paired segment: %s dropped as does not match reference post masking" + % (segment1.query_name), + file=sys.stderr, + ) + return False + + return (amplicon, segments) + + def generate_amplicons(bed: list): """Generate a dictionary of amplicons from a primer scheme list (generated by vcftagprimersites/read_bed_file) @@ -470,6 +663,109 @@ def normalise(trimmed_segments: dict, normalise: int, bed: list, verbose: bool = return output_segments, mean_depths +def normalise_paired(trimmed_segments: dict, normalise: int, bed: list): + """Normalise the depth of the trimmed segments to a given value. Perform per-amplicon normalisation using numpy vector maths to determine whether the segment in question would take the depth closer to the desired depth accross the amplicon. + + Args: + trimmed_segments (dict): Dict containing amplicon number as key and list of tuples liek: [pysam.AlignedSegment, pysam.AlignedSegment] as value + normalise (int): Desired normalised depth + bed (list): Primer scheme list (generated by vcftagprimersites/read_bed_file) + trim_primers (bool): Whether to trim primers from the reads + + Raises: + ValueError: Amplicon assigned to segment not found in primer scheme file + + Returns: + list : List of pysam.AlignedSegment to output + """ + + amplicons = generate_amplicons(bed) + + output_segments = [] + + mean_depths = {} + + for chrom, amplicon_dict in trimmed_segments.items(): + for amplicon, segments in amplicon_dict.items(): + if amplicon not in amplicons[chrom]: + raise ValueError(f"Segment {amplicon} not found in primer scheme file") + + desired_depth = np.full_like( + (amplicons[chrom][amplicon]["length"],), normalise, dtype=int + ) + + amplicon_depth = np.zeros( + (amplicons[chrom][amplicon]["length"],), dtype=int + ) + + if not segments: + print( + f"No segments assigned to amplicon {amplicon}, skipping", + file=sys.stderr, + ) + continue + + random.shuffle(segments) + + distance = np.mean(np.abs(amplicon_depth - desired_depth)) + + for paired_segments in segments: + + test_depths = np.copy(amplicon_depth) + + segment1, segment2 = paired_segments + + for segment in (segment1, segment2): + + relative_start = ( + segment.reference_start - amplicons[chrom][amplicon]["p_start"] + ) + + if relative_start < 0: + relative_start = 0 + + relative_end = ( + segment.reference_end - amplicons[chrom][amplicon]["p_start"] + ) + + test_depths[relative_start:relative_end] += 1 + + test_distance = np.mean(np.abs(test_depths - desired_depth)) + + if test_distance < distance: + amplicon_depth = test_depths + distance = test_distance + output_segments.append(segment1) + output_segments.append(segment2) + + mean_depths[(chrom, amplicon)] = np.mean(amplicon_depth) + + return output_segments, mean_depths + + +def read_pair_generator(bam, region_string=None): + """ + Generate read pairs in a BAM file or within a region string. + Reads are added to read_dict until a pair is found. + """ + read_dict = defaultdict(lambda: [None, None]) + for read in bam: + if not read.is_proper_pair: + continue + qname = read.query_name + if qname not in read_dict: + if read.is_read1: + read_dict[qname][0] = read + else: + read_dict[qname][1] = read + else: + if read.is_read1: + yield read, read_dict[qname][1] + else: + yield read_dict[qname][0], read + del read_dict[qname] + + def go(args): """Filter and soft mask an alignment file so that the alignment boundaries match the primer start and end sites. @@ -518,56 +814,102 @@ def go(args): trimmed_segments = {x: {} for x in chroms} - # iterate over the alignment segments in the input SAM file - for segment in infile: - if args.report: - trimming_tuple = handle_segment( - segment=segment, - bed=bed, - args=args, - report_writer=report_writer, - min_mapq=args.min_mapq, - ) - else: - trimming_tuple = handle_segment( - segment=segment, - bed=bed, - args=args, - min_mapq=args.min_mapq, - ) - if not trimming_tuple: - continue + if args.paired: + read_pairs = read_pair_generator(infile) + + for segments in read_pairs: + if args.report: + trimming_tuple = handle_paired_segment( + segments=segments, + bed=bed, + args=args, + report_writer=report_writer, + min_mapq=args.min_mapq, + ) + else: + trimming_tuple = handle_paired_segment( + segments=segments, + bed=bed, + args=args, + min_mapq=args.min_mapq, + ) + if not trimming_tuple: + continue - # unpack the trimming tuple since segment passed trimming - amplicon, trimmed_segment = trimming_tuple - trimmed_segments[trimmed_segment.reference_name].setdefault(amplicon, []) + # unpack the trimming tuple since segment passed trimming + amplicon, trimmed_pair = trimming_tuple + trimmed_segments[trimmed_pair[0].reference_name].setdefault(amplicon, []) - if trimmed_segment: - trimmed_segments[trimmed_segment.reference_name][amplicon].append( - trimmed_segment + if trimmed_segments: + trimmed_segments[trimmed_pair[0].reference_name][amplicon].append( + trimmed_pair + ) + + if args.normalise: + output_segments, mean_amp_depths = normalise_paired( + trimmed_segments, args.normalise, bed ) - # normalise if requested - if args.normalise: - output_segments, mean_amp_depths = normalise( - trimmed_segments, args.normalise, bed, args.verbose - ) + # write mean amplicon depths to file + if args.amp_depth_report: + with open(args.amp_depth_report, "w") as amp_depth_report_fh: + amp_depth_report_fh.write("chrom\tamplicon\tmean_depth\n") + for (chrom, amplicon), depth in mean_amp_depths.items(): + amp_depth_report_fh.write(f"{chrom}\t{amplicon}\t{depth}\n") + + for output_segment in output_segments: + outfile.write(output_segment) + else: + # iterate over the alignment segments in the input SAM file + for segment in infile: + if args.report: + trimming_tuple = handle_segment( + segment=segment, + bed=bed, + args=args, + report_writer=report_writer, + min_mapq=args.min_mapq, + ) + else: + trimming_tuple = handle_segment( + segment=segment, + bed=bed, + args=args, + min_mapq=args.min_mapq, + ) + if not trimming_tuple: + continue - # write mean amplicon depths to file - if args.amp_depth_report: - with open(args.amp_depth_report, "w") as amp_depth_report_fh: - amp_depth_report_fh.write("chrom\tamplicon\tmean_depth\n") - for (chrom, amplicon), depth in mean_amp_depths.items(): - amp_depth_report_fh.write(f"{chrom}\t{amplicon}\t{depth}\n") + # unpack the trimming tuple since segment passed trimming + amplicon, trimmed_segment = trimming_tuple + trimmed_segments[trimmed_segment.reference_name].setdefault(amplicon, []) - for output_segment in output_segments: - outfile.write(output_segment) + if trimmed_segment: + trimmed_segments[trimmed_segment.reference_name][amplicon].append( + trimmed_segment + ) - else: - for chrom, amplicon_dict in trimmed_segments.items(): - for amplicon, segments in amplicon_dict.items(): - for segment in segments: - outfile.write(segment) + # normalise if requested + if args.normalise: + output_segments, mean_amp_depths = normalise( + trimmed_segments, args.normalise, bed, args.verbose + ) + + # write mean amplicon depths to file + if args.amp_depth_report: + with open(args.amp_depth_report, "w") as amp_depth_report_fh: + amp_depth_report_fh.write("chrom\tamplicon\tmean_depth\n") + for (chrom, amplicon), depth in mean_amp_depths.items(): + amp_depth_report_fh.write(f"{chrom}\t{amplicon}\t{depth}\n") + + for output_segment in output_segments: + outfile.write(output_segment) + + else: + for chrom, amplicon_dict in trimmed_segments.items(): + for amplicon, segments in amplicon_dict.items(): + for segment in segments: + outfile.write(segment) # close up the file handles infile.close() @@ -602,6 +944,11 @@ def main(): action="store_true", help="Trims primers from reads", ) + parser.add_argument( + "--paired", + action="store_true", + help="Process paired-end reads", + ) parser.add_argument( "--no-read-groups", dest="no_read_groups", diff --git a/setup.cfg b/setup.cfg index af30012..a693feb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = artic -version = 1.6.2 +version = 1.6.3 author = Nick Loman author_email = n.j.loman@bham.ac.uk maintainer = Sam Wilkinson