Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions diskann-benchmark-runner/dev/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
* Licensed under the MIT license.
*/

use diskann_benchmark_runner::{app::App, output, registry};
use diskann_benchmark_runner::{output, App, Registry};

fn main() -> anyhow::Result<()> {
// Parse the command line options.
let app = App::parse();

// Gather the test inputs and outputs.
let mut inputs = registry::Inputs::new();
diskann_benchmark_runner::test::register_inputs(&mut inputs)?;
let mut registry = Registry::new();
diskann_benchmark_runner::test::register_benchmarks(&mut registry)?;

let mut benchmarks = registry::Benchmarks::new();
diskann_benchmark_runner::test::register_benchmarks(&mut benchmarks);

app.run(&inputs, &benchmarks, &mut output::default())
app.run(&registry, &mut output::default())
}
54 changes: 19 additions & 35 deletions diskann-benchmark-runner/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,13 @@
//! use diskann_benchmark_runner::{App, registry};
//!
//! fn main() -> anyhow::Result<()> {
//! let mut inputs = registry::Inputs::new();
//! // inputs.register::<MyInput>()?;
//!
//! let mut benchmarks = registry::Benchmarks::new();
//! // benchmarks.register::<MyBenchmark>("my-bench");
//! // benchmarks.register_regression::<MyRegressionBenchmark>("my-regression");
//! let mut registry = registry::Registry::new();
//! // registry.register("my-bench", MyBenchmark::default())?;
//! // registry.register_regression("my-regression", MyRegressionBenchmark::default())?;
//!
//! let app = App::parse();
//! let mut output = diskann_benchmark_runner::output::default();
//! app.run(&inputs, &benchmarks, &mut output)
//! app.run(&registry, &mut output)
//! }
//! ```
//!
Expand Down Expand Up @@ -192,15 +189,14 @@ impl App {
/// Run the application using the registered `inputs` and `benchmarks`.
pub fn run(
&self,
inputs: &registry::Inputs,
benchmarks: &registry::Benchmarks,
registry: &registry::Registry,
mut output: &mut dyn Output,
) -> anyhow::Result<()> {
match &self.command {
// If a named benchmark isn't given, then list the available benchmarks.
Commands::Inputs { describe } => {
if let Some(describe) = describe {
if let Some(input) = inputs.get(describe) {
if let Some(input) = registry.input(describe) {
let repr = jobs::Unprocessed::format_input(input)?;
writeln!(
output,
Expand All @@ -217,7 +213,7 @@ impl App {
}

writeln!(output, "Available input kinds are listed below:")?;
let mut tags: Vec<_> = inputs.tags().collect();
let mut tags: Vec<_> = registry.input_tags().collect();
tags.sort();
for i in tags.iter() {
writeln!(output, " {}", i)?;
Expand All @@ -226,7 +222,7 @@ impl App {
// List the available benchmarks.
Commands::Benchmarks {} => {
writeln!(output, "Registered Benchmarks:")?;
for (name, description) in benchmarks.names() {
for (name, description) in registry.names() {
write!(output, " {name}:")?;
if description.is_empty() {
writeln!(output)?;
Expand All @@ -248,11 +244,11 @@ impl App {
allow_debug,
} => {
// Parse and validate the input.
let run = Jobs::load(input_file, inputs)?;
let run = Jobs::load(input_file, registry)?;
// Check if we have a match for each benchmark.
for job in run.jobs().iter() {
const MAX_METHODS: usize = 3;
if let Err(mismatches) = benchmarks.debug(job, MAX_METHODS) {
if let Err(mismatches) = registry.debug(job, MAX_METHODS) {
let repr = serde_json::to_string_pretty(&job.serialize()?)?;

writeln!(
Expand Down Expand Up @@ -314,7 +310,7 @@ impl App {

// Run the specified job.
let checkpoint = Checkpoint::new(&serialized, &results, output_file)?;
let r = benchmarks.call(job, checkpoint, output)?;
let r = registry.call(job, checkpoint, output)?;

// Collect the results
results.push(r);
Expand All @@ -324,7 +320,7 @@ impl App {
}
}
// Extensions
Commands::Check(check) => return self.check(check, inputs, benchmarks, output),
Commands::Check(check) => return self.check(check, registry, output),
};
Ok(())
}
Expand All @@ -333,8 +329,7 @@ impl App {
fn check(
&self,
check: &Check,
inputs: &registry::Inputs,
benchmarks: &registry::Benchmarks,
registry: &registry::Registry,
mut output: &mut dyn Output,
) -> anyhow::Result<()> {
match check {
Expand All @@ -350,7 +345,7 @@ impl App {
Ok(())
}
Check::Tolerances { describe } => {
let tolerances = benchmarks.tolerances();
let tolerances = registry.tolerances();

match describe {
Some(name) => match tolerances.get(&**name) {
Expand Down Expand Up @@ -405,12 +400,7 @@ impl App {
tolerances,
input_file,
} => {
// For verification - we merely check that we can successfully construct
// the regression `Checks` struct. It performs all the necessary preflight
// checks.
let benchmarks = benchmarks.tolerances();
let _ =
internal::regression::Checks::new(tolerances, input_file, inputs, &benchmarks)?;
let _ = internal::regression::Checks::new(tolerances, input_file, registry)?;
Ok(())
}
Check::Run {
Expand All @@ -420,9 +410,7 @@ impl App {
after,
output_file,
} => {
let registered = benchmarks.tolerances();
let checks =
internal::regression::Checks::new(tolerances, input_file, inputs, &registered)?;
let checks = internal::regression::Checks::new(tolerances, input_file, registry)?;
let jobs = checks.jobs(before, after)?;
jobs.run(output, output_file.as_deref())?;
Ok(())
Expand Down Expand Up @@ -605,13 +593,9 @@ mod tests {
fn run(&self, tempdir: &Path) {
let apps = self.parse_stdin(tempdir);

// Register inputs
let mut inputs = registry::Inputs::new();
crate::test::register_inputs(&mut inputs).unwrap();

// Register outputs
let mut benchmarks = registry::Benchmarks::new();
crate::test::register_benchmarks(&mut benchmarks);
let mut registry = registry::Registry::new();
crate::test::register_benchmarks(&mut registry).unwrap();

// Run each app invocation - collecting the last output into a buffer.
//
Expand All @@ -631,7 +615,7 @@ mod tests {
&mut crate::output::Sink::new()
};

if let Err(err) = app.run(&inputs, &benchmarks, b) {
if let Err(err) = app.run(&registry, b) {
if is_last {
write!(b, "{:?}", err).unwrap();
} else {
Expand Down
4 changes: 2 additions & 2 deletions diskann-benchmark-runner/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub trait Benchmark: 'static {
///
/// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure.
///
/// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`]
/// On failure, returns `Err(FailureScore)`. In the [`crate::Registry`]
/// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations
/// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging.
fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;
Expand Down Expand Up @@ -90,7 +90,7 @@ impl std::fmt::Display for FailureScore {
/// The semantics of pass or failure are left solely to the discretion of the [`Regression`]
/// implementation.
///
/// See: [`register_regression`](crate::registry::Benchmarks::register_regression).
/// See: [`register_regression`](crate::Registry::register_regression).
pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
/// The tolerance [`Input`] associated with this regression check.
type Tolerances: Input + 'static;
Expand Down
14 changes: 12 additions & 2 deletions diskann-benchmark-runner/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait Input {
fn example() -> anyhow::Result<serde_json::Value>;
}

/// A registered input. See [`crate::registry::Inputs::get`].
/// A registered input. See [`crate::Registry::input`].
#[derive(Clone, Copy)]
pub struct Registered<'a>(pub(crate) &'a dyn DynInput);

Expand Down Expand Up @@ -110,11 +110,15 @@ pub(crate) trait DynInput {
checker: &mut Checker,
) -> anyhow::Result<Any>;
fn example(&self) -> anyhow::Result<serde_json::Value>;

// reflection
fn as_any(&self) -> &dyn std::any::Any;
fn type_name(&self) -> &'static str;
}

impl<T> DynInput for Wrapper<T>
where
T: Input,
T: Input + 'static,
{
fn tag(&self) -> &'static str {
T::tag()
Expand All @@ -129,4 +133,10 @@ where
fn example(&self) -> anyhow::Result<serde_json::Value> {
T::example()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn type_name(&self) -> &'static str {
std::any::type_name::<T>()
}
}
13 changes: 6 additions & 7 deletions diskann-benchmark-runner/src/internal/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,19 @@ impl<'a> Checks<'a> {
pub(crate) fn new(
tolerances: &Path,
input_file: &Path,
inputs: &registry::Inputs,
entries: &'a HashMap<&'static str, registry::RegisteredTolerance<'a>>,
registry: &'a registry::Registry,
) -> anyhow::Result<Self> {
// Load the raw input file.
let partial = jobs::Partial::load(input_file)?;

// Parse and validate the raw jobs against the registered inputs.
//
// This preserves the ordering of the jobs.
let inputs = jobs::Jobs::parse(&partial, inputs)?;
let inputs = jobs::Jobs::parse(&partial, registry)?;

// Now that the inputs have been fully parsed and validated, we then check that we
// can load the raw tolerance file.
let parsed = Raw::load(tolerances)?.parse(entries)?;
let parsed = Raw::load(tolerances)?.parse(&registry.tolerances())?;
Self::match_all(parsed, partial, inputs)
}

Expand Down Expand Up @@ -298,7 +297,7 @@ impl Raw {

fn parse<'a>(
self,
entries: &'a HashMap<&'static str, registry::RegisteredTolerance<'a>>,
entries: &HashMap<&'static str, registry::RegisteredTolerance<'a>>,
) -> anyhow::Result<Parsed<'a>> {
// Attempt to parse raw tolerances into registered tolerance inputs.
let num_checks = self.checks.len();
Expand Down Expand Up @@ -356,7 +355,7 @@ impl Raw {
.with_context(context)?;

Ok(ParsedInner {
entry,
entry: entry.clone(),
Comment thread
hildebrandmw marked this conversation as resolved.
tolerance: Rc::new(tolerance),
input: unprocessed.input,
})
Expand All @@ -382,7 +381,7 @@ impl Raw {
/// * The tag in `input` exists within at least one of the regressions in `entry`.
#[derive(Debug)]
struct ParsedInner<'a> {
entry: &'a registry::RegisteredTolerance<'a>,
entry: registry::RegisteredTolerance<'a>,
tolerance: Rc<Any>,
input: jobs::Unprocessed,
}
Expand Down
8 changes: 4 additions & 4 deletions diskann-benchmark-runner/src/jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::path::{Path, PathBuf};
use anyhow::Context;
use serde::{Deserialize, Serialize};

use crate::{checker::Checker, input, registry, Any};
use crate::{checker::Checker, input, Any, Registry};

#[derive(Debug)]
pub(crate) struct Jobs {
Expand All @@ -33,14 +33,14 @@ impl Jobs {
/// the post-load validation of the requested runs, including:
///
/// * Resolution of input files.
pub(crate) fn load(path: &Path, registry: &registry::Inputs) -> anyhow::Result<Self> {
pub(crate) fn load(path: &Path, registry: &Registry) -> anyhow::Result<Self> {
Self::parse(&Partial::load(path)?, registry)
}

/// Parse `self` from a [`Partial`].
///
/// This method also perform deserialization checks on the parsed inputs.
pub(crate) fn parse(partial: &Partial, registry: &registry::Inputs) -> anyhow::Result<Self> {
pub(crate) fn parse(partial: &Partial, registry: &Registry) -> anyhow::Result<Self> {
let mut checker = Checker::new(
partial
.search_directories
Expand All @@ -65,7 +65,7 @@ impl Jobs {
};

let input = registry
.get(&unprocessed.tag)
.input(&unprocessed.tag)
.ok_or_else(|| {
anyhow::anyhow!("Unrecognized input tag: \"{}\"", unprocessed.tag)
})
Expand Down
1 change: 1 addition & 0 deletions diskann-benchmark-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use benchmark::Benchmark;
pub use checker::{CheckDeserialization, Checker};
pub use input::Input;
pub use output::Output;
pub use registry::{Registry, RegistryError};
pub use result::Checkpoint;

#[cfg(any(test, feature = "test-app"))]
Expand Down
Loading
Loading