Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion infisical_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,29 @@ def set_token(self, token: str):

def get_token(self):
"""
Set the access token for future requests.
Get the access token for future requests.
"""
return self.access_token

def close(self):
"""
Close the client and release resources.

This stops the background cache cleanup thread. You don't need to call
this if you're using the client as a context manager (with statement),
as cleanup happens automatically when exiting the context.
"""
if self.cache:
self.cache.close()

# These are automatically called if using the client as a context manager (on start)
# Example:
# with InfisicalSDKClient(...) as client:
# ...
def __enter__(self) -> "InfisicalSDKClient":
"""Support for context manager protocol."""
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Ensure cleanup when exiting context."""
self.close()
122 changes: 104 additions & 18 deletions infisical_sdk/util/secrets_cache.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,66 @@
from typing import Dict, Tuple, Any

from infisical_sdk.api_types import BaseSecret
import json
import time
import threading
import weakref
import atexit
from hashlib import sha256
import pickle

MAX_CACHE_SIZE = 1000

# Global registry to track active caches for cleanup at interpreter shutdown
_active_caches: weakref.WeakSet = weakref.WeakSet()
_atexit_registered = False


def _cleanup_all_caches():
"""Called at interpreter shutdown to stop all cache threads."""
for cache in list(_active_caches):
try:
cache.close()
except Exception:
pass


class SecretsCache:
def __init__(self, ttl_seconds: int = 60) -> None:
if ttl_seconds is None or ttl_seconds <= 0:
self.enabled = False
self._closed = True
return

self.enabled = True
self._closed = False
self.ttl = ttl_seconds
self.cleanup_interval = 60

self.cache: Dict[str, Tuple[bytes, float]] = {}

self.lock = threading.RLock()

self.stop_cleanup_thread = False
self.cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
# use a event for cleaner thread signaling
self._stop_event = threading.Event()

# start cleanup thread with a ref to self
# this prevents the thread from keeping the cache alive
self.cleanup_thread = threading.Thread(
target=self._cleanup_worker_static,
args=(weakref.ref(self),),
daemon=True,
name=f"SecretsCache-cleanup-{id(self)}"
)
self.cleanup_thread.start()

# register for cleanup tracking
_active_caches.add(self)

# register atexit handler once
global _atexit_registered
if not _atexit_registered:
atexit.register(_cleanup_all_caches)
_atexit_registered = True

def compute_cache_key(self, operation_name: str, **kwargs) -> str:
sorted_kwargs = sorted(kwargs.items())
Expand All @@ -34,7 +69,7 @@ def compute_cache_key(self, operation_name: str, **kwargs) -> str:
return f"{operation_name}-{sha256(json_str.encode()).hexdigest()}"

def get(self, cache_key: str) -> Any:
if not self.enabled:
if not self.enabled or self._closed:
return None

with self.lock:
Expand All @@ -50,7 +85,7 @@ def get(self, cache_key: str) -> Any:


def set(self, cache_key: str, value: Any) -> None:
if not self.enabled:
if not self.enabled or self._closed:
return

with self.lock:
Expand All @@ -64,14 +99,14 @@ def set(self, cache_key: str, value: Any) -> None:


def unset(self, cache_key: str) -> None:
if not self.enabled:
if not self.enabled or self._closed:
return

with self.lock:
self.cache.pop(cache_key, None)

def invalidate_operation(self, operation_name: str) -> None:
if not self.enabled:
if not self.enabled or self._closed:
return

with self.lock:
Expand All @@ -91,16 +126,67 @@ def _cleanup_expired_items(self) -> None:
for key in expired_keys:
self.cache.pop(key, None)

def _cleanup_worker(self) -> None:
"""Background worker that periodically cleans up expired items."""
while not self.stop_cleanup_thread:
time.sleep(self.cleanup_interval)
self._cleanup_expired_items()
@staticmethod
def _cleanup_worker_static(cache_ref: weakref.ref) -> None:
"""
Background worker that periodically cleans up expired items.

Uses a weak reference to the cache to avoid preventing garbage collection.
The thread will exit automatically when the cache is garbage collected.
"""
while True:
cache = cache_ref()
if cache is None:
return # cache has been garbage collected, exit thread

# extract what we need, then release the reference so GC can work
stop_event = cache._stop_event
cleanup_interval = cache.cleanup_interval
del cache # release reference BEFORE waiting

# now wait without holding a reference to the cache
if stop_event.wait(timeout=cleanup_interval):
return # event was set, time to stop

# re-acquire reference to do cleanup
cache = cache_ref()
if cache is None:
return # cache was garbage collected during wait

cache._cleanup_expired_items()

def close(self) -> None:
"""
Explicitly stop the cleanup thread and release resources.

This method should be called when the cache is no longer needed.
It is safe to call multiple times.
"""
if not self.enabled or self._closed:
return

self._closed = True
self._stop_event.set()

if self.cleanup_thread.is_alive():
self.cleanup_thread.join(timeout=2.0)

# clear the cache
with self.lock:
self.cache.clear()

def __enter__(self) -> "SecretsCache":
"""Support for context manager protocol."""
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Ensure cleanup when exiting context."""
self.close()

def __del__(self) -> None:
"""Ensure thread is properly stopped when the object is garbage collected."""
self.stop_cleanup_thread = True
if self.enabled and self.cleanup_thread.is_alive():
self.cleanup_thread.join(timeout=1.0)


"""Fallback cleanup when object is garbage collected."""
try:
self.close()
except Exception:
# just pass to ignore errors on shutdown
pass