@@ -127,6 +127,11 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
127127 choices = STR_TO_CALIBRATE_SEARCH_BUFFER .keys (),
128128 default = "all" ,
129129 )
130+ parser .add_argument (
131+ "--shuffle_ground_truth" ,
132+ help = "Shuffle ground truth. See also svsbench.build --shuffle" ,
133+ action = "store_true" ,
134+ )
130135 return parser .parse_args (argv )
131136
132137
@@ -159,6 +164,8 @@ def search(
159164 lvq_strategy : svs .LVQStrategy | None = None ,
160165 train_prefetchers : bool = True ,
161166 search_buffer_optimization : svs .VamanaSearchBufferOptimization = svs .VamanaSearchBufferOptimization .All ,
167+ shuffle_ground_truth : bool = False ,
168+ seed : int = 42 ,
162169):
163170 logger .info ({"search_args" : locals ()})
164171 logger .info (utils .read_system_config ())
@@ -209,6 +216,8 @@ def search(
209216
210217 query = svs .read_vecs (str (query_path ))
211218 ground_truth = svs .read_vecs (str (ground_truth_path ))
219+ if shuffle_ground_truth :
220+ np .random .default_rng (seed ).shuffle (ground_truth )
212221
213222 for batch_size_idx , batch_size in enumerate (batch_sizes ):
214223 index .num_threads = min (max_threads , batch_size )
@@ -372,6 +381,8 @@ def main(argv: str | None = None) -> None:
372381 search_buffer_optimization = STR_TO_CALIBRATE_SEARCH_BUFFER [
373382 args .calibrate_search_buffer
374383 ],
384+ shuffle_ground_truth = args .shuffle_ground_truth ,
385+ seed = args .seed ,
375386 )
376387
377388
0 commit comments