Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions crates/wasi-http/src/p3/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(crate) enum Body {
}

impl Body {
/// Implementation of `consume-body` shared between requests and responses
pub(crate) fn consume<T>(
self,
mut store: Access<'_, T, WasiHttp>,
Expand Down Expand Up @@ -105,6 +106,7 @@ impl Body {
}
}

/// Implementation of `drop` shared between requests and responses
pub(crate) fn drop(self, mut store: impl AsContextMut) {
if let Body::Guest {
contents_rx,
Expand All @@ -120,7 +122,8 @@ impl Body {
}
}

pub(crate) enum GuestBodyKind {
/// The kind of body, used for error reporting
pub(crate) enum BodyKind {
Request,
Response,
}
Expand All @@ -141,20 +144,22 @@ impl ContentLength {
}
}

/// [StreamConsumer] implementation for bodies originating in the guest.
struct GuestBodyConsumer {
contents_tx: PollSender<Result<Bytes, ErrorCode>>,
result_tx: Option<oneshot::Sender<Result<(), ErrorCode>>>,
content_length: Option<ContentLength>,
kind: GuestBodyKind,
kind: BodyKind,
// `true` when the other side of `contents_tx` was unexpectedly closed
closed: bool,
}

impl GuestBodyConsumer {
/// Constructs the approprite body size error given the [BodyKind]
fn body_size_error(&self, n: Option<u64>) -> ErrorCode {
match self.kind {
GuestBodyKind::Request => ErrorCode::HttpRequestBodySize(n),
GuestBodyKind::Response => ErrorCode::HttpResponseBodySize(n),
BodyKind::Request => ErrorCode::HttpRequestBodySize(n),
BodyKind::Response => ErrorCode::HttpResponseBodySize(n),
}
}

Expand Down Expand Up @@ -235,20 +240,22 @@ impl<D> StreamConsumer<D> for GuestBodyConsumer {
}
}

/// [http_body::Body] implementation for bodies originating in the guest.
pub(crate) struct GuestBody {
contents_rx: Option<mpsc::Receiver<Result<Bytes, ErrorCode>>>,
trailers_rx: Option<oneshot::Receiver<Result<Option<Arc<http::HeaderMap>>, ErrorCode>>>,
content_length: Option<u64>,
}

impl GuestBody {
/// Construct a new [GuestBody]
pub(crate) fn new<T: 'static>(
mut store: impl AsContextMut<Data = T>,
contents_rx: Option<StreamReader<u8>>,
trailers_rx: FutureReader<Result<Option<Resource<Trailers>>, ErrorCode>>,
result_tx: oneshot::Sender<Result<(), ErrorCode>>,
content_length: Option<u64>,
kind: GuestBodyKind,
kind: BodyKind,
getter: fn(&mut T) -> WasiHttpCtxView<'_>,
) -> Self {
let (trailers_http_tx, trailers_http_rx) = oneshot::channel();
Expand Down Expand Up @@ -290,10 +297,15 @@ impl http_body::Body for GuestBody {
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
if let Some(contents_rx) = self.contents_rx.as_mut() {
// `contents_rx` has not been closed yet, poll it
while let Some(res) = ready!(contents_rx.poll_recv(cx)) {
match res {
Ok(buf) => {
if let Some(n) = self.content_length.as_mut() {
// Substract frame length from `content_length`,
// [GuestBodyConsumer] already performs the validation, so
// just keep count as optimization for
// `is_end_stream` and `size_hint`
*n = n.saturating_sub(buf.len().try_into().unwrap_or(u64::MAX));
}
return Poll::Ready(Some(Ok(http_body::Frame::data(buf))));
Expand All @@ -303,14 +315,17 @@ impl http_body::Body for GuestBody {
}
}
}
// Record that `contents_rx` is closed
self.contents_rx = None;
}

let Some(trailers_rx) = self.trailers_rx.as_mut() else {
// `trailers_rx` has already terminated - this is the end of stream
return Poll::Ready(None);
};

let res = ready!(Pin::new(trailers_rx).poll(cx));
// Record that `trailers_rx` has terminated
self.trailers_rx = None;
match res {
Ok(Ok(Some(trailers))) => Poll::Ready(Some(Ok(http_body::Frame::trailers(
Expand All @@ -328,14 +343,18 @@ impl http_body::Body for GuestBody {
|| !contents_rx.is_closed()
|| self.content_length.is_some_and(|n| n > 0)
{
// `contents_rx` might still produce data frames
return false;
}
}
if let Some(trailers_rx) = self.trailers_rx.as_ref() {
if !trailers_rx.is_terminated() {
// `trailers_rx` has not terminated yet
return false;
}
}

// no data left
return true;
}

Expand All @@ -348,6 +367,7 @@ impl http_body::Body for GuestBody {
}
}

/// [http_body::Body] that has been consumed.
pub(crate) struct ConsumedBody;

impl http_body::Body for ConsumedBody {
Expand All @@ -372,9 +392,10 @@ impl http_body::Body for ConsumedBody {
}
}

pub(crate) struct GuestTrailerConsumer<T> {
pub(crate) tx: Option<oneshot::Sender<Result<Option<Arc<HeaderMap>>, ErrorCode>>>,
pub(crate) getter: fn(&mut T) -> WasiHttpCtxView<'_>,
/// [FutureConsumer] implementation for trailers originating in the guest.
struct GuestTrailerConsumer<T> {
tx: Option<oneshot::Sender<Result<Option<Arc<HeaderMap>>, ErrorCode>>>,
getter: fn(&mut T) -> WasiHttpCtxView<'_>,
}

impl<D> FutureConsumer<D> for GuestTrailerConsumer<D>
Expand All @@ -387,12 +408,13 @@ where
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
mut store: StoreContextMut<D>,
mut source: Source<'_, Self::Item>,
mut src: Source<'_, Self::Item>,
_: bool,
) -> Poll<wasmtime::Result<()>> {
let value = &mut None;
source.read(store.as_context_mut(), value)?;
let res = match value.take().unwrap() {
let mut result = None;
src.read(store.as_context_mut(), &mut result)
.context("failed to read result")?;
let res = match result.context("result value missing")? {
Ok(Some(trailers)) => {
let WasiHttpCtxView { table, .. } = (self.getter)(store.data_mut());
let trailers = table
Expand All @@ -408,6 +430,7 @@ where
}
}

/// [StreamProducer] implementation for bodies originating in the host.
struct HostBodyStreamProducer<T> {
body: BoxBody<Bytes, ErrorCode>,
trailers: Option<oneshot::Sender<Result<Option<Resource<Trailers>>, ErrorCode>>>,
Expand Down Expand Up @@ -446,6 +469,8 @@ where
let cap = match dst.remaining(&mut store).map(NonZeroUsize::new) {
Some(Some(cap)) => Some(cap),
Some(None) => {
// On 0-length the best we can do is check that underlying stream has not
// reached the end yet
if self.body.is_end_stream() {
break 'result Ok(None);
} else {
Expand All @@ -462,11 +487,13 @@ where
let n = frame.len();
let cap = cap.into();
if n > cap {
// data frame does not fit in destination, fill it and buffer the rest
dst.set_buffer(Cursor::new(frame.split_off(cap)));
let mut dst = dst.as_direct(store, cap);
dst.remaining().copy_from_slice(&frame);
dst.mark_written(cap);
} else {
// copy the whole frame into the destination
let mut dst = dst.as_direct(store, n);
dst.remaining()[..n].copy_from_slice(&frame);
dst.mark_written(n);
Expand Down
116 changes: 103 additions & 13 deletions crates/wasi-http/src/p3/host/handler.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,74 @@
use crate::p3::bindings::http::handler::{Host, HostWithStore};
use crate::p3::bindings::http::types::{ErrorCode, Request, Response};
use crate::p3::body::{Body, ConsumedBody, GuestBody, GuestBodyKind};
use crate::p3::body::{Body, BodyKind, ConsumedBody, GuestBody};
use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView, get_content_length};
use anyhow::Context as _;
use core::pin::Pin;
use core::task::{Context, Poll, Waker};
use http::header::HOST;
use http::{HeaderValue, Uri};
use http_body_util::BodyExt as _;
use std::sync::Arc;
use tokio::sync::oneshot;
use tracing::debug;
use wasmtime::component::{Accessor, AccessorTask, Resource};
use wasmtime::component::{Accessor, AccessorTask, JoinHandle, Resource};

/// A wrapper around [`JoinHandle`], which will [`JoinHandle::abort`] the task
/// when dropped
struct AbortOnDropJoinHandle(JoinHandle);

impl Drop for AbortOnDropJoinHandle {
fn drop(&mut self) {
self.0.abort();
}
}

/// A wrapper around [http_body::Body], which allows attaching arbitrary state to it
struct BodyWithState<T, U> {
body: T,
_state: U,
}

impl<T, U> http_body::Body for BodyWithState<T, U>
where
T: http_body::Body + Unpin,
U: Unpin,
{
type Data = T::Data;
type Error = T::Error;

#[inline]
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
Pin::new(&mut self.get_mut().body).poll_frame(cx)
}

#[inline]
fn is_end_stream(&self) -> bool {
self.body.is_end_stream()
}

#[inline]
fn size_hint(&self) -> http_body::SizeHint {
self.body.size_hint()
}
}

trait BodyExt {
fn with_state<T>(self, state: T) -> BodyWithState<Self, T>
where
Self: Sized,
{
BodyWithState {
body: self,
_state: state,
}
}
}

impl<T> BodyExt for T {}

struct SendRequestTask {
io: Pin<Box<dyn Future<Output = Result<(), ErrorCode>> + Send>>,
Expand All @@ -26,14 +84,35 @@ impl<T> AccessorTask<T, WasiHttp, wasmtime::Result<()>> for SendRequestTask {
}
}

async fn io_task_result(
rx: oneshot::Receiver<(
Arc<AbortOnDropJoinHandle>,
oneshot::Receiver<Result<(), ErrorCode>>,
)>,
) -> Result<(), ErrorCode> {
let Ok((_io, io_result_rx)) = rx.await else {
return Ok(());
};
io_result_rx.await.unwrap_or(Ok(()))
}

impl HostWithStore for WasiHttp {
async fn handle<T>(
store: &Accessor<T, Self>,
req: Resource<Request>,
) -> HttpResult<Resource<Response>> {
let getter = store.getter();
// A handle to the I/O task, if spawned, will be sent on this channel
// and kept as part of request body state
let (io_task_tx, io_task_rx) = oneshot::channel();

// A handle to the I/O task, if spawned, will be sent on this channel
// along with the result receiver
let (io_result_tx, io_result_rx) = oneshot::channel();

// Response processing result will be sent on this channel
let (res_result_tx, res_result_rx) = oneshot::channel();

let getter = store.getter();
let fut = store.with(|mut store| {
let WasiHttpCtxView { table, .. } = store.get();
let Request {
Expand All @@ -56,30 +135,30 @@ impl HostWithStore for WasiHttp {
result_tx,
} => {
let (http_result_tx, http_result_rx) = oneshot::channel();
// `Content-Length` header value is validated in `fields` implementation
let content_length = get_content_length(&headers)
.map_err(|err| ErrorCode::InternalError(Some(format!("{err:#}"))))?;
_ = result_tx.send(Box::new(async move {
if let Ok(Err(err)) = http_result_rx.await {
return Err(err);
};
io_result_rx.await.unwrap_or(Ok(()))
io_task_result(io_result_rx).await
}));
GuestBody::new(
&mut store,
contents_rx,
trailers_rx,
http_result_tx,
content_length,
GuestBodyKind::Request,
BodyKind::Request,
getter,
)
.with_state(io_task_rx)
.boxed()
}
Body::Host { body, result_tx } => {
_ = result_tx.send(Box::new(
async move { io_result_rx.await.unwrap_or(Ok(())) },
));
body
_ = result_tx.send(Box::new(io_task_result(io_result_rx)));
body.with_state(io_task_rx).boxed()
}
Body::Consumed => ConsumedBody.boxed(),
};
Expand Down Expand Up @@ -121,6 +200,7 @@ impl HostWithStore for WasiHttp {
req,
options.as_deref().copied(),
Box::new(async {
// Forward the response processing result to `WasiHttpCtx` implementation
let Ok(fut) = res_result_rx.await else {
return Ok(());
};
Expand All @@ -129,16 +209,26 @@ impl HostWithStore for WasiHttp {
))
})?;
let (res, io) = Box::into_pin(fut).await?;
store.spawn(SendRequestTask {
io: Box::into_pin(io),
result_tx: io_result_tx,
});
let (
http::response::Parts {
status, headers, ..
},
body,
) = res.into_parts();

let mut io = Box::into_pin(io);
let body = match io.as_mut().poll(&mut Context::from_waker(Waker::noop()))? {
Poll::Ready(()) => body,
Poll::Pending => {
// I/O driver still needs to be polled, spawn a task and send handles to it
let (tx, rx) = oneshot::channel();
let io = store.spawn(SendRequestTask { io, result_tx: tx });
let io = Arc::new(AbortOnDropJoinHandle(io));
_ = io_result_tx.send((Arc::clone(&io), rx));
_ = io_task_tx.send(Arc::clone(&io));
body.with_state(io).boxed()
}
};
let res = Response {
status,
headers: Arc::new(headers),
Expand Down
Loading