1+ """S3 Bucket data loader with local caching (HuggingFace-style)."""
2+
3+ from __future__ import annotations
4+
5+ import hashlib
6+ import logging
7+ import os
8+ import re
9+ import shutil
10+ from pathlib import Path
11+ from typing import Optional , Union
12+ from urllib .parse import urlparse
13+
14+ logger = logging .getLogger (__name__ )
15+
16+
17+ class _DataLoader :
18+ """Internal S3 data loader with local caching."""
19+
20+ DEFAULT_BUCKET = "cache-datasets"
21+ DEFAULT_CACHE_DIR = Path (os .environ .get ("LCS_HUB_CACHE" , Path .home () / ".cache/libcachesim/hub" ))
22+
23+ # Characters that are problematic on various filesystems
24+ INVALID_CHARS = set ('<>:"|?*\x00 ' )
25+ # Reserved names on Windows
26+ RESERVED_NAMES = {
27+ 'CON' , 'PRN' , 'AUX' , 'NUL' ,
28+ 'COM1' , 'COM2' , 'COM3' , 'COM4' , 'COM5' , 'COM6' , 'COM7' , 'COM8' , 'COM9' ,
29+ 'LPT1' , 'LPT2' , 'LPT3' , 'LPT4' , 'LPT5' , 'LPT6' , 'LPT7' , 'LPT8' , 'LPT9'
30+ }
31+
32+ def __init__ (
33+ self , bucket_name : str = DEFAULT_BUCKET , cache_dir : Optional [Union [str , Path ]] = None , use_auth : bool = False
34+ ):
35+ self .bucket_name = self ._validate_bucket_name (bucket_name )
36+ self .cache_dir = Path (cache_dir ) if cache_dir else self .DEFAULT_CACHE_DIR
37+ self .use_auth = use_auth
38+ self ._s3_client = None
39+ self ._ensure_cache_dir ()
40+
41+ def _validate_bucket_name (self , bucket_name : str ) -> str :
42+ """Validate S3 bucket name according to AWS rules."""
43+ if not bucket_name :
44+ raise ValueError ("Bucket name cannot be empty" )
45+
46+ if len (bucket_name ) < 3 or len (bucket_name ) > 63 :
47+ raise ValueError ("Bucket name must be between 3 and 63 characters" )
48+
49+ if not re .match (r'^[a-z0-9.-]+$' , bucket_name ):
50+ raise ValueError ("Bucket name can only contain lowercase letters, numbers, periods, and hyphens" )
51+
52+ if bucket_name .startswith ('.' ) or bucket_name .endswith ('.' ):
53+ raise ValueError ("Bucket name cannot start or end with a period" )
54+
55+ if bucket_name .startswith ('-' ) or bucket_name .endswith ('-' ):
56+ raise ValueError ("Bucket name cannot start or end with a hyphen" )
57+
58+ if '..' in bucket_name :
59+ raise ValueError ("Bucket name cannot contain consecutive periods" )
60+
61+ return bucket_name
62+
63+ def _validate_and_sanitize_key (self , key : str ) -> str :
64+ """Validate and sanitize S3 key for safe local filesystem usage."""
65+ if not key :
66+ raise ValueError ("S3 key cannot be empty" )
67+
68+ if len (key ) > 1024 : # S3 limit is 1024 bytes
69+ raise ValueError ("S3 key is too long (max 1024 characters)" )
70+
71+ # Check for path traversal attempts
72+ if '..' in key :
73+ raise ValueError ("S3 key cannot contain '..' (path traversal not allowed)" )
74+
75+ if key .startswith ('/' ):
76+ raise ValueError ("S3 key cannot start with '/'" )
77+
78+ # Split key into parts and validate each part
79+ parts = key .split ('/' )
80+ sanitized_parts = []
81+
82+ for part in parts :
83+ if not part : # Empty part (double slash)
84+ continue
85+
86+ # Check for reserved names (case insensitive)
87+ if part .upper () in self .RESERVED_NAMES :
88+ raise ValueError (f"S3 key contains reserved name: { part } " )
89+
90+ # Check for invalid characters
91+ if any (c in self .INVALID_CHARS for c in part ):
92+ raise ValueError (f"S3 key contains invalid characters in part: { part } " )
93+
94+ # Check if part is too long for filesystem
95+ if len (part ) > 255 : # Most filesystems have 255 char limit per component
96+ raise ValueError (f"S3 key component too long: { part } " )
97+
98+ sanitized_parts .append (part )
99+
100+ if not sanitized_parts :
101+ raise ValueError ("S3 key resulted in empty path after sanitization" )
102+
103+ return '/' .join (sanitized_parts )
104+
105+ def _ensure_cache_dir (self ) -> None :
106+ (self .cache_dir / self .bucket_name ).mkdir (parents = True , exist_ok = True )
107+
108+ def _get_available_disk_space (self , path : Path ) -> int :
109+ """Get available disk space in bytes."""
110+ try :
111+ stat = os .statvfs (path )
112+ return stat .f_bavail * stat .f_frsize
113+ except (OSError , AttributeError ):
114+ # Fallback for Windows or other systems
115+ try :
116+ import shutil
117+ return shutil .disk_usage (path ).free
118+ except Exception :
119+ logger .warning ("Could not determine available disk space" )
120+ return float ('inf' ) # Assume unlimited space if we can't check
121+
122+ @property
123+ def s3_client (self ):
124+ if self ._s3_client is None :
125+ try :
126+ import boto3
127+ from botocore .config import Config
128+ from botocore import UNSIGNED
129+
130+ self ._s3_client = boto3 .client (
131+ "s3" , config = None if self .use_auth else Config (signature_version = UNSIGNED )
132+ )
133+ except ImportError :
134+ raise ImportError ("Install boto3: pip install boto3" )
135+ return self ._s3_client
136+
137+ def _cache_path (self , key : str ) -> Path :
138+ """Create cache path that mirrors S3 structure after validation."""
139+ sanitized_key = self ._validate_and_sanitize_key (key )
140+ cache_path = self .cache_dir / self .bucket_name / sanitized_key
141+
142+ # Double-check that the resolved path is still within cache directory
143+ try :
144+ cache_path .resolve ().relative_to (self .cache_dir .resolve ())
145+ except ValueError :
146+ raise ValueError (f"S3 key resolves outside cache directory: { key } " )
147+
148+ return cache_path
149+
150+ def _get_object_size (self , key : str ) -> int :
151+ """Get the size of an S3 object without downloading it."""
152+ try :
153+ response = self .s3_client .head_object (Bucket = self .bucket_name , Key = key )
154+ return response ['ContentLength' ]
155+ except Exception as e :
156+ logger .warning (f"Could not determine object size for s3://{ self .bucket_name } /{ key } : { e } " )
157+ return 0
158+
159+ def _download (self , key : str , dest : Path ) -> None :
160+ temp = dest .with_suffix (dest .suffix + ".tmp" )
161+ temp .parent .mkdir (parents = True , exist_ok = True )
162+
163+ try :
164+ # Check available disk space before downloading
165+ object_size = self ._get_object_size (key )
166+ if object_size > 0 :
167+ available_space = self ._get_available_disk_space (temp .parent )
168+ if object_size > available_space :
169+ raise RuntimeError (
170+ f"Insufficient disk space. Need { object_size / (1024 ** 3 ):.2f} GB, "
171+ f"but only { available_space / (1024 ** 3 ):.2f} GB available"
172+ )
173+
174+ logger .info (f"Downloading s3://{ self .bucket_name } /{ key } " )
175+ obj = self .s3_client .get_object (Bucket = self .bucket_name , Key = key )
176+ with open (temp , "wb" ) as f :
177+ f .write (obj ["Body" ].read ())
178+ shutil .move (str (temp ), str (dest ))
179+ logger .info (f"Saved to: { dest } " )
180+ except Exception as e :
181+ if temp .exists ():
182+ temp .unlink ()
183+ raise RuntimeError (f"Download failed for s3://{ self .bucket_name } /{ key } : { e } " )
184+
185+ def load (self , key : str , force : bool = False , mode : str = "rb" ) -> Union [bytes , str ]:
186+ path = self ._cache_path (key )
187+ if not path .exists () or force :
188+ self ._download (key , path )
189+ with open (path , mode ) as f :
190+ return f .read ()
191+
192+ def get_cached_path (self , key : str ) -> str :
193+ """Get the local cached file path, downloading if necessary."""
194+ path = self ._cache_path (key )
195+ if not path .exists ():
196+ self ._download (key , path )
197+ return str (path )
198+
199+ def is_cached (self , key : str ) -> bool :
200+ try :
201+ return self ._cache_path (key ).exists ()
202+ except ValueError :
203+ return False
204+
205+ def clear_cache (self , key : Optional [str ] = None ) -> None :
206+ if key :
207+ try :
208+ path = self ._cache_path (key )
209+ if path .exists ():
210+ path .unlink ()
211+ logger .info (f"Cleared: { path } " )
212+ except ValueError as e :
213+ logger .warning (f"Cannot clear cache for invalid key { key } : { e } " )
214+ else :
215+ shutil .rmtree (self .cache_dir , ignore_errors = True )
216+ logger .info (f"Cleared entire cache: { self .cache_dir } " )
217+
218+ def list_cached_files (self ) -> list [str ]:
219+ if not self .cache_dir .exists ():
220+ return []
221+ return [str (p ) for p in self .cache_dir .rglob ("*" ) if p .is_file () and not p .name .endswith (".tmp" )]
222+
223+ def get_cache_size (self ) -> int :
224+ return sum (p .stat ().st_size for p in self .cache_dir .rglob ("*" ) if p .is_file ())
225+
226+ def list_s3_objects (self , prefix : str = "" , delimiter : str = "/" ) -> dict :
227+ """
228+ List S3 objects and pseudo-folders under a prefix.
229+
230+ Args:
231+ prefix: The S3 prefix to list under (like folder path)
232+ delimiter: Use "/" to simulate folder structure
233+
234+ Returns:
235+ A dict with two keys:
236+ - "folders": list of sub-prefixes (folders)
237+ - "files": list of object keys (files)
238+ """
239+ paginator = self .s3_client .get_paginator ("list_objects_v2" )
240+ result = {"folders" : [], "files" : []}
241+
242+ for page in paginator .paginate (Bucket = self .bucket_name , Prefix = prefix , Delimiter = delimiter ):
243+ # CommonPrefixes are like subdirectories
244+ result ["folders" ].extend (cp ["Prefix" ] for cp in page .get ("CommonPrefixes" , []))
245+ result ["files" ].extend (obj ["Key" ] for obj in page .get ("Contents" , []))
246+
247+ return result
248+
249+
250+ # Global data loader instance
251+ _data_loader = _DataLoader ()
252+
253+
254+ def set_cache_dir (cache_dir : Union [str , Path ]) -> None :
255+ """
256+ Set the global cache directory for S3 downloads.
257+
258+ Args:
259+ cache_dir: Path to the cache directory
260+
261+ Example:
262+ >>> import libcachesim as lcs
263+ >>> lcs.set_cache_dir("/tmp/my_cache")
264+ """
265+ global _data_loader
266+ _data_loader = _DataLoader (cache_dir = cache_dir )
267+
268+
269+ def get_cache_dir () -> Path :
270+ """
271+ Get the current cache directory.
272+
273+ Returns:
274+ Path to the current cache directory
275+
276+ Example:
277+ >>> import libcachesim as lcs
278+ >>> print(lcs.get_cache_dir())
279+ /home/user/.cache/libcachesim/hub
280+ """
281+ return _data_loader .cache_dir
282+
283+
284+ def clear_cache (s3_path : Optional [str ] = None ) -> None :
285+ """
286+ Clear cached files.
287+
288+ Args:
289+ s3_path: Specific S3 path to clear, or None to clear all cache
290+
291+ Example:
292+ >>> import libcachesim as lcs
293+ >>> # Clear specific file
294+ >>> lcs.clear_cache("s3://cache-datasets/trace1.lcs.zst")
295+ >>> # Clear all cache
296+ >>> lcs.clear_cache()
297+ """
298+ if s3_path and s3_path .startswith ("s3://" ):
299+ parsed = urlparse (s3_path )
300+ bucket = parsed .netloc
301+ key = parsed .path .lstrip ('/' )
302+ if bucket == _data_loader .bucket_name :
303+ _data_loader .clear_cache (key )
304+ else :
305+ logger .warning (f"Cannot clear cache for different bucket: { bucket } " )
306+ else :
307+ _data_loader .clear_cache (s3_path )
308+
309+
310+ def get_cache_size () -> int :
311+ """
312+ Get total size of cached files in bytes.
313+
314+ Returns:
315+ Total cache size in bytes
316+
317+ Example:
318+ >>> import libcachesim as lcs
319+ >>> size_mb = lcs.get_cache_size() / (1024**2)
320+ >>> print(f"Cache size: {size_mb:.2f} MB")
321+ """
322+ return _data_loader .get_cache_size ()
323+
324+
325+ def list_cached_files () -> list [str ]:
326+ """
327+ List all cached files.
328+
329+ Returns:
330+ List of cached file paths
331+
332+ Example:
333+ >>> import libcachesim as lcs
334+ >>> files = lcs.list_cached_files()
335+ >>> for file in files:
336+ ... print(file)
337+ """
338+ return _data_loader .list_cached_files ()
339+
340+
341+ def get_data_loader (bucket_name : str = None ) -> _DataLoader :
342+ """Get data loader instance for a specific bucket or the global one."""
343+ global _data_loader
344+ if bucket_name is None or bucket_name == _data_loader .bucket_name :
345+ return _data_loader
346+ else :
347+ return _DataLoader (bucket_name = bucket_name , cache_dir = _data_loader .cache_dir .parent )
0 commit comments