Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions deduplication/__main__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from deduplication.workflows import *
from deduplication.args import parse_args


args = parse_args()

args.sim_threshold = float(args.sim_threshold)

if args.mode == "bloom":
if args.single:
assert len(args.input) == 1 and len(args.minhash_dir) == 1 and len(args.name) == 1, "Expected single input argument but got a list"
dedup_single_bloom(args.input[0], args.minhash_dir[0], args.num, args.fp, args.output_file, args.name[0], args.sim_threshold, args.num_perm, args.save_dir, not args.skip_minhashing)
dedup_single_bloom(args.input[0], args.minhash_dir[0], args.num, args.fp, args.output_file, args.name[0], args.sim_threshold, args.num_perm, args.save_dir, not args.skip_minhashing, skip_insertion=args.skip_insertion)
elif args.multi:
dedup_multi_bloom(args.input, args.minhash_dir, args.num, args.fp, args.output_file, args.name, args.sim_threshold, args.num_perm, args.save_dir, not args.skip_minhashing)
dedup_multi_bloom(args.input, args.minhash_dir, args.num, args.fp, args.output_file, args.name, args.sim_threshold, args.num_perm, args.save_dir, not args.skip_minhashing, skip_insertion=args.skip_insertion)
else:
assert len(args.input) == 1 and len(args.minhash_dir) == 1 and len(args.name) == 1, "Expected single input argument but got a list"
dedup_single_file_bloom(args.input[0], args.minhash_dir[0], args.num, args.fp, args.output_file, args.name[0], args.sim_threshold, args.num_perm, args.save_dir, not args.skip_minhashing)
dedup_single_file_bloom(args.input[0], args.minhash_dir[0], args.num, args.fp, args.output_file, args.name[0], args.sim_threshold, args.num_perm, args.save_dir, not args.skip_minhashing, skip_insertion=args.skip_insertion)
else:
if args.single:
assert len(args.input) == 1 and len(args.minhash_dir) == 1 and len(args.name) == 1, "Expected single input argument but got a list"
Expand Down
7 changes: 7 additions & 0 deletions deduplication/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def parse_args():
parser.add_argument(
"--sim-threshold",
help="Jaccard Similarity threshold for deduplication, should be in [0, 1]. Default is 0.8",
type=float,
default=0.8,
)
parser.add_argument(
"--num-perm",
help="Number of hash functions for MinHashing. Default is 128",
type=int,
default=128,
)
parser.add_argument(
Expand Down Expand Up @@ -97,5 +99,10 @@ def parse_args():
help="If set, will skip the minhashing step of each workflow (useful if minhashes have been precomputed at minhash_dir)",
action="store_true"
)
parser.add_argument(
"--skip-insertion",
help="If set, will skip inserting unique documents into the index (works only with LSHBloom)",
action="store_true"
)

return parser.parse_args()
13 changes: 7 additions & 6 deletions deduplication/lshbloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, minhash_dir: str, lsh_params: Dict):
self.minhash_dir = minhash_dir
self.lsh = MinHashLSHBloom(**lsh_params)

def deduplicate_corpus(self) -> List[Tuple[str]]:
def deduplicate_corpus(self, skip_insertion: bool = False) -> List[Tuple[str]]:
"""
Deduplicates documents in the given corpus and adds them to the LSH index if appropriate.
Documents without existing duplicates will be stored in the LSH index for future deduplication.
Expand All @@ -45,12 +45,12 @@ def deduplicate_corpus(self) -> List[Tuple[str]]:
if f.endswith(".pkl")
]
for minhashfile in minhash_files:
dups = self.deduplicate_minhash_file(minhashfile)
dups = self.deduplicate_minhash_file(minhashfile, skip_insertion=skip_insertion)
duplicate_list.extend(dups)

return duplicate_list

def deduplicate_and_insert(self, params: Tuple) -> List[Tuple[str]]:
def deduplicate_and_insert(self, params: Tuple, skip_insertion: bool = False) -> List[Tuple[str]]:
"""
Deduplicates a MinHash signature corresponding to a document using the provided LSH index.
If the document is not duplicated in the LSH index, it is added to the index.
Expand All @@ -67,12 +67,13 @@ def deduplicate_and_insert(self, params: Tuple) -> List[Tuple[str]]:

# insert if not duplicated in index
if not result:
self.lsh.insert(m_query)
if not skip_insertion:
self.lsh.insert(m_query)
return None

return [(key,)]

def deduplicate_minhash_file(self, minhashfile: str) -> List[Tuple[str]]:
def deduplicate_minhash_file(self, minhashfile: str, skip_insertion: bool = False) -> List[Tuple[str]]:
"""
Deduplicate documents in the given minhash file and adds them to the LSH index if appropriate.
Documents without existing duplicates will be stored in the LSH index for future deduplication.
Expand All @@ -91,7 +92,7 @@ def deduplicate_minhash_file(self, minhashfile: str) -> List[Tuple[str]]:
# can't multiprocess here as insertion requires C++ dependencies that are not compatible with pickle
with tqdm(total=len(minhash_list), desc=fname) as pbar:
for i in range(len(minhash_list)):
result = self.deduplicate_and_insert(minhash_list[i])
result = self.deduplicate_and_insert(minhash_list[i], skip_insertion=skip_insertion)
if result:
duplicate_list.extend(result)
pbar.update()
Expand Down
103 changes: 102 additions & 1 deletion deduplication/minhash.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,116 @@
from tqdm.autonotebook import tqdm
from multiprocessing import Pool
from datasketch import MinHash
from typing import Optional
from typing import Optional, cast
from glob import glob
import logging

import pickle
import json
from functools import partial
import os
import collections.abc
from pathlib import Path
import zstandard
import gzip
import io
import pyarrow.parquet
import itertools

# TODO check if minhashes already exist, recompute only if forced

def compute_minhash_text(t: tuple[int, str], fname: str, num_perm: int) -> tuple[str, MinHash] | None:
lineNo, line = t
s = set(line.split())
if not s:
return None
m = MinHash(num_perm=num_perm)
for d in s:
m.update(d)
# generate a unique key for this document
key = f"{fname}-{lineNo}"
return (key, m)

SUPPORTED_FILETYPES = set([
*[
((format, encoding) if encoding != "" else (format,))
for format, encoding
in itertools.product(
[".jsonl", ".json"],
[".gz", ".zstd", ".zst", ""]
)],
(".parquet",)
])

def is_supported_dataset(p :Path) -> bool:
return tuple(p.suffixes[-2:]) in SUPPORTED_FILETYPES


def open_anydataset(p: Path) -> collections.abc.Generator[tuple[int, str]]:
match p.suffixes[-2:]:
case [(".json" |".jsonl")]:
with open(p, "r") as f:
for lineNo, line in enumerate(f):
try:
yield lineNo, cast(str, json.loads(line)["text"])
except GeneratorExit as e:
raise e
except:
logging.exception("failed to parse lineNo %s of \"%s\"", lineNo, p)
case [(".jsonl" | ".json"), (".zst"| ".zstd")]:
with open(p, "rb") as f:
dctx = zstandard.ZstdDecompressor()
stream_reader = dctx.stream_reader(f)
text_stream = io.TextIOWrapper(stream_reader, encoding='utf-8')
for lineNo, line in enumerate(text_stream):
try:
yield lineNo, cast(str, json.loads(line)["text"])
except GeneratorExit as e:
raise e
except:
logging.exception("failed to parse lineNo %s of \"%s\"", lineNo, p)
case [(".jsonl" | ".json"), ".gz"]:
with gzip.open(p, "r") as f:
for lineNo, line in enumerate(f):
try:
yield lineNo, cast(str, json.loads(line)["text"])
except GeneratorExit as e:
raise e
except:
logging.exception("failed to parse lineNo %s of \"%s\"", lineNo, p)
case ['.parquet']:
pq = pyarrow.parquet.ParquetFile(p)
idx = 0
for row_group in range(pq.num_row_groups):
table = pq.read_row_group(row_group, columns=["text"])
for v in table["text"]:
try:
yield idx, v.as_py()
idx += 1
except GeneratorExit as e:
raise e
except:
logging.exception("failed to parse lineNo %s of \"%s\"", idx, p)

case _:
raise NotImplementedError(f"{p.suffixes[-2:]} is not a supported filetype")

def compute_minhash_for_anyfile(infile: str, output_dir: str, num_perm: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this covers all filetypes, perhaps we should delete compute_minhash_for_file below and refactor our minhasher to use this function instead

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if this was used else where or upstream or if this was something you wrote. If you wrote this compute_minhash_for_file, I'll remove/refactor it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we can refactor this here. The only other thing to note is that the MinHasher class has another compute_minhash_for_file function that needs to be updated to reflect the new change - that MinHasher method is used elsewhere in workflows.py.

n = 50000
path = Path(infile)
fin = open_anydataset(path)
with Pool(32) as p, tqdm(total=n, desc=path.stem) as pbar:
minhash_list = []
partial_compute_minhash = partial(compute_minhash_text, fname=path.stem, num_perm=num_perm)
for result in p.imap_unordered(partial_compute_minhash, fin):
if result:
minhash_list.append(result)
pbar.update()
with open(f"{output_dir}/{path.stem[:-6]}.pkl", "wb") as fp:
pickle.dump(minhash_list, fp)
print(f"Generated MinHash for {len(minhash_list):,} documents in {path.stem}")


def compute_minhash_jsonl(t, fname, num_perm):
lineNo, line = t
lineNo += 1
Expand Down
10 changes: 7 additions & 3 deletions deduplication/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def dedup_single_bloom(
save_dir: str = "./",
compute_minhashes: bool = True,
clear: bool = False,
skip_insertion: bool = False,
):
if clear:
clear_dir(save_dir)
Expand All @@ -138,7 +139,7 @@ def dedup_single_bloom(
m.process()

index = LSHBloom(minhash_dir, lsh_params)
duplicates = index.deduplicate_corpus()
duplicates = index.deduplicate_corpus(skip_insertion=skip_insertion)
write_duplicates_to_csv(duplicates, csvfile, corpus_name, header=["dup_key"])


Expand All @@ -155,6 +156,7 @@ def dedup_multi_bloom(
save_dir: str = "./",
compute_minhashes: bool = True,
clear: bool = False,
skip_insertion: bool = False,
):
assert len(input_dirs) == len(minhash_dirs) == len(corpus_names), \
f"Expected len(input_dirs) == len(minhash_dirs) == len(corpus_names), got {len(input_dirs)}, {len(minhash_dirs)}, {len(corpus_names)}"
Expand All @@ -174,7 +176,8 @@ def dedup_multi_bloom(
n_hash_funcs,
save_dir,
compute_minhashes,
clear=False
clear=False,
skip_insertion=skip_insertion
)

def dedup_single_file_bloom(
Expand All @@ -189,6 +192,7 @@ def dedup_single_file_bloom(
save_dir: str = "./",
compute_minhashes: bool = True,
clear: bool = False,
skip_insertion: bool = False,
):
if clear:
clear_dir(save_dir)
Expand All @@ -208,5 +212,5 @@ def dedup_single_file_bloom(
fname = input_file.split("/")[-1]
minhash_file = f"{minhash_dir}/{fname[:-6]}.pkl"
index = LSHBloom(minhash_dir, lsh_params)
duplicates = index.deduplicate_minhash_file(minhash_file)
duplicates = index.deduplicate_minhash_file(minhash_file, skip_insertion=skip_insertion)
write_duplicates_to_csv(duplicates, csvfile, corpus_name, header=["dup_key"])
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,7 @@
'datasketch @ git+https://github.com/123epsilon/datasketch.git@060a32b4b4a2272d77480dd633a1bf770678ba49',
'pybloomfiltermmap3==0.5.7',
'tqdm>=4.60.0',
'zstandard>=0.23.0',
'pyarrow>=18.0.0',
]
)
)