Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/diffusers/pipelines/flux2/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import math

import numpy as np
import PIL.Image
import torch

from ...configuration_utils import register_to_config
from ...image_processor import VaeImageProcessor
Expand Down Expand Up @@ -56,27 +58,57 @@ def __init__(
do_convert_rgb=do_convert_rgb,
)

@staticmethod
def to_pil(image) -> PIL.Image.Image:
"""Convert torch.Tensor or np.ndarray to PIL.Image.Image.

Accepts:
- PIL.Image.Image → returned as-is
- torch.Tensor → shape (C, H, W) or (B, C, H, W), values in [0, 1]
- np.ndarray → shape (H, W, C) or (B, H, W, C), values in [0, 1]
"""
if isinstance(image, PIL.Image.Image):
return image

if isinstance(image, torch.Tensor):
image = image.detach().cpu().float()
if image.ndim == 4:
image = image[0]
image = image.permute(1, 2, 0).numpy()
elif isinstance(image, np.ndarray):
if image.ndim == 4:
image = image[0]
else:
raise ValueError(
f"Expected PIL.Image.Image, torch.Tensor, or np.ndarray, got {type(image)}"
)

if image.dtype != np.uint8:
image = (np.clip(image, 0, 1) * 255).astype(np.uint8)

return PIL.Image.fromarray(image)

@staticmethod
def check_image_input(
image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
) -> PIL.Image.Image:
"""
Check if image meets minimum size and aspect ratio requirements.
Accepts PIL.Image.Image, torch.Tensor, or np.ndarray and converts to PIL.

Args:
image: PIL Image to validate
image: Image to validate (PIL, tensor, or numpy array)
max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width)
min_side_length: Minimum pixels required for width and height
max_area: Maximum allowed area in pixels²

Returns:
The input image if valid
The image as PIL.Image.Image

Raises:
ValueError: If image is too small or aspect ratio is too extreme
"""
if not isinstance(image, PIL.Image.Image):
raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
image = Flux2ImageProcessor.to_pil(image)

width, height = image.size

Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/pipelines/flux2/pipeline_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...image_processor import PipelineImageInput
from ..pipeline_utils import DiffusionPipeline
from .image_processor import Flux2ImageProcessor
from .pipeline_output import Flux2PipelineOutput
Expand Down Expand Up @@ -608,7 +609,7 @@ def interrupt(self):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
image: PipelineImageInput | None = None,
prompt: str | list[str] = None,
height: int | None = None,
width: int | None = None,
Expand Down Expand Up @@ -758,8 +759,8 @@ def __call__(

condition_images = None
if image is not None:
for img in image:
self.image_processor.check_image_input(img)
# Convert each image to PIL (handles tensor/numpy/PIL uniformly)
image = [self.image_processor.check_image_input(img) for img in image]

condition_images = []
for img in image:
Expand Down
30 changes: 30 additions & 0 deletions tests/pipelines/flux2/test_pipeline_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,36 @@ def test_image_input(self):
# fmt: on
assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)

def test_image_input_tensor(self):
"""Issue #13177: pipeline should accept torch.Tensor images."""
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
inputs = self.get_dummy_inputs(device)

inputs["image"] = torch.rand(3, 64, 64)
image = pipe(**inputs).images
assert image is not None and image.shape[-1] == 3

def test_image_input_numpy(self):
"""Issue #13177: pipeline should accept np.ndarray images."""
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
inputs = self.get_dummy_inputs(device)

inputs["image"] = np.random.rand(64, 64, 3).astype(np.float32)
image = pipe(**inputs).images
assert image is not None and image.shape[-1] == 3

def test_image_input_tensor_list(self):
"""Issue #13177: pipeline should accept list of tensors."""
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
inputs = self.get_dummy_inputs(device)

inputs["image"] = [torch.rand(3, 64, 64), torch.rand(3, 64, 64)]
image = pipe(**inputs).images
assert image is not None and image.shape[-1] == 3

@unittest.skip("Needs to be revisited")
def test_encode_prompt_works_in_isolation(self):
pass