Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions backend/python/diffusers/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant)
# Build kwargs for dynamic loading
load_kwargs = {"torch_dtype": torchType}

# For large models (e.g., >80GB), enable low_cpu_mem_usage and device_map
# to avoid OOM during loading by distributing across multiple GPUs
if request.LowVRAM:
load_kwargs["low_cpu_mem_usage"] = True
load_kwargs["device_map"] = "balanced"

# Add variant if not loading from single file
if not fromSingleFile and variant:
load_kwargs["variant"] = variant
Expand Down Expand Up @@ -428,14 +434,56 @@ def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant)
) from e

# Apply LowVRAM optimization if supported and requested
if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'):
# Skip if device_map was used (they conflict with each other)
if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload') and "device_map" not in load_kwargs:
pipe.enable_model_cpu_offload()

return pipe

def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def Shutdown(self, request, context):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused?

"""
Shutdown and release GPU memory for the loaded model.
This allows dynamic model reloading with different configurations (e.g., different LoRA adapters).
"""
try:
print("Shutting down diffusers backend...", file=sys.stderr)

# Release pipeline
if hasattr(self, 'pipe') and self.pipe is not None:
del self.pipe
self.pipe = None

# Release controlnet
if hasattr(self, 'controlnet') and self.controlnet is not None:
del self.controlnet
self.controlnet = None

# Release compel
if hasattr(self, 'compel') and self.compel is not None:
del self.compel
self.compel = None

# Clear CUDA cache to release GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("CUDA cache cleared", file=sys.stderr)

# Reset state flags
self.img2vid = False
self.txt2vid = False
self.ltx2_pipeline = False
self.options = {}

print("Diffusers backend shutdown complete", file=sys.stderr)
return backend_pb2.Result(message="Model unloaded successfully", success=True)
except Exception as err:
print(f"Error during shutdown: {err}", file=sys.stderr)
return backend_pb2.Result(success=False, message=f"Shutdown error: {err}")

def LoadModel(self, request, context):
try:
print(f"Loading model {request.Model}...", file=sys.stderr)
Expand Down Expand Up @@ -582,9 +630,11 @@ def LoadModel(self, request, context):
self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)

if device != "cpu":
self.pipe.to(device)
if self.controlnet:
self.controlnet.to(device)
# Skip .to(device) if device_map was used (they conflict with each other)
if not hasattr(self.pipe, "hf_device_map") or self.pipe.hf_device_map is None:
self.pipe.to(device)
if self.controlnet:
self.controlnet.to(device)

except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
Expand Down
Loading