Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions nimble/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,11 @@ def sort_input_bam(bam, cores, tmp_dir):
fastq_to_bam_parser = subparsers.add_parser('fastq-to-bam')
fastq_to_bam_parser.add_argument('--r1-fastq', help='Path to R1 FASTQ file.', type=str, required=True)
fastq_to_bam_parser.add_argument('--r2-fastq', help='Path to R2 FASTQ file.', type=str, required=True)
fastq_to_bam_parser.add_argument("--map", required=True, help="TSV(.gz) with columns: rawCB, correctedCB, rawUMI, correctedUMI")
fastq_to_bam_parser.add_argument("--map", required=True, help="Cell barcode whitelist file (one CB per line, .gz or plain text)")
fastq_to_bam_parser.add_argument('--output', help='Path for output BAM file.', type=str, required=True)
fastq_to_bam_parser.add_argument('-c', '--num_cores', help='The number of cores to use for processing.', type=int, default=1)
fastq_to_bam_parser.add_argument('--cb-length', help='Length of cell barcode (default: 16).', type=int, default=16)
fastq_to_bam_parser.add_argument('--umi-length', help='Length of UMI (default: 12).', type=int, default=12)

args = parser.parse_args()

Expand All @@ -445,7 +447,9 @@ def sort_input_bam(bam, cores, tmp_dir):
args.r2_fastq,
args.map,
args.output,
args.num_cores
args.num_cores,
args.cb_length,
args.umi_length
)
elif args.subcommand == 'plot':
if os.path.getsize(args.input_file) > 0:
Expand Down
254 changes: 170 additions & 84 deletions nimble/fastq_barcode_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,126 @@
import pysam
from Bio import SeqIO

def load_cb_umi_mapping(tsv_path):
"""
Load a TSV with 4 columns:
1) Raw CB
2) Corrected CB
3) Raw UMI
4) Corrected UMI

def hamming_distance(s1, s2):
"""Calculate Hamming distance between two strings of equal length."""
if len(s1) != len(s2):
return float('inf')
return sum(c1 != c2 for c1, c2 in zip(s1, s2))


def build_hamming_index(whitelist):
"""
Build an index mapping each valid CB to all possible 1-edit variants.
This allows efficient lookup of correction candidates.

Returns:
cb_map: dict[str raw_cb] = str corrected_cb
umi_map_by_raw_cb: dict[str raw_cb] = dict[str raw_umi] = str corrected_umi
dict[str variant] = set[str valid_cb]
"""
cb_map = {}
umi_map_by_raw_cb = defaultdict(dict)

open_func = gzip.open if tsv_path.endswith('.gz') else open
mode = 'rt' if tsv_path.endswith('.gz') else 'r'
bases = ['A', 'C', 'G', 'T', 'N']
hamming_index = defaultdict(set)

for valid_cb in whitelist:
# For each position, generate all possible single-base substitutions
for i in range(len(valid_cb)):
for base in bases:
if base != valid_cb[i]:
variant = valid_cb[:i] + base + valid_cb[i+1:]
hamming_index[variant].add(valid_cb)

return hamming_index


def load_cb_whitelist(whitelist_path):
"""
Load a cell barcode whitelist (one CB per line).
Build a Hamming distance = 1 index for efficient correction.

line_ct = 0
malformed = 0
pairs_ct = 0
unique_cb = set()
Returns:
whitelist: set of valid cell barcodes
hamming_index: dict mapping variants to valid CBs
"""
whitelist = set()

open_func = gzip.open if whitelist_path.endswith('.gz') else open
mode = 'rt' if whitelist_path.endswith('.gz') else 'r'

try:
with open_func(tsv_path, mode) as f:
for line_num, line in enumerate(f, start=1):
with open_func(whitelist_path, mode) as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split('\t')
if len(parts) < 4:
malformed += 1
continue

raw_cb, corr_cb, raw_umi, corr_umi = parts[0], parts[1], parts[2], parts[3]
cb_map[raw_cb] = corr_cb
umi_map_by_raw_cb[raw_cb][raw_umi] = corr_umi
pairs_ct += 1
unique_cb.add(raw_cb)
line_ct += 1
if line:
whitelist.add(line)
except Exception as e:
print(f"Error loading CB/UMI mapping from {tsv_path}: {e}", file=sys.stderr)
print(f"Error loading CB whitelist from {whitelist_path}: {e}", file=sys.stderr)
sys.exit(1)

print(f"Loaded mapping from {tsv_path}")
print(f" Lines read: {line_ct + malformed}")
print(f" Malformed lines skipped (<4 cols): {malformed}")
print(f" Unique raw CBs: {len(unique_cb)}")
print(f" (raw CB, raw UMI) pairs: {pairs_ct}")
return cb_map, umi_map_by_raw_cb
print(f"Loaded whitelist from {whitelist_path}")
print(f" Valid cell barcodes: {len(whitelist)}")

print("Building Hamming distance index...")
hamming_index = build_hamming_index(whitelist)
print(f" Indexed {len(hamming_index)} variants")

return whitelist, hamming_index


def correct_cell_barcode(raw_cb, quality_scores, whitelist, hamming_index, correction_cache):
"""
Correct a raw cell barcode using 10x-style correction:
1. Check for perfect match
2. If not, find all candidates with Hamming distance = 1
3. Among candidates, select the one where the differing base has the lowest quality score

Args:
raw_cb: Raw cell barcode string
quality_scores: List of quality scores for the CB region
whitelist: Set of valid cell barcodes
hamming_index: Dict mapping variants to valid CBs
correction_cache: Dict for caching corrections

Returns:
Corrected CB string, or None if no valid correction found
"""
# Check cache first
if raw_cb in correction_cache:
return correction_cache[raw_cb]

# Check for perfect match
if raw_cb in whitelist:
correction_cache[raw_cb] = raw_cb
return raw_cb

# Find candidates with Hamming distance = 1
candidates = hamming_index.get(raw_cb, set())

if not candidates:
correction_cache[raw_cb] = None
return None

if len(candidates) == 1:
# Only one candidate, use it
corrected = next(iter(candidates))
correction_cache[raw_cb] = corrected
return corrected

# Multiple candidates: select based on quality scores
# Find the differing position and pick candidate with lowest quality at that position
best_candidate = None
lowest_quality = float('inf')

for candidate in candidates:
# Find the position where they differ
for i, (raw_base, cand_base) in enumerate(zip(raw_cb, candidate)):
if raw_base != cand_base:
qual = quality_scores[i]
if qual < lowest_quality:
lowest_quality = qual
best_candidate = candidate
break

correction_cache[raw_cb] = best_candidate
return best_candidate


def parse_10x_barcode_from_r1(sequence, cb_length=16, umi_length=12):
Expand All @@ -71,13 +141,12 @@ def parse_10x_barcode_from_r1(sequence, cb_length=16, umi_length=12):
return cell_barcode, umi, remaining_sequence


def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_length=16, umi_length=12):
def process_pair(r1_record, r2_record, whitelist, hamming_index, correction_cache, stats, cb_length=16, umi_length=12):
"""
Process a single FASTQ pair and return (r1_bam, r2_bam), or None if skipped.
Mapping logic:
- Look up corrected CB by raw CB.
- Look up corrected UMI by (raw CB, raw UMI).
Both must exist to emit a pair.
Correction logic:
- Correct CB using 10x-style Hamming distance = 1 correction with quality scores
- Use raw UMI (no correction)
"""
# Normalize names (handle /1 and /2 suffixes if present)
r1_name = r1_record.id.removesuffix('/1')
Expand All @@ -87,33 +156,33 @@ def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_leng
return None

r1_seq = str(r1_record.seq)
cell_barcode, umi, remaining_r1_seq = parse_10x_barcode_from_r1(r1_seq, cb_length, umi_length)
if cell_barcode is None or umi is None:
raw_cb, umi, remaining_r1_seq = parse_10x_barcode_from_r1(r1_seq, cb_length, umi_length)
if raw_cb is None or umi is None:
stats['too_short'] += 1
return None
if len(remaining_r1_seq) == 0:
stats['no_remaining_seq'] += 1
return None

# CB and UMI corrections (UMI in context of *raw* CB)
corrected_cb = cb_map.get(cell_barcode)
# Get quality scores for the CB region
cb_quality_scores = r1_record.letter_annotations["phred_quality"][:cb_length]

# Correct cell barcode using 10x-style correction
corrected_cb = correct_cell_barcode(raw_cb, cb_quality_scores, whitelist, hamming_index, correction_cache)

if corrected_cb is None:
stats['cb_not_in_map'] += 1
return None

umi_map_for_cb = umi_map_by_raw_cb.get(cell_barcode)
if not umi_map_for_cb:
stats['cb_has_no_umi_map'] += 1
return None

corrected_umi = umi_map_for_cb.get(umi)
if corrected_umi is None:
stats['umi_not_in_map_for_cb'] += 1
stats['cb_no_correction'] += 1
return None

# Track correction statistics
if corrected_cb == raw_cb:
stats['cb_perfect_match'] += 1
else:
stats['cb_corrected'] += 1

barcode_length = cb_length + umi_length

# Build unaligned BAM records
# Build unaligned BAM records with corrected CB and raw UMI
r1_bam = pysam.AlignedSegment()
r1_bam.query_name = r1_name
r1_bam.query_sequence = remaining_r1_seq
Expand All @@ -124,7 +193,7 @@ def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_leng
r1_bam.reference_start = -1
r1_bam.mapping_quality = 0
r1_bam.set_tag("CB", corrected_cb)
r1_bam.set_tag("UB", corrected_umi)
r1_bam.set_tag("UB", umi) # Use raw UMI

r2_bam = pysam.AlignedSegment()
r2_bam.query_name = r2_name
Expand All @@ -135,26 +204,31 @@ def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_leng
r2_bam.reference_start = -1
r2_bam.mapping_quality = 0
r2_bam.set_tag("CB", corrected_cb)
r2_bam.set_tag("UB", corrected_umi)
r2_bam.set_tag("UB", umi) # Use raw UMI

return r1_bam, r2_bam


def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_bam, num_cores=1, cb_length=16, umi_length=12):
def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_whitelist_file, output_bam, num_cores=1, cb_length=16, umi_length=12):
"""
Convert paired FASTQ files to unaligned BAM with CB/UB tags using multiple threads.
Performs 10x-style cell barcode correction using a whitelist.

Args:
r1_fastq: path to R1 FASTQ(.gz)
r2_fastq: path to R2 FASTQ(.gz)
cb_umi_mapping_file: TSV(.gz) with 4 columns (rawCB, corrCB, rawUMI, corrUMI)
cb_whitelist_file: path to cell barcode whitelist file (one CB per line, .gz or plain text)
output_bam: path to output BAM
num_cores: threads
cb_length, umi_length: lengths for parsing R1
cb_length: length of cell barcode (default 16)
umi_length: length of UMI (default 12)
"""
print("Loading CB/UMI mapping...")
cb_map, umi_map_by_raw_cb = load_cb_umi_mapping(cb_umi_mapping_file)

print("Loading cell barcode whitelist...")
whitelist, hamming_index = load_cb_whitelist(cb_whitelist_file)

# Initialize correction cache (shared across threads, but thread-safe via GIL for dict operations)
correction_cache = {}

stats = defaultdict(int)

r1_open_func = gzip.open if r1_fastq.endswith('.gz') else open
Expand All @@ -164,7 +238,7 @@ def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_b

header = {
'HD': {'VN': '1.6', 'SO': 'queryname'},
'PG': [{'ID': 'nimble-fastq-to-bam', 'PN': 'nimble', 'VN': '1.1', 'CL': 'single-tsv, cb-scoped-umi'}]
'PG': [{'ID': 'nimble-fastq-to-bam', 'PN': 'nimble', 'VN': '1.2', 'CL': 'whitelist-based CB correction'}]
}

print(f"Processing paired FASTQ files with {num_cores} threads...")
Expand All @@ -183,7 +257,8 @@ def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_b
for idx, (r1_record, r2_record) in enumerate(zip(r1_iter, r2_iter), start=1):
stats['total_pairs'] += 1
fut = executor.submit(process_pair, r1_record, r2_record,
cb_map, umi_map_by_raw_cb, stats, cb_length, umi_length)
whitelist, hamming_index, correction_cache, stats,
cb_length, umi_length)
futures[fut] = True

# Throttle in-flight futures to limit memory use
Expand Down Expand Up @@ -218,17 +293,28 @@ def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_b
print(f"Error during processing: {e}", file=sys.stderr)
sys.exit(1)

# Print summary
# Print summary with correction statistics
print("\n=== Processing Statistics ===")
ordered = [
'total_pairs', 'written_pairs',
'name_mismatch', 'too_short', 'no_remaining_seq',
'cb_not_in_map', 'cb_has_no_umi_map', 'umi_not_in_map_for_cb'
]
for k in ordered:
print(f"{k.replace('_', ' ').capitalize()}: {stats.get(k, 0)}")
# Print any other counters encountered
for k, v in stats.items():
if k not in ordered:
print(f"{k.replace('_', ' ').capitalize()}: {v}")
print(f"Output BAM written to: {output_bam}")
print(f"Total read pairs: {stats.get('total_pairs', 0)}")
print(f"Written pairs: {stats.get('written_pairs', 0)}")
print(f"\nCell Barcode Correction:")
print(f" Perfect matches: {stats.get('cb_perfect_match', 0)}")
print(f" Corrected (1-edit): {stats.get('cb_corrected', 0)}")
print(f" No valid correction: {stats.get('cb_no_correction', 0)}")

total_cb_processed = stats.get('cb_perfect_match', 0) + stats.get('cb_corrected', 0) + stats.get('cb_no_correction', 0)
if total_cb_processed > 0:
perfect_pct = 100.0 * stats.get('cb_perfect_match', 0) / total_cb_processed
corrected_pct = 100.0 * stats.get('cb_corrected', 0) / total_cb_processed
dropped_pct = 100.0 * stats.get('cb_no_correction', 0) / total_cb_processed
print(f" Correction rate: {perfect_pct:.2f}% perfect, {corrected_pct:.2f}% corrected, {dropped_pct:.2f}% dropped")

print(f"\nOther filters:")
print(f" Name mismatch: {stats.get('name_mismatch', 0)}")
print(f" Too short: {stats.get('too_short', 0)}")
print(f" No remaining sequence: {stats.get('no_remaining_seq', 0)}")

# Print unique cache size
print(f"\nCorrection cache size: {len(correction_cache)} unique raw CBs")

print(f"\nOutput BAM written to: {output_bam}")