Skip to content

Commit 0db2e16

Browse files
committed
Feat: Make data load internal
1 parent 5b4da1b commit 0db2e16

File tree

8 files changed

+439
-143
lines changed

8 files changed

+439
-143
lines changed

libcachesim/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from .trace_analyzer import TraceAnalyzer
6060
from .synthetic_reader import SyntheticReader, create_zipf_requests, create_uniform_requests
6161
from .util import Util
62-
from .data_loader import DataLoader
6362

6463
__all__ = [
6564
# Core classes
@@ -118,8 +117,6 @@
118117
"create_uniform_requests",
119118
# Utilities
120119
"Util",
121-
# Data loader
122-
"DataLoader",
123120
# Metadata
124121
"__doc__",
125122
"__version__",

libcachesim/_s3_cache.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
"""S3 Bucket data loader with local caching (HuggingFace-style)."""
2+
3+
from __future__ import annotations
4+
5+
import hashlib
6+
import logging
7+
import os
8+
import re
9+
import shutil
10+
from pathlib import Path
11+
from typing import Optional, Union
12+
from urllib.parse import urlparse
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class _DataLoader:
18+
"""Internal S3 data loader with local caching."""
19+
20+
DEFAULT_BUCKET = "cache-datasets"
21+
DEFAULT_CACHE_DIR = Path(os.environ.get("LCS_HUB_CACHE", Path.home() / ".cache/libcachesim/hub"))
22+
23+
# Characters that are problematic on various filesystems
24+
INVALID_CHARS = set('<>:"|?*\x00')
25+
# Reserved names on Windows
26+
RESERVED_NAMES = {
27+
'CON', 'PRN', 'AUX', 'NUL',
28+
'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', 'COM8', 'COM9',
29+
'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9'
30+
}
31+
32+
def __init__(
33+
self, bucket_name: str = DEFAULT_BUCKET, cache_dir: Optional[Union[str, Path]] = None, use_auth: bool = False
34+
):
35+
self.bucket_name = self._validate_bucket_name(bucket_name)
36+
self.cache_dir = Path(cache_dir) if cache_dir else self.DEFAULT_CACHE_DIR
37+
self.use_auth = use_auth
38+
self._s3_client = None
39+
self._ensure_cache_dir()
40+
41+
def _validate_bucket_name(self, bucket_name: str) -> str:
42+
"""Validate S3 bucket name according to AWS rules."""
43+
if not bucket_name:
44+
raise ValueError("Bucket name cannot be empty")
45+
46+
if len(bucket_name) < 3 or len(bucket_name) > 63:
47+
raise ValueError("Bucket name must be between 3 and 63 characters")
48+
49+
if not re.match(r'^[a-z0-9.-]+$', bucket_name):
50+
raise ValueError("Bucket name can only contain lowercase letters, numbers, periods, and hyphens")
51+
52+
if bucket_name.startswith('.') or bucket_name.endswith('.'):
53+
raise ValueError("Bucket name cannot start or end with a period")
54+
55+
if bucket_name.startswith('-') or bucket_name.endswith('-'):
56+
raise ValueError("Bucket name cannot start or end with a hyphen")
57+
58+
if '..' in bucket_name:
59+
raise ValueError("Bucket name cannot contain consecutive periods")
60+
61+
return bucket_name
62+
63+
def _validate_and_sanitize_key(self, key: str) -> str:
64+
"""Validate and sanitize S3 key for safe local filesystem usage."""
65+
if not key:
66+
raise ValueError("S3 key cannot be empty")
67+
68+
if len(key) > 1024: # S3 limit is 1024 bytes
69+
raise ValueError("S3 key is too long (max 1024 characters)")
70+
71+
# Check for path traversal attempts
72+
if '..' in key:
73+
raise ValueError("S3 key cannot contain '..' (path traversal not allowed)")
74+
75+
if key.startswith('/'):
76+
raise ValueError("S3 key cannot start with '/'")
77+
78+
# Split key into parts and validate each part
79+
parts = key.split('/')
80+
sanitized_parts = []
81+
82+
for part in parts:
83+
if not part: # Empty part (double slash)
84+
continue
85+
86+
# Check for reserved names (case insensitive)
87+
if part.upper() in self.RESERVED_NAMES:
88+
raise ValueError(f"S3 key contains reserved name: {part}")
89+
90+
# Check for invalid characters
91+
if any(c in self.INVALID_CHARS for c in part):
92+
raise ValueError(f"S3 key contains invalid characters in part: {part}")
93+
94+
# Check if part is too long for filesystem
95+
if len(part) > 255: # Most filesystems have 255 char limit per component
96+
raise ValueError(f"S3 key component too long: {part}")
97+
98+
sanitized_parts.append(part)
99+
100+
if not sanitized_parts:
101+
raise ValueError("S3 key resulted in empty path after sanitization")
102+
103+
return '/'.join(sanitized_parts)
104+
105+
def _ensure_cache_dir(self) -> None:
106+
(self.cache_dir / self.bucket_name).mkdir(parents=True, exist_ok=True)
107+
108+
def _get_available_disk_space(self, path: Path) -> int:
109+
"""Get available disk space in bytes."""
110+
try:
111+
stat = os.statvfs(path)
112+
return stat.f_bavail * stat.f_frsize
113+
except (OSError, AttributeError):
114+
# Fallback for Windows or other systems
115+
try:
116+
import shutil
117+
return shutil.disk_usage(path).free
118+
except Exception:
119+
logger.warning("Could not determine available disk space")
120+
return float('inf') # Assume unlimited space if we can't check
121+
122+
@property
123+
def s3_client(self):
124+
if self._s3_client is None:
125+
try:
126+
import boto3
127+
from botocore.config import Config
128+
from botocore import UNSIGNED
129+
130+
self._s3_client = boto3.client(
131+
"s3", config=None if self.use_auth else Config(signature_version=UNSIGNED)
132+
)
133+
except ImportError:
134+
raise ImportError("Install boto3: pip install boto3")
135+
return self._s3_client
136+
137+
def _cache_path(self, key: str) -> Path:
138+
"""Create cache path that mirrors S3 structure after validation."""
139+
sanitized_key = self._validate_and_sanitize_key(key)
140+
cache_path = self.cache_dir / self.bucket_name / sanitized_key
141+
142+
# Double-check that the resolved path is still within cache directory
143+
try:
144+
cache_path.resolve().relative_to(self.cache_dir.resolve())
145+
except ValueError:
146+
raise ValueError(f"S3 key resolves outside cache directory: {key}")
147+
148+
return cache_path
149+
150+
def _get_object_size(self, key: str) -> int:
151+
"""Get the size of an S3 object without downloading it."""
152+
try:
153+
response = self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
154+
return response['ContentLength']
155+
except Exception as e:
156+
logger.warning(f"Could not determine object size for s3://{self.bucket_name}/{key}: {e}")
157+
return 0
158+
159+
def _download(self, key: str, dest: Path) -> None:
160+
temp = dest.with_suffix(dest.suffix + ".tmp")
161+
temp.parent.mkdir(parents=True, exist_ok=True)
162+
163+
try:
164+
# Check available disk space before downloading
165+
object_size = self._get_object_size(key)
166+
if object_size > 0:
167+
available_space = self._get_available_disk_space(temp.parent)
168+
if object_size > available_space:
169+
raise RuntimeError(
170+
f"Insufficient disk space. Need {object_size / (1024**3):.2f} GB, "
171+
f"but only {available_space / (1024**3):.2f} GB available"
172+
)
173+
174+
logger.info(f"Downloading s3://{self.bucket_name}/{key}")
175+
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
176+
with open(temp, "wb") as f:
177+
f.write(obj["Body"].read())
178+
shutil.move(str(temp), str(dest))
179+
logger.info(f"Saved to: {dest}")
180+
except Exception as e:
181+
if temp.exists():
182+
temp.unlink()
183+
raise RuntimeError(f"Download failed for s3://{self.bucket_name}/{key}: {e}")
184+
185+
def load(self, key: str, force: bool = False, mode: str = "rb") -> Union[bytes, str]:
186+
path = self._cache_path(key)
187+
if not path.exists() or force:
188+
self._download(key, path)
189+
with open(path, mode) as f:
190+
return f.read()
191+
192+
def get_cached_path(self, key: str) -> str:
193+
"""Get the local cached file path, downloading if necessary."""
194+
path = self._cache_path(key)
195+
if not path.exists():
196+
self._download(key, path)
197+
return str(path)
198+
199+
def is_cached(self, key: str) -> bool:
200+
try:
201+
return self._cache_path(key).exists()
202+
except ValueError:
203+
return False
204+
205+
def clear_cache(self, key: Optional[str] = None) -> None:
206+
if key:
207+
try:
208+
path = self._cache_path(key)
209+
if path.exists():
210+
path.unlink()
211+
logger.info(f"Cleared: {path}")
212+
except ValueError as e:
213+
logger.warning(f"Cannot clear cache for invalid key {key}: {e}")
214+
else:
215+
shutil.rmtree(self.cache_dir, ignore_errors=True)
216+
logger.info(f"Cleared entire cache: {self.cache_dir}")
217+
218+
def list_cached_files(self) -> list[str]:
219+
if not self.cache_dir.exists():
220+
return []
221+
return [str(p) for p in self.cache_dir.rglob("*") if p.is_file() and not p.name.endswith(".tmp")]
222+
223+
def get_cache_size(self) -> int:
224+
return sum(p.stat().st_size for p in self.cache_dir.rglob("*") if p.is_file())
225+
226+
def list_s3_objects(self, prefix: str = "", delimiter: str = "/") -> dict:
227+
"""
228+
List S3 objects and pseudo-folders under a prefix.
229+
230+
Args:
231+
prefix: The S3 prefix to list under (like folder path)
232+
delimiter: Use "/" to simulate folder structure
233+
234+
Returns:
235+
A dict with two keys:
236+
- "folders": list of sub-prefixes (folders)
237+
- "files": list of object keys (files)
238+
"""
239+
paginator = self.s3_client.get_paginator("list_objects_v2")
240+
result = {"folders": [], "files": []}
241+
242+
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter=delimiter):
243+
# CommonPrefixes are like subdirectories
244+
result["folders"].extend(cp["Prefix"] for cp in page.get("CommonPrefixes", []))
245+
result["files"].extend(obj["Key"] for obj in page.get("Contents", []))
246+
247+
return result
248+
249+
250+
# Global data loader instance
251+
_data_loader = _DataLoader()
252+
253+
254+
def set_cache_dir(cache_dir: Union[str, Path]) -> None:
255+
"""
256+
Set the global cache directory for S3 downloads.
257+
258+
Args:
259+
cache_dir: Path to the cache directory
260+
261+
Example:
262+
>>> import libcachesim as lcs
263+
>>> lcs.set_cache_dir("/tmp/my_cache")
264+
"""
265+
global _data_loader
266+
_data_loader = _DataLoader(cache_dir=cache_dir)
267+
268+
269+
def get_cache_dir() -> Path:
270+
"""
271+
Get the current cache directory.
272+
273+
Returns:
274+
Path to the current cache directory
275+
276+
Example:
277+
>>> import libcachesim as lcs
278+
>>> print(lcs.get_cache_dir())
279+
/home/user/.cache/libcachesim/hub
280+
"""
281+
return _data_loader.cache_dir
282+
283+
284+
def clear_cache(s3_path: Optional[str] = None) -> None:
285+
"""
286+
Clear cached files.
287+
288+
Args:
289+
s3_path: Specific S3 path to clear, or None to clear all cache
290+
291+
Example:
292+
>>> import libcachesim as lcs
293+
>>> # Clear specific file
294+
>>> lcs.clear_cache("s3://cache-datasets/trace1.lcs.zst")
295+
>>> # Clear all cache
296+
>>> lcs.clear_cache()
297+
"""
298+
if s3_path and s3_path.startswith("s3://"):
299+
parsed = urlparse(s3_path)
300+
bucket = parsed.netloc
301+
key = parsed.path.lstrip('/')
302+
if bucket == _data_loader.bucket_name:
303+
_data_loader.clear_cache(key)
304+
else:
305+
logger.warning(f"Cannot clear cache for different bucket: {bucket}")
306+
else:
307+
_data_loader.clear_cache(s3_path)
308+
309+
310+
def get_cache_size() -> int:
311+
"""
312+
Get total size of cached files in bytes.
313+
314+
Returns:
315+
Total cache size in bytes
316+
317+
Example:
318+
>>> import libcachesim as lcs
319+
>>> size_mb = lcs.get_cache_size() / (1024**2)
320+
>>> print(f"Cache size: {size_mb:.2f} MB")
321+
"""
322+
return _data_loader.get_cache_size()
323+
324+
325+
def list_cached_files() -> list[str]:
326+
"""
327+
List all cached files.
328+
329+
Returns:
330+
List of cached file paths
331+
332+
Example:
333+
>>> import libcachesim as lcs
334+
>>> files = lcs.list_cached_files()
335+
>>> for file in files:
336+
... print(file)
337+
"""
338+
return _data_loader.list_cached_files()
339+
340+
341+
def get_data_loader(bucket_name: str = None) -> _DataLoader:
342+
"""Get data loader instance for a specific bucket or the global one."""
343+
global _data_loader
344+
if bucket_name is None or bucket_name == _data_loader.bucket_name:
345+
return _data_loader
346+
else:
347+
return _DataLoader(bucket_name=bucket_name, cache_dir=_data_loader.cache_dir.parent)

0 commit comments

Comments
 (0)