Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions integrations/dav-server/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use dav_server::fs::{DavFile, OpenOptions};
use dav_server::fs::{DavMetaData, FsResult};
use dav_server::fs::{FsError, FsFuture};
use futures::FutureExt;
use futures::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
use opendal::{FuturesAsyncReader, FuturesAsyncWriter, Operator};
use futures::{AsyncReadExt, AsyncSeekExt};
use opendal::{Buffer, FuturesAsyncReader, Operator, Writer};

use super::metadata::OpendalMetaData;
use super::utils::*;
Expand Down Expand Up @@ -54,7 +54,7 @@ impl Debug for OpendalFile {

enum State {
Read(FuturesAsyncReader),
Write(FuturesAsyncWriter),
Write(Option<Writer>),
}

impl OpendalFile {
Expand All @@ -74,9 +74,8 @@ impl OpendalFile {
.writer_with(&path)
.append(options.append)
.await
.map_err(convert_error)?
.into_futures_async_write();
State::Write(w)
.map_err(convert_error)?;
State::Write(Some(w))
} else {
return Err(FsError::NotImplemented);
};
Expand Down Expand Up @@ -106,25 +105,37 @@ impl DavFile for OpendalFile {

fn write_buf(&mut self, mut buf: Box<dyn Buf + Send>) -> FsFuture<'_, ()> {
async move {
let State::Write(w) = &mut self.state else {
let State::Write(Some(w)) = &mut self.state else {
return Err(FsError::GeneralFailure);
};

w.write_all(&buf.copy_to_bytes(buf.remaining()))
if w.write(Buffer::from(buf.copy_to_bytes(buf.remaining())))
.await
.map_err(|_| FsError::GeneralFailure)?;
.is_err()
{
let _ = w.abort().await;
self.state = State::Write(None);
return Err(FsError::GeneralFailure);
}

Ok(())
}
.boxed()
}

fn write_bytes(&mut self, buf: Bytes) -> FsFuture<'_, ()> {
async move {
let State::Write(w) = &mut self.state else {
let State::Write(Some(w)) = &mut self.state else {
return Err(FsError::GeneralFailure);
};

w.write_all(&buf).await.map_err(|_| FsError::GeneralFailure)
if w.write(Buffer::from(buf)).await.is_err() {
let _ = w.abort().await;
self.state = State::Write(None);
return Err(FsError::GeneralFailure);
}

Ok(())
}
.boxed()
}
Expand Down Expand Up @@ -158,12 +169,18 @@ impl DavFile for OpendalFile {

fn flush(&mut self) -> FsFuture<'_, ()> {
async move {
let State::Write(w) = &mut self.state else {
let State::Write(Some(w)) = &mut self.state else {
return Err(FsError::GeneralFailure);
};

w.flush().await.map_err(|_| FsError::GeneralFailure)?;
w.close().await.map_err(|_| FsError::GeneralFailure)
if w.close().await.is_err() {
let _ = w.abort().await;
self.state = State::Write(None);
return Err(FsError::GeneralFailure);
}

self.state = State::Write(None);
Ok(())
}
.boxed()
}
Expand Down
165 changes: 165 additions & 0 deletions integrations/dav-server/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,17 @@ use dav_server::fs::OpenOptions;
use dav_server::fs::{DavFileSystem, ReadDirMeta};
use dav_server_opendalfs::OpendalFs;
use futures::StreamExt;
use opendal::Buffer;
use opendal::Operator;
use opendal::raw::oio;
use opendal::raw::{Access, AccessorInfo, OpWrite, RpWrite};
use opendal::services::Fs;
use opendal::{Capability, Error, ErrorKind, Metadata};
use std::fmt::{Debug, Formatter};
use std::fs;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

#[tokio::test]
async fn test() -> Result<()> {
Expand Down Expand Up @@ -213,3 +220,161 @@ async fn test_read_dir() {

fs::remove_dir_all(TMP_PATH).unwrap();
}

#[derive(Clone)]
struct AbortTrackingAccess {
info: Arc<AccessorInfo>,
aborted: Arc<AtomicBool>,
fail_on_write: bool,
fail_on_close: bool,
}

impl Debug for AbortTrackingAccess {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AbortTrackingAccess").finish()
}
}

impl AbortTrackingAccess {
fn with_failures(aborted: Arc<AtomicBool>, fail_on_write: bool, fail_on_close: bool) -> Self {
let info = AccessorInfo::default();
info.set_scheme("memory")
.set_root("/")
.set_name("abort-tracking")
.set_native_capability(Capability {
write: true,
..Default::default()
});

Self {
info: info.into(),
aborted,
fail_on_write,
fail_on_close,
}
}

fn write_failure(aborted: Arc<AtomicBool>) -> Self {
Self::with_failures(aborted, true, false)
}

fn close_failure(aborted: Arc<AtomicBool>) -> Self {
Self::with_failures(aborted, false, true)
}
}

impl Access for AbortTrackingAccess {
type Reader = oio::Reader;
type Writer = oio::Writer;
type Lister = oio::Lister;
type Deleter = oio::Deleter;
type Copier = oio::Copier;

fn info(&self) -> Arc<AccessorInfo> {
self.info.clone()
}

async fn write(&self, _: &str, _: OpWrite) -> opendal::Result<(RpWrite, Self::Writer)> {
Ok((
RpWrite::new(),
Box::new(AbortTrackingWriter {
aborted: self.aborted.clone(),
fail_on_write: self.fail_on_write,
fail_on_close: self.fail_on_close,
}),
))
}
}

struct AbortTrackingWriter {
aborted: Arc<AtomicBool>,
fail_on_write: bool,
fail_on_close: bool,
}

impl oio::Write for AbortTrackingWriter {
async fn write(&mut self, _: Buffer) -> opendal::Result<()> {
if self.fail_on_write {
return Err(Error::new(ErrorKind::Unexpected, "injected write failure"));
}

Ok(())
}

async fn close(&mut self) -> opendal::Result<Metadata> {
if self.fail_on_close {
return Err(Error::new(ErrorKind::Unexpected, "injected close failure"));
}

Ok(Metadata::default())
}

async fn abort(&mut self) -> opendal::Result<()> {
self.aborted.store(true, Ordering::SeqCst);
Ok(())
}
}

#[tokio::test]
async fn test_failed_write_aborts_before_drop() {
let aborted = Arc::new(AtomicBool::new(false));
let op = Operator::from_inner(Arc::new(AbortTrackingAccess::write_failure(
aborted.clone(),
)));
let webdavfs = OpendalFs::new(op);

let mut file = webdavfs
.open(
&DavPath::new("/failed-write").unwrap(),
OpenOptions {
write: true,
..OpenOptions::default()
},
)
.await
.unwrap();

let err = file.write_bytes(Bytes::from(vec![1; 300 * 1024])).await;
assert!(err.is_err());

drop(file);

assert!(
aborted.load(Ordering::SeqCst),
"writer.abort() should be called when a write fails before close()"
);
}

#[tokio::test]
async fn test_failed_close_aborts_before_drop() {
let aborted = Arc::new(AtomicBool::new(false));
let op = Operator::from_inner(Arc::new(AbortTrackingAccess::close_failure(
aborted.clone(),
)));
let webdavfs = OpendalFs::new(op);

let mut file = webdavfs
.open(
&DavPath::new("/failed-close").unwrap(),
OpenOptions {
write: true,
..OpenOptions::default()
},
)
.await
.unwrap();

file.write_bytes(Bytes::from(vec![1; 300 * 1024]))
.await
.unwrap();

let err = file.flush().await;
assert!(err.is_err());

drop(file);

assert!(
aborted.load(Ordering::SeqCst),
"writer.abort() should be called when close() fails during flush()"
);
}
Loading