Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
The output format of the generated image. Choose between `PIL.Image`, `np.array` or
`torch.Tensor`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

Expand All @@ -97,6 +98,17 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"output_type must be one of ['pt', 'np', 'pil'], got '{output_type}'.")
Comment on lines +101 to +102
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yiyixuxu Should output_type validation be a review rule?


if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
Comment on lines +104 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, this resolves issue 2


device = self._execution_device

# Sample gaussian noise to begin loop
if isinstance(self.unet.config.sample_size, int):
image_shape = (
Expand All @@ -108,12 +120,12 @@ def __call__(
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

if self.device.type == "mps":
if device.type == "mps":
# randn does not work reproducibly on mps
image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
image = image.to(self.device)
image = image.to(device)
else:
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
image = randn_tensor(image_shape, generator=generator, device=device, dtype=self.unet.dtype)
Comment on lines 110 to +128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, this + a later line resolves issue 1


# set step values
self.scheduler.set_timesteps(num_inference_steps)
Expand All @@ -129,9 +141,12 @@ def __call__(
xm.mark_step()

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if output_type != "pt":
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
Comment on lines 143 to +147
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per additional review in #13663 (comment) this could use VaeImageProcessor, this would fix all output_type's without the if statements. cc @yiyixuxu For awareness, this is another case that supports introducing VaeImageProcessor usage as a review rule.


self.maybe_free_model_hooks()

if not return_dict:
return (image,)
Expand Down
Loading