Skip to content

Commit e952da7

Browse files
authored
Add shuffle for search (#22)
1 parent 5d211d9 commit e952da7

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

src/svsbench/search.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
127127
choices=STR_TO_CALIBRATE_SEARCH_BUFFER.keys(),
128128
default="all",
129129
)
130+
parser.add_argument(
131+
"--shuffle_ground_truth",
132+
help="Shuffle ground truth. See also svsbench.build --shuffle",
133+
action="store_true",
134+
)
130135
return parser.parse_args(argv)
131136

132137

@@ -159,6 +164,8 @@ def search(
159164
lvq_strategy: svs.LVQStrategy | None = None,
160165
train_prefetchers: bool = True,
161166
search_buffer_optimization: svs.VamanaSearchBufferOptimization = svs.VamanaSearchBufferOptimization.All,
167+
shuffle_ground_truth: bool = False,
168+
seed: int = 42,
162169
):
163170
logger.info({"search_args": locals()})
164171
logger.info(utils.read_system_config())
@@ -209,6 +216,8 @@ def search(
209216

210217
query = svs.read_vecs(str(query_path))
211218
ground_truth = svs.read_vecs(str(ground_truth_path))
219+
if shuffle_ground_truth:
220+
np.random.default_rng(seed).shuffle(ground_truth)
212221

213222
for batch_size_idx, batch_size in enumerate(batch_sizes):
214223
index.num_threads = min(max_threads, batch_size)
@@ -372,6 +381,8 @@ def main(argv: str | None = None) -> None:
372381
search_buffer_optimization=STR_TO_CALIBRATE_SEARCH_BUFFER[
373382
args.calibrate_search_buffer
374383
],
384+
shuffle_ground_truth=args.shuffle_ground_truth,
385+
seed=args.seed,
375386
)
376387

377388

0 commit comments

Comments
 (0)