Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 82 additions & 41 deletions core/layers/throttle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(missing_docs)]

use std::future::Future;
use std::num::NonZeroU32;
use std::sync::Arc;

Expand All @@ -32,6 +33,33 @@ use governor::state::NotKeyed;
use opendal_core::raw::*;
use opendal_core::*;

/// ThrottleRateLimiter abstracts a rate-limit primitive used by
/// [`ThrottleLayer`].
pub trait ThrottleRateLimiter: Send + Sync + Clone + Unpin + 'static {
/// Block until `n` units of capacity are available.
///
/// Returns an error when the request can never be satisfied, for
/// example when `n` exceeds the limiter's burst/capacity.
fn until_n_ready(&self, n: NonZeroU32) -> impl Future<Output = Result<()>> + MaybeSend;
}

/// Share an atomic RateLimiter instance across all threads in one operator.
/// If want to add more observability in the future, replace the default NoOpMiddleware with other middleware types.
/// Read more about [Middleware](https://docs.rs/governor/latest/governor/middleware/index.html)
pub type SharedRateLimiter =
Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>;

impl ThrottleRateLimiter for SharedRateLimiter {
async fn until_n_ready(&self, n: NonZeroU32) -> Result<()> {
self.as_ref().until_n_ready(n).await.map_err(|_| {
Error::new(
ErrorKind::RateLimited,
"burst size is smaller than the request size",
)
})
}
}

/// Add a bandwidth rate limiter to the underlying services.
///
/// # Throttle
Expand Down Expand Up @@ -67,56 +95,77 @@ use opendal_core::*;
/// # }
/// ```
#[derive(Clone)]
pub struct ThrottleLayer {
bandwidth: NonZeroU32,
burst: NonZeroU32,
pub struct ThrottleLayer<L: ThrottleRateLimiter = SharedRateLimiter> {
rate_limiter: L,
}

impl ThrottleLayer {
impl ThrottleLayer<SharedRateLimiter> {
/// Create a new `ThrottleLayer` with given bandwidth and burst.
///
/// - bandwidth: the maximum number of bytes allowed to pass through per second.
/// - burst: the maximum number of bytes allowed to pass through at once.
pub fn new(bandwidth: u32, burst: u32) -> Self {
assert!(bandwidth > 0);
assert!(burst > 0);
Self {
bandwidth: NonZeroU32::new(bandwidth).unwrap(),
burst: NonZeroU32::new(burst).unwrap(),
}
let bandwidth = NonZeroU32::new(bandwidth).unwrap();
let burst = NonZeroU32::new(burst).unwrap();
let rate_limiter = Arc::new(RateLimiter::direct(
Quota::per_second(bandwidth).allow_burst(burst),
));
Self { rate_limiter }
}
}

impl<L: ThrottleRateLimiter> ThrottleLayer<L> {
/// Create a layer with any [`ThrottleRateLimiter`] implementation.
///
/// ```
/// # use std::num::NonZeroU32;
/// # use std::sync::Arc;
/// # use governor::Quota;
/// # use governor::RateLimiter;
/// # use opendal_layer_throttle::SharedRateLimiter;
/// # use opendal_layer_throttle::ThrottleLayer;
/// let limiter: SharedRateLimiter = Arc::new(RateLimiter::direct(
/// Quota::per_second(NonZeroU32::new(1024).unwrap())
/// .allow_burst(NonZeroU32::new(1024 * 1024).unwrap()),
/// ));
/// let _layer = ThrottleLayer::with_limiter(limiter);
/// ```
pub fn with_limiter(rate_limiter: L) -> Self {
Self { rate_limiter }
}
}

impl<A: Access> Layer<A> for ThrottleLayer {
type LayeredAccess = ThrottleAccessor<A>;
impl<A: Access, L: ThrottleRateLimiter> Layer<A> for ThrottleLayer<L> {
type LayeredAccess = ThrottleAccessor<A, L>;

fn layer(&self, inner: A) -> Self::LayeredAccess {
let rate_limiter = Arc::new(RateLimiter::direct(
Quota::per_second(self.bandwidth).allow_burst(self.burst),
));
ThrottleAccessor {
inner,
rate_limiter,
rate_limiter: self.rate_limiter.clone(),
}
}
}

/// Share an atomic RateLimiter instance across all threads in one operator.
/// If want to add more observability in the future, replace the default NoOpMiddleware with other middleware types.
/// Read more about [Middleware](https://docs.rs/governor/latest/governor/middleware/index.html)
type SharedRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>;

#[doc(hidden)]
#[derive(Debug)]
pub struct ThrottleAccessor<A: Access> {
pub struct ThrottleAccessor<A: Access, L: ThrottleRateLimiter> {
inner: A,
rate_limiter: SharedRateLimiter,
rate_limiter: L,
}

impl<A: Access> LayeredAccess for ThrottleAccessor<A> {
impl<A: Access, L: ThrottleRateLimiter> std::fmt::Debug for ThrottleAccessor<A, L> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThrottleAccessor")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}

impl<A: Access, L: ThrottleRateLimiter> LayeredAccess for ThrottleAccessor<A, L> {
type Inner = A;
type Reader = ThrottleWrapper<A::Reader>;
type Writer = ThrottleWrapper<A::Writer>;
type Reader = ThrottleWrapper<A::Reader, L>;
type Writer = ThrottleWrapper<A::Writer, L>;
type Lister = A::Lister;
type Deleter = A::Deleter;

Expand Down Expand Up @@ -152,27 +201,24 @@ impl<A: Access> LayeredAccess for ThrottleAccessor<A> {
}

#[doc(hidden)]
pub struct ThrottleWrapper<R> {
pub struct ThrottleWrapper<R, L> {
inner: R,
limiter: SharedRateLimiter,
limiter: L,
}

impl<R> ThrottleWrapper<R> {
fn new(inner: R, rate_limiter: SharedRateLimiter) -> Self {
Self {
inner,
limiter: rate_limiter,
}
impl<R, L> ThrottleWrapper<R, L> {
fn new(inner: R, limiter: L) -> Self {
Self { inner, limiter }
}
}

impl<R: oio::Read> oio::Read for ThrottleWrapper<R> {
impl<R: oio::Read, L: ThrottleRateLimiter> oio::Read for ThrottleWrapper<R, L> {
async fn read(&mut self) -> Result<Buffer> {
self.inner.read().await
}
}

impl<R: oio::Write> oio::Write for ThrottleWrapper<R> {
impl<R: oio::Write, L: ThrottleRateLimiter> oio::Write for ThrottleWrapper<R, L> {
async fn write(&mut self, bs: Buffer) -> Result<()> {
let len = bs.len();
if len == 0 {
Expand All @@ -189,12 +235,7 @@ impl<R: oio::Write> oio::Write for ThrottleWrapper<R> {
let buf_length =
NonZeroU32::new(len as u32).expect("len is non-zero so NonZeroU32 must exist");

self.limiter.until_n_ready(buf_length).await.map_err(|_| {
Error::new(
ErrorKind::RateLimited,
"burst size is smaller than the request size",
)
})?;
self.limiter.until_n_ready(buf_length).await?;

self.inner.write(bs).await
}
Expand Down
Loading