Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 0 additions & 239 deletions diskann-benchmark-runner/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
* Licensed under the MIT license.
*/

use crate::dispatcher::{DispatchRule, FailureScore, MatchScore};

/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization.
///
/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`]
Expand All @@ -15,9 +13,6 @@ pub struct Any {
tag: &'static str,
}

/// The score given unsuccessful downcasts in [`Any::try_match`].
pub const MATCH_FAIL: FailureScore = FailureScore(10_000);

impl Any {
/// Construct a new [`Any`] around `any` and associate it with the name `tag`.
///
Expand Down Expand Up @@ -97,158 +92,6 @@ impl Any {
self.any.as_any().downcast_ref::<T>()
}

/// Attempt to downcast self to `T` and if succssful, try matching `&T` with `U` using
/// [`crate::dispatcher::DispatchRule`].
///
/// Otherwise, return `Err(diskann_benchmark_runner::any::MATCH_FAIL)`.
///
/// ```rust
/// use diskann_benchmark_runner::{
/// any::Any,
/// dispatcher::{self, MatchScore, FailureScore},
/// utils::datatype::{self, DataType, Type},
/// };
///
/// let value = Any::new(DataType::Float32, "datatype");
///
/// // A successful down cast and successful match.
/// assert_eq!(
/// value.try_match::<DataType, Type<f32>>().unwrap(),
/// MatchScore(0),
/// );
///
/// // A successful down cast but unsuccessful match.
/// assert_eq!(
/// value.try_match::<DataType, Type<f64>>().unwrap_err(),
/// datatype::MATCH_FAIL,
/// );
///
/// // An unsuccessful down cast.
/// let value = Any::new(0usize, "usize");
/// assert_eq!(
/// value.try_match::<DataType, Type<f32>>().unwrap_err(),
/// diskann_benchmark_runner::any::MATCH_FAIL,
/// );
/// ```
pub fn try_match<'a, T, U>(&'a self) -> Result<MatchScore, FailureScore>
where
U: DispatchRule<&'a T>,
T: 'static,
{
if let Some(cast) = self.downcast_ref::<T>() {
U::try_match(&cast)
} else {
Err(MATCH_FAIL)
}
}

/// Attempt to downcast self to `T` and if succssful, try converting `&T` with `U` using
/// [`crate::dispatcher::DispatchRule`].
///
/// If unsuccessful, returns an error.
///
/// ```rust
/// use diskann_benchmark_runner::{
/// any::Any,
/// dispatcher::{self, MatchScore, FailureScore},
/// utils::datatype::{self, DataType, Type},
/// };
///
/// let value = Any::new(DataType::Float32, "datatype");
///
/// // A successful down cast and successful conversion.
/// let _: Type<f32> = value.convert::<DataType, _>().unwrap();
/// ```
pub fn convert<'a, T, U>(&'a self) -> anyhow::Result<U>
where
U: DispatchRule<&'a T>,
anyhow::Error: From<U::Error>,
T: 'static,
{
if let Some(cast) = self.downcast_ref::<T>() {
Ok(U::convert(cast)?)
} else {
Err(anyhow::Error::msg("invalid dispatch"))
}
}

/// A wrapper for [`DispatchRule::description`].
///
/// If `from` is `None` - document the expected tag for the input and return
/// `<U as DispatchRule<&T>>::description(f, None)`.
///
/// If `from` is `Some` - attempt to downcast to `T`. If successful, return the dispatch
/// rule description for `U` on the doncast reference. Otherwise, return the expected tag.
///
/// ```rust
/// use diskann_benchmark_runner::{
/// any::Any,
/// utils::datatype::{self, DataType, Type},
/// };
///
/// use std::io::Write;
///
/// struct Display(Option<Any>);
///
/// impl std::fmt::Display for Display {
/// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
/// match &self.0 {
/// Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&&v), "my-tag"),
/// None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
/// }
/// }
/// }
///
/// // No contained value - document the expected conversion.
/// assert_eq!(
/// Display(None).to_string(),
/// "tag \"my-tag\"\nfloat32",
/// );
///
/// // Matching contained value.
/// assert_eq!(
/// Display(Some(Any::new(DataType::Float32, "datatype"))).to_string(),
/// "successful match",
/// );
///
/// // Successful down cast - unsuccessful match.
/// assert_eq!(
/// Display(Some(Any::new(DataType::UInt64, "datatype"))).to_string(),
/// "expected \"float32\" but found \"uint64\"",
/// );
///
/// // Unsuccessful down cast.
/// assert_eq!(
/// Display(Some(Any::new(0usize, "another-tag"))).to_string(),
/// "expected tag \"my-tag\" - instead got \"another-tag\"",
/// );
/// ```
pub fn description<'a, T, U>(
f: &mut std::fmt::Formatter<'_>,
from: Option<&&'a Self>,
tag: impl std::fmt::Display,
) -> std::fmt::Result
where
U: DispatchRule<&'a T>,
T: 'static,
{
match from {
Some(this) => match this.downcast_ref::<T>() {
Some(a) => U::description(f, Some(&a)),
None => write!(
f,
"expected tag \"{}\" - instead got \"{}\"",
tag,
this.tag(),
),
},
None => {
writeln!(f, "tag \"{}\"", tag)?;
U::description(f, None::<&&T>)
}
}
}

/// Serialize the contained object to a [`serde_json::Value`].
pub fn serialize(&self) -> Result<serde_json::Value, serde_json::Error> {
self.any.dump()
Expand Down Expand Up @@ -308,8 +151,6 @@ where
mod tests {
use super::*;

use crate::utils::datatype::{self, DataType, Type};

#[test]
fn test_new() {
let x = Any::new(42usize, "my-tag");
Expand Down Expand Up @@ -352,84 +193,4 @@ mod tests {
serde_json::Value::Number(serde_json::value::Number::from_f64(1.5).unwrap())
);
}

#[test]
fn test_try_match() {
let value = Any::new(DataType::Float32, "random-tag");

// A successful down cast and successful match.
assert_eq!(
value.try_match::<DataType, Type<f32>>().unwrap(),
MatchScore(0),
);

// A successful down cast but unsuccessful match.
assert_eq!(
value.try_match::<DataType, Type<f64>>().unwrap_err(),
datatype::MATCH_FAIL,
);

// An unsuccessful down cast.
let value = Any::new(0usize, "");
assert_eq!(
value.try_match::<DataType, Type<f32>>().unwrap_err(),
MATCH_FAIL,
);
}

#[test]
fn test_convert() {
let value = Any::new(DataType::Float32, "random-tag");

// A successful down cast and successful conversion.
let _: Type<f32> = value.convert::<DataType, _>().unwrap();

// An invalid match should return an error.
let value = Any::new(0usize, "random-rag");
let err = value.convert::<DataType, Type<f32>>().unwrap_err();
let msg = err.to_string();
assert!(msg.contains("invalid dispatch"), "{}", msg);
}

#[test]
#[should_panic(expected = "invalid dispatch")]
fn test_convert_inner_error() {
let value = Any::new(DataType::Float32, "random-tag");
let _ = value.convert::<DataType, Type<u64>>();
}

#[test]
fn test_description() {
struct Display(Option<Any>);

impl std::fmt::Display for Display {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
Some(v) => Any::description::<DataType, Type<f32>>(f, Some(&v), "my-tag"),
None => Any::description::<DataType, Type<f32>>(f, None, "my-tag"),
}
}
}

// No contained value - document the expected conversion.
assert_eq!(Display(None).to_string(), "tag \"my-tag\"\nfloat32",);

// Matching contained value.
assert_eq!(
Display(Some(Any::new(DataType::Float32, ""))).to_string(),
"successful match",
);

// Successful down cast - unsuccessful match.
assert_eq!(
Display(Some(Any::new(DataType::UInt64, ""))).to_string(),
"expected \"float32\" but found \"uint64\"",
);

// Unsuccessful down cast.
assert_eq!(
Display(Some(Any::new(0usize, ""))).to_string(),
"expected tag \"my-tag\" - instead got \"\"",
);
}
}
30 changes: 26 additions & 4 deletions diskann-benchmark-runner/src/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

use serde::{Deserialize, Serialize};

use crate::{
dispatcher::{FailureScore, MatchScore},
Any, Checkpoint, Input, Output,
};
use crate::{Any, Checkpoint, Input, Output};

/// A registered benchmark.
///
Expand Down Expand Up @@ -60,6 +57,31 @@ pub trait Benchmark: 'static {
) -> anyhow::Result<Self::Output>;
}

/// Successful matches from [`Benchmark::try_match`] will return `MatchScores`.
///
/// A lower numerical value indicates a better match for purposes of overload resolution.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct MatchScore(pub u32);

impl std::fmt::Display for MatchScore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "success ({})", self.0)
}
}

/// Successful matches from [`Benchmark::try_match`] will return `FailureScores`.
///
/// A lower numerical value indicates a better match, which can help when compiling a
/// list of considered and rejected candidates.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct FailureScore(pub u32);

impl std::fmt::Display for FailureScore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "fail ({})", self.0)
}
}

/// A refinement of [`Benchmark`], that supports before/after comparison of generated results.
///
/// Benchmarks are associated with a "tolerance" input, which may contain runtime values
Expand Down
Loading
Loading