Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
* "waiting" -- Job is waiting in the queue to run
* "downloading" -- Model file(s) are downloading
* "running" -- Model has downloaded and the model probing and registration process is running
* "paused" -- Job is paused and can be resumed
* "completed" -- Installation completed successfully
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
* "cancelled" -- Job was cancelled before completion.
Expand Down Expand Up @@ -818,6 +819,89 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
installer.cancel_job(job)


@model_manager_router.post(
"/install/{id}/pause",
operation_id="pause_model_install_job",
responses={
201: {"description": "The job was paused successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
"""Pause the model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.pause_job(job)
return job


@model_manager_router.post(
"/install/{id}/resume",
operation_id="resume_model_install_job",
responses={
201: {"description": "The job was resumed successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
"""Resume a paused model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.resume_job(job)
return job


@model_manager_router.post(
"/install/{id}/restart_failed",
operation_id="restart_failed_model_install_job",
responses={
201: {"description": "Failed files restarted successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
"""Restart failed or non-resumable file downloads for the given job."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.restart_failed(job)
return job


@model_manager_router.post(
"/install/{id}/restart_file",
operation_id="restart_model_install_file",
responses={
201: {"description": "File restarted successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def restart_model_install_file(
id: int = Path(description="Model install job ID"),
file_source: AnyHttpUrl = Body(description="File download URL to restart"),
) -> ModelInstallJob:
"""Restart a specific file download for the given job."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.restart_file(job, str(file_source))
return job


@model_manager_router.delete(
"/install",
operation_id="prune_model_install_jobs",
Expand Down
13 changes: 11 additions & 2 deletions invokeai/app/invocations/flux2_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_bfl_peft_lora_conversion_utils import (
convert_bfl_lora_patch_to_diffusers,
)
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
Expand Down Expand Up @@ -503,11 +506,17 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
return mask.expand_as(latents)

def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply."""
"""Iterate over LoRA models to apply.

Converts BFL-format LoRA keys to diffusers format if needed, since FLUX.2 Klein
uses Flux2Transformer2DModel (diffusers naming) but LoRAs may have been loaded
with BFL naming (e.g. when a Klein 4B LoRA is misidentified as FLUX.1).
"""
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
converted = convert_bfl_lora_patch_to_diffusers(lora_info.model)
yield (converted, lora.weight)
del lora_info

def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
Expand Down
182 changes: 182 additions & 0 deletions invokeai/app/invocations/flux2_klein_lora_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""FLUX.2 Klein LoRA Loader Invocation.
Applies LoRA models to a FLUX.2 Klein transformer and/or Qwen3 text encoder.
Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3 for text encoding.
"""

from typing import Optional

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType


@invocation_output("flux2_klein_lora_loader_output")
class Flux2KleinLoRALoaderOutput(BaseInvocationOutput):
"""FLUX.2 Klein LoRA Loader Output"""

transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Transformer"
)
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
)


@invocation(
"flux2_klein_lora_loader",
title="Apply LoRA - Flux2 Klein",
tags=["lora", "model", "flux", "klein", "flux2"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2KleinLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX.2 Klein transformer and/or Qwen3 text encoder."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.Flux2,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> Flux2KleinLoRALoaderOutput:
lora_key = self.lora.key

if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")

# Warn if LoRA variant doesn't match transformer variant
lora_config = context.models.get_config(lora_key)
lora_variant = getattr(lora_config, "variant", None)
if lora_variant and self.transformer is not None:
transformer_config = context.models.get_config(self.transformer.transformer.key)
transformer_variant = getattr(transformer_config, "variant", None)
if transformer_variant and lora_variant != transformer_variant:
context.logger.warning(
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
f"but transformer is {transformer_variant.value}. This may cause shape errors."
)

# Check for existing LoRAs with the same key.
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')

output = Flux2KleinLoRALoaderOutput()

# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
output.qwen3_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)

return output


@invocation(
"flux2_klein_lora_collection_loader",
title="Apply LoRA Collection - Flux2 Klein",
tags=["lora", "model", "flux", "klein", "flux2"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2KleinLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a FLUX.2 Klein transformer and/or Qwen3 text encoder."""

loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)

transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> Flux2KleinLoRALoaderOutput:
output = Flux2KleinLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []

if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)

if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)

for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue

if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")

assert lora.lora.base in (BaseModelType.Flux, BaseModelType.Flux2)

# Warn if LoRA variant doesn't match transformer variant
lora_config = context.models.get_config(lora.lora.key)
lora_variant = getattr(lora_config, "variant", None)
if lora_variant and self.transformer is not None:
transformer_config = context.models.get_config(self.transformer.transformer.key)
transformer_variant = getattr(transformer_config, "variant", None)
if transformer_variant and lora_variant != transformer_variant:
context.logger.warning(
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
f"but transformer is {transformer_variant.value}. This may cause shape errors."
)

added_loras.append(lora.lora.key)

if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)

if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
output.qwen3_encoder.loras.append(lora)

return output
2 changes: 1 addition & 1 deletion invokeai/app/services/boards/boards_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create(
board_name: str,
) -> BoardDTO:
board_record = self.__invoker.services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0, 0, 0)
return board_record_to_dto(board_record, None, 0, 0)

def get_dto(self, board_id: str) -> BoardDTO:
board_record = self.__invoker.services.board_records.get(board_id)
Expand Down
28 changes: 28 additions & 0 deletions invokeai/app/services/download/download_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DownloadJobStatus(str, Enum):

WAITING = "waiting" # not enqueued, will not run
RUNNING = "running" # actively downloading
PAUSED = "paused" # paused, can be resumed
COMPLETED = "completed" # finished running
CANCELLED = "cancelled" # user cancelled
ERROR = "error" # terminated with an error message
Expand Down Expand Up @@ -61,6 +62,7 @@ class DownloadJobBase(BaseModel):

# internal flag
_cancelled: bool = PrivateAttr(default=False)
_paused: bool = PrivateAttr(default=False)

# optional event handlers passed in on creation
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
Expand All @@ -72,6 +74,12 @@ class DownloadJobBase(BaseModel):
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
self._paused = False

def pause(self) -> None:
"""Pause the job, preserving partial downloads."""
self._paused = True
self._cancelled = True

# cancelled and the callbacks are private attributes in order to prevent
# them from being serialized and/or used in the Json Schema
Expand All @@ -80,6 +88,11 @@ def cancelled(self) -> bool:
"""Call to cancel the job."""
return self._cancelled

@property
def paused(self) -> bool:
"""Return true if job is paused."""
return self._paused

@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
Expand Down Expand Up @@ -161,6 +174,17 @@ class DownloadJob(DownloadJobBase):
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
canonical_url: Optional[str] = Field(default=None, description="Canonical URL to request on resume")
etag: Optional[str] = Field(default=None, description="ETag from the remote server, if available")
last_modified: Optional[str] = Field(default=None, description="Last-Modified from the remote server, if available")
final_url: Optional[str] = Field(default=None, description="Final resolved URL after redirects, if available")
expected_total_bytes: Optional[int] = Field(default=None, description="Expected total size of the download")
resume_required: bool = Field(default=False, description="True if server refused resume; restart required")
resume_message: Optional[str] = Field(default=None, description="Message explaining why resume is required")
resume_from_scratch: bool = Field(
default=False,
description="True if resume metadata existed but the partial file was missing and the download restarted from the beginning",
)

def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
Expand Down Expand Up @@ -321,6 +345,10 @@ def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass

def pause_job(self, job: DownloadJobBase) -> None: # noqa D401
"""Pause the job, preserving partial downloads."""
raise NotImplementedError

@abstractmethod
def join(self) -> None:
"""Wait until all jobs are off the queue."""
Expand Down
Loading