Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/services/hf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }

[dev-dependencies]
base64 = { workspace = true }
futures = { workspace = true }
opendal-core = { path = "../../core", version = "0.57.0", features = [
"reqwest-rustls-tls",
Expand Down
100 changes: 74 additions & 26 deletions core/services/hf/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;
use log::debug;

use super::HF_SCHEME;
use super::config::HfConfig;
use super::config::{HfConfig, HfDownloadMode};
use super::core::HfCore;
use super::deleter::HfDeleter;
use super::lister::HfLister;
Expand Down Expand Up @@ -52,7 +52,7 @@ impl HfBuilder {
pub fn repo_type(mut self, repo_type: &str) -> Self {
if !repo_type.is_empty() {
if let Ok(rt) = HfRepoType::parse(repo_type) {
self.config.repo_type = rt;
self.config.repo_type = Some(rt);
}
}
self
Expand Down Expand Up @@ -111,6 +111,19 @@ impl HfBuilder {
self
}

/// Set the download mode. Either `xet` (default) or `http`.
///
/// - `xet`: uses the XET protocol for downloads (default).
/// - `http`: plain HTTP download, following the redirect from the server.
pub fn download_mode(mut self, mode: &str) -> Self {
if !mode.is_empty() {
if let Ok(m) = HfDownloadMode::parse(mode) {
self.config.download_mode = m;
}
}
self
}

/// configure the Hub base url. You might want to set this variable if your
/// organization is using a Private Hub https://huggingface.co/enterprise
///
Expand Down Expand Up @@ -178,15 +191,18 @@ impl Builder for HfBuilder {
let token = self.hf_token();
let endpoint = self.hf_endpoint();

let repo_type = self.config.repo_type;
let repo_type = self.config.repo_type.ok_or_else(|| {
Error::new(ErrorKind::ConfigInvalid, "repo_type is required")
.with_operation("Builder::build")
.with_context("service", HF_SCHEME)
})?;
debug!("backend use repo_type: {:?}", &repo_type);

let repo_id = match &self.config.repo_id {
Some(repo_id) => Ok(repo_id.clone()),
None => Err(Error::new(ErrorKind::ConfigInvalid, "repo_id is empty")
let repo_id = self.config.repo_id.ok_or_else(|| {
Error::new(ErrorKind::ConfigInvalid, "repo_id is required")
.with_operation("Builder::build")
.with_context("service", HF_SCHEME)),
}?;
.with_context("service", HF_SCHEME)
})?;
debug!("backend use repo_id: {}", &repo_id);

let revision = match &self.config.revision {
Expand Down Expand Up @@ -220,7 +236,14 @@ impl Builder for HfBuilder {
debug!("backend repo uri: {:?}", repo.uri(&root, ""));

Ok(HfBackend {
core: Arc::new(HfCore::build(info, repo, root, token, endpoint)?),
core: Arc::new(HfCore::build(
info,
repo,
root,
token,
endpoint,
self.config.download_mode,
)?),
})
}
}
Expand Down Expand Up @@ -248,28 +271,18 @@ impl Access for HfBackend {
return Ok(RpStat::new(Metadata::new(EntryMode::DIR)));
}

if self.core.repo.is_bucket() {
if path.ends_with('/') {
return Ok(RpStat::new(Metadata::new(EntryMode::DIR)));
}
return match self.core.maybe_xet_file(path).await? {
Some(file_info) => {
let size = file_info.file_size().unwrap_or(0);
Ok(RpStat::new(
Metadata::new(EntryMode::FILE).with_content_length(size),
))
}
None => Err(Error::new(ErrorKind::NotFound, "path not found")),
};
// Buckets have no git directory entries; treat any trailing-slash path as a virtual dir.
if self.core.repo.is_bucket() && path.ends_with('/') {
return Ok(RpStat::new(Metadata::new(EntryMode::DIR)));
}

let info = self.core.path_info(path).await?;
Ok(RpStat::new(info.metadata()?))
}

async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
let (metadata, reader) = HfReader::try_new(&self.core, path, args.range()).await?;
Ok((RpRead::new(metadata), reader))
let reader = HfReader::try_new(&self.core, path, args.range()).await?;
Ok((RpRead::default(), reader))
}

async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
Expand All @@ -294,10 +307,16 @@ impl Access for HfBackend {

#[cfg(test)]
pub(super) mod test_utils {
use std::sync::Arc;

use super::super::config::HfDownloadMode;
use super::super::core::HfCore;
use super::super::uri::{HfRepo, HfRepoType};
use super::HfBuilder;
use opendal_core::Capability;
use opendal_core::Operator;
use opendal_core::layers::HttpClientLayer;
use opendal_core::raw::HttpClient;
use opendal_core::raw::{AccessorInfo, HttpClient};

fn finish_operator(op: Operator) -> Operator {
let client = HttpClient::with(reqwest::Client::new());
Expand All @@ -308,7 +327,8 @@ pub(super) mod test_utils {
let op = Operator::new(
HfBuilder::default()
.repo_type("model")
.repo_id("openai-community/gpt2"),
.repo_id("openai-community/gpt2")
.download_mode("http"),
)
.unwrap()
.finish();
Expand All @@ -326,6 +346,34 @@ pub(super) mod test_utils {
finish_operator(op)
}

pub fn testing_dataset_core() -> Arc<HfCore> {
let repo_id = std::env::var("HF_OPENDAL_DATASET").expect("HF_OPENDAL_DATASET must be set");
let token = std::env::var("HF_OPENDAL_TOKEN").expect("HF_OPENDAL_TOKEN must be set");

let info = AccessorInfo::default();
info.set_scheme("hf").set_native_capability(Capability {
read: true,
write: true,
delete: true,
..Default::default()
});
info.update_http_client(|_| HttpClient::with(reqwest::Client::new()));

let repo = HfRepo::new(HfRepoType::Dataset, repo_id, Some("main".to_string()));

Arc::new(
HfCore::build(
Arc::new(info),
repo,
"/".to_string(),
Some(token),
"https://huggingface.co".to_string(),
HfDownloadMode::Xet,
)
.expect("failed to build HfCore"),
)
}

pub fn testing_bucket_operator() -> Operator {
let repo_id = std::env::var("HF_OPENDAL_BUCKET").expect("HF_OPENDAL_BUCKET must be set");
let token = std::env::var("HF_OPENDAL_TOKEN").expect("HF_OPENDAL_TOKEN must be set");
Expand Down
Loading
Loading