Skip to content

Commit 56797e8

Browse files
author
Sebastian Benjamin
committed
Add cb correction
1 parent 6be1237 commit 56797e8

File tree

2 files changed

+176
-86
lines changed

2 files changed

+176
-86
lines changed

nimble/__main__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,11 @@ def sort_input_bam(bam, cores, tmp_dir):
424424
fastq_to_bam_parser = subparsers.add_parser('fastq-to-bam')
425425
fastq_to_bam_parser.add_argument('--r1-fastq', help='Path to R1 FASTQ file.', type=str, required=True)
426426
fastq_to_bam_parser.add_argument('--r2-fastq', help='Path to R2 FASTQ file.', type=str, required=True)
427-
fastq_to_bam_parser.add_argument("--map", required=True, help="TSV(.gz) with columns: rawCB, correctedCB, rawUMI, correctedUMI")
427+
fastq_to_bam_parser.add_argument("--map", required=True, help="Cell barcode whitelist file (one CB per line, .gz or plain text)")
428428
fastq_to_bam_parser.add_argument('--output', help='Path for output BAM file.', type=str, required=True)
429429
fastq_to_bam_parser.add_argument('-c', '--num_cores', help='The number of cores to use for processing.', type=int, default=1)
430+
fastq_to_bam_parser.add_argument('--cb-length', help='Length of cell barcode (default: 16).', type=int, default=16)
431+
fastq_to_bam_parser.add_argument('--umi-length', help='Length of UMI (default: 12).', type=int, default=12)
430432

431433
args = parser.parse_args()
432434

@@ -445,7 +447,9 @@ def sort_input_bam(bam, cores, tmp_dir):
445447
args.r2_fastq,
446448
args.map,
447449
args.output,
448-
args.num_cores
450+
args.num_cores,
451+
args.cb_length,
452+
args.umi_length
449453
)
450454
elif args.subcommand == 'plot':
451455
if os.path.getsize(args.input_file) > 0:

nimble/fastq_barcode_processor.py

Lines changed: 170 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,56 +6,126 @@
66
import pysam
77
from Bio import SeqIO
88

9-
def load_cb_umi_mapping(tsv_path):
10-
"""
11-
Load a TSV with 4 columns:
12-
1) Raw CB
13-
2) Corrected CB
14-
3) Raw UMI
15-
4) Corrected UMI
169

10+
def hamming_distance(s1, s2):
11+
"""Calculate Hamming distance between two strings of equal length."""
12+
if len(s1) != len(s2):
13+
return float('inf')
14+
return sum(c1 != c2 for c1, c2 in zip(s1, s2))
15+
16+
17+
def build_hamming_index(whitelist):
18+
"""
19+
Build an index mapping each valid CB to all possible 1-edit variants.
20+
This allows efficient lookup of correction candidates.
21+
1722
Returns:
18-
cb_map: dict[str raw_cb] = str corrected_cb
19-
umi_map_by_raw_cb: dict[str raw_cb] = dict[str raw_umi] = str corrected_umi
23+
dict[str variant] = set[str valid_cb]
2024
"""
21-
cb_map = {}
22-
umi_map_by_raw_cb = defaultdict(dict)
23-
24-
open_func = gzip.open if tsv_path.endswith('.gz') else open
25-
mode = 'rt' if tsv_path.endswith('.gz') else 'r'
25+
bases = ['A', 'C', 'G', 'T', 'N']
26+
hamming_index = defaultdict(set)
27+
28+
for valid_cb in whitelist:
29+
# For each position, generate all possible single-base substitutions
30+
for i in range(len(valid_cb)):
31+
for base in bases:
32+
if base != valid_cb[i]:
33+
variant = valid_cb[:i] + base + valid_cb[i+1:]
34+
hamming_index[variant].add(valid_cb)
35+
36+
return hamming_index
37+
38+
39+
def load_cb_whitelist(whitelist_path):
40+
"""
41+
Load a cell barcode whitelist (one CB per line).
42+
Build a Hamming distance = 1 index for efficient correction.
2643
27-
line_ct = 0
28-
malformed = 0
29-
pairs_ct = 0
30-
unique_cb = set()
44+
Returns:
45+
whitelist: set of valid cell barcodes
46+
hamming_index: dict mapping variants to valid CBs
47+
"""
48+
whitelist = set()
49+
50+
open_func = gzip.open if whitelist_path.endswith('.gz') else open
51+
mode = 'rt' if whitelist_path.endswith('.gz') else 'r'
3152

3253
try:
33-
with open_func(tsv_path, mode) as f:
34-
for line_num, line in enumerate(f, start=1):
54+
with open_func(whitelist_path, mode) as f:
55+
for line in f:
3556
line = line.strip()
36-
if not line:
37-
continue
38-
parts = line.split('\t')
39-
if len(parts) < 4:
40-
malformed += 1
41-
continue
42-
43-
raw_cb, corr_cb, raw_umi, corr_umi = parts[0], parts[1], parts[2], parts[3]
44-
cb_map[raw_cb] = corr_cb
45-
umi_map_by_raw_cb[raw_cb][raw_umi] = corr_umi
46-
pairs_ct += 1
47-
unique_cb.add(raw_cb)
48-
line_ct += 1
57+
if line:
58+
whitelist.add(line)
4959
except Exception as e:
50-
print(f"Error loading CB/UMI mapping from {tsv_path}: {e}", file=sys.stderr)
60+
print(f"Error loading CB whitelist from {whitelist_path}: {e}", file=sys.stderr)
5161
sys.exit(1)
5262

53-
print(f"Loaded mapping from {tsv_path}")
54-
print(f" Lines read: {line_ct + malformed}")
55-
print(f" Malformed lines skipped (<4 cols): {malformed}")
56-
print(f" Unique raw CBs: {len(unique_cb)}")
57-
print(f" (raw CB, raw UMI) pairs: {pairs_ct}")
58-
return cb_map, umi_map_by_raw_cb
63+
print(f"Loaded whitelist from {whitelist_path}")
64+
print(f" Valid cell barcodes: {len(whitelist)}")
65+
66+
print("Building Hamming distance index...")
67+
hamming_index = build_hamming_index(whitelist)
68+
print(f" Indexed {len(hamming_index)} variants")
69+
70+
return whitelist, hamming_index
71+
72+
73+
def correct_cell_barcode(raw_cb, quality_scores, whitelist, hamming_index, correction_cache):
74+
"""
75+
Correct a raw cell barcode using 10x-style correction:
76+
1. Check for perfect match
77+
2. If not, find all candidates with Hamming distance = 1
78+
3. Among candidates, select the one where the differing base has the lowest quality score
79+
80+
Args:
81+
raw_cb: Raw cell barcode string
82+
quality_scores: List of quality scores for the CB region
83+
whitelist: Set of valid cell barcodes
84+
hamming_index: Dict mapping variants to valid CBs
85+
correction_cache: Dict for caching corrections
86+
87+
Returns:
88+
Corrected CB string, or None if no valid correction found
89+
"""
90+
# Check cache first
91+
if raw_cb in correction_cache:
92+
return correction_cache[raw_cb]
93+
94+
# Check for perfect match
95+
if raw_cb in whitelist:
96+
correction_cache[raw_cb] = raw_cb
97+
return raw_cb
98+
99+
# Find candidates with Hamming distance = 1
100+
candidates = hamming_index.get(raw_cb, set())
101+
102+
if not candidates:
103+
correction_cache[raw_cb] = None
104+
return None
105+
106+
if len(candidates) == 1:
107+
# Only one candidate, use it
108+
corrected = next(iter(candidates))
109+
correction_cache[raw_cb] = corrected
110+
return corrected
111+
112+
# Multiple candidates: select based on quality scores
113+
# Find the differing position and pick candidate with lowest quality at that position
114+
best_candidate = None
115+
lowest_quality = float('inf')
116+
117+
for candidate in candidates:
118+
# Find the position where they differ
119+
for i, (raw_base, cand_base) in enumerate(zip(raw_cb, candidate)):
120+
if raw_base != cand_base:
121+
qual = quality_scores[i]
122+
if qual < lowest_quality:
123+
lowest_quality = qual
124+
best_candidate = candidate
125+
break
126+
127+
correction_cache[raw_cb] = best_candidate
128+
return best_candidate
59129

60130

61131
def parse_10x_barcode_from_r1(sequence, cb_length=16, umi_length=12):
@@ -71,13 +141,12 @@ def parse_10x_barcode_from_r1(sequence, cb_length=16, umi_length=12):
71141
return cell_barcode, umi, remaining_sequence
72142

73143

74-
def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_length=16, umi_length=12):
144+
def process_pair(r1_record, r2_record, whitelist, hamming_index, correction_cache, stats, cb_length=16, umi_length=12):
75145
"""
76146
Process a single FASTQ pair and return (r1_bam, r2_bam), or None if skipped.
77-
Mapping logic:
78-
- Look up corrected CB by raw CB.
79-
- Look up corrected UMI by (raw CB, raw UMI).
80-
Both must exist to emit a pair.
147+
Correction logic:
148+
- Correct CB using 10x-style Hamming distance = 1 correction with quality scores
149+
- Use raw UMI (no correction)
81150
"""
82151
# Normalize names (handle /1 and /2 suffixes if present)
83152
r1_name = r1_record.id.removesuffix('/1')
@@ -87,33 +156,33 @@ def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_leng
87156
return None
88157

89158
r1_seq = str(r1_record.seq)
90-
cell_barcode, umi, remaining_r1_seq = parse_10x_barcode_from_r1(r1_seq, cb_length, umi_length)
91-
if cell_barcode is None or umi is None:
159+
raw_cb, umi, remaining_r1_seq = parse_10x_barcode_from_r1(r1_seq, cb_length, umi_length)
160+
if raw_cb is None or umi is None:
92161
stats['too_short'] += 1
93162
return None
94163
if len(remaining_r1_seq) == 0:
95164
stats['no_remaining_seq'] += 1
96165
return None
97166

98-
# CB and UMI corrections (UMI in context of *raw* CB)
99-
corrected_cb = cb_map.get(cell_barcode)
167+
# Get quality scores for the CB region
168+
cb_quality_scores = r1_record.letter_annotations["phred_quality"][:cb_length]
169+
170+
# Correct cell barcode using 10x-style correction
171+
corrected_cb = correct_cell_barcode(raw_cb, cb_quality_scores, whitelist, hamming_index, correction_cache)
172+
100173
if corrected_cb is None:
101-
stats['cb_not_in_map'] += 1
102-
return None
103-
104-
umi_map_for_cb = umi_map_by_raw_cb.get(cell_barcode)
105-
if not umi_map_for_cb:
106-
stats['cb_has_no_umi_map'] += 1
107-
return None
108-
109-
corrected_umi = umi_map_for_cb.get(umi)
110-
if corrected_umi is None:
111-
stats['umi_not_in_map_for_cb'] += 1
174+
stats['cb_no_correction'] += 1
112175
return None
176+
177+
# Track correction statistics
178+
if corrected_cb == raw_cb:
179+
stats['cb_perfect_match'] += 1
180+
else:
181+
stats['cb_corrected'] += 1
113182

114183
barcode_length = cb_length + umi_length
115184

116-
# Build unaligned BAM records
185+
# Build unaligned BAM records with corrected CB and raw UMI
117186
r1_bam = pysam.AlignedSegment()
118187
r1_bam.query_name = r1_name
119188
r1_bam.query_sequence = remaining_r1_seq
@@ -124,7 +193,7 @@ def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_leng
124193
r1_bam.reference_start = -1
125194
r1_bam.mapping_quality = 0
126195
r1_bam.set_tag("CB", corrected_cb)
127-
r1_bam.set_tag("UB", corrected_umi)
196+
r1_bam.set_tag("UB", umi) # Use raw UMI
128197

129198
r2_bam = pysam.AlignedSegment()
130199
r2_bam.query_name = r2_name
@@ -135,26 +204,31 @@ def process_pair(r1_record, r2_record, cb_map, umi_map_by_raw_cb, stats, cb_leng
135204
r2_bam.reference_start = -1
136205
r2_bam.mapping_quality = 0
137206
r2_bam.set_tag("CB", corrected_cb)
138-
r2_bam.set_tag("UB", corrected_umi)
207+
r2_bam.set_tag("UB", umi) # Use raw UMI
139208

140209
return r1_bam, r2_bam
141210

142211

143-
def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_bam, num_cores=1, cb_length=16, umi_length=12):
212+
def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_whitelist_file, output_bam, num_cores=1, cb_length=16, umi_length=12):
144213
"""
145214
Convert paired FASTQ files to unaligned BAM with CB/UB tags using multiple threads.
215+
Performs 10x-style cell barcode correction using a whitelist.
146216
147217
Args:
148218
r1_fastq: path to R1 FASTQ(.gz)
149219
r2_fastq: path to R2 FASTQ(.gz)
150-
cb_umi_mapping_file: TSV(.gz) with 4 columns (rawCB, corrCB, rawUMI, corrUMI)
220+
cb_whitelist_file: path to cell barcode whitelist file (one CB per line, .gz or plain text)
151221
output_bam: path to output BAM
152222
num_cores: threads
153-
cb_length, umi_length: lengths for parsing R1
223+
cb_length: length of cell barcode (default 16)
224+
umi_length: length of UMI (default 12)
154225
"""
155-
print("Loading CB/UMI mapping...")
156-
cb_map, umi_map_by_raw_cb = load_cb_umi_mapping(cb_umi_mapping_file)
157-
226+
print("Loading cell barcode whitelist...")
227+
whitelist, hamming_index = load_cb_whitelist(cb_whitelist_file)
228+
229+
# Initialize correction cache (shared across threads, but thread-safe via GIL for dict operations)
230+
correction_cache = {}
231+
158232
stats = defaultdict(int)
159233

160234
r1_open_func = gzip.open if r1_fastq.endswith('.gz') else open
@@ -164,7 +238,7 @@ def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_b
164238

165239
header = {
166240
'HD': {'VN': '1.6', 'SO': 'queryname'},
167-
'PG': [{'ID': 'nimble-fastq-to-bam', 'PN': 'nimble', 'VN': '1.1', 'CL': 'single-tsv, cb-scoped-umi'}]
241+
'PG': [{'ID': 'nimble-fastq-to-bam', 'PN': 'nimble', 'VN': '1.2', 'CL': 'whitelist-based CB correction'}]
168242
}
169243

170244
print(f"Processing paired FASTQ files with {num_cores} threads...")
@@ -183,7 +257,8 @@ def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_b
183257
for idx, (r1_record, r2_record) in enumerate(zip(r1_iter, r2_iter), start=1):
184258
stats['total_pairs'] += 1
185259
fut = executor.submit(process_pair, r1_record, r2_record,
186-
cb_map, umi_map_by_raw_cb, stats, cb_length, umi_length)
260+
whitelist, hamming_index, correction_cache, stats,
261+
cb_length, umi_length)
187262
futures[fut] = True
188263

189264
# Throttle in-flight futures to limit memory use
@@ -218,17 +293,28 @@ def fastq_to_bam_with_barcodes(r1_fastq, r2_fastq, cb_umi_mapping_file, output_b
218293
print(f"Error during processing: {e}", file=sys.stderr)
219294
sys.exit(1)
220295

221-
# Print summary
296+
# Print summary with correction statistics
222297
print("\n=== Processing Statistics ===")
223-
ordered = [
224-
'total_pairs', 'written_pairs',
225-
'name_mismatch', 'too_short', 'no_remaining_seq',
226-
'cb_not_in_map', 'cb_has_no_umi_map', 'umi_not_in_map_for_cb'
227-
]
228-
for k in ordered:
229-
print(f"{k.replace('_', ' ').capitalize()}: {stats.get(k, 0)}")
230-
# Print any other counters encountered
231-
for k, v in stats.items():
232-
if k not in ordered:
233-
print(f"{k.replace('_', ' ').capitalize()}: {v}")
234-
print(f"Output BAM written to: {output_bam}")
298+
print(f"Total read pairs: {stats.get('total_pairs', 0)}")
299+
print(f"Written pairs: {stats.get('written_pairs', 0)}")
300+
print(f"\nCell Barcode Correction:")
301+
print(f" Perfect matches: {stats.get('cb_perfect_match', 0)}")
302+
print(f" Corrected (1-edit): {stats.get('cb_corrected', 0)}")
303+
print(f" No valid correction: {stats.get('cb_no_correction', 0)}")
304+
305+
total_cb_processed = stats.get('cb_perfect_match', 0) + stats.get('cb_corrected', 0) + stats.get('cb_no_correction', 0)
306+
if total_cb_processed > 0:
307+
perfect_pct = 100.0 * stats.get('cb_perfect_match', 0) / total_cb_processed
308+
corrected_pct = 100.0 * stats.get('cb_corrected', 0) / total_cb_processed
309+
dropped_pct = 100.0 * stats.get('cb_no_correction', 0) / total_cb_processed
310+
print(f" Correction rate: {perfect_pct:.2f}% perfect, {corrected_pct:.2f}% corrected, {dropped_pct:.2f}% dropped")
311+
312+
print(f"\nOther filters:")
313+
print(f" Name mismatch: {stats.get('name_mismatch', 0)}")
314+
print(f" Too short: {stats.get('too_short', 0)}")
315+
print(f" No remaining sequence: {stats.get('no_remaining_seq', 0)}")
316+
317+
# Print unique cache size
318+
print(f"\nCorrection cache size: {len(correction_cache)} unique raw CBs")
319+
320+
print(f"\nOutput BAM written to: {output_bam}")

0 commit comments

Comments
 (0)