-
Notifications
You must be signed in to change notification settings - Fork 7k
fix(ddpm): use _execution_device, validate inputs, free hooks (#13649) #13671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,7 +73,8 @@ def __call__( | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
| expense of slower inference. | ||
| output_type (`str`, *optional*, defaults to `"pil"`): | ||
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. | ||
| The output format of the generated image. Choose between `PIL.Image`, `np.array` or | ||
| `torch.Tensor`. | ||
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. | ||
|
|
||
|
|
@@ -97,6 +98,17 @@ def __call__( | |
| If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is | ||
| returned where the first element is a list with the generated images | ||
| """ | ||
| if output_type not in ["pt", "np", "pil"]: | ||
| raise ValueError(f"output_type must be one of ['pt', 'np', 'pil'], got '{output_type}'.") | ||
|
|
||
| if isinstance(generator, list) and len(generator) != batch_size: | ||
| raise ValueError( | ||
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||
| ) | ||
|
Comment on lines
+104
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, this resolves issue 2 |
||
|
|
||
| device = self._execution_device | ||
|
|
||
| # Sample gaussian noise to begin loop | ||
| if isinstance(self.unet.config.sample_size, int): | ||
| image_shape = ( | ||
|
|
@@ -108,12 +120,12 @@ def __call__( | |
| else: | ||
| image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) | ||
|
|
||
| if self.device.type == "mps": | ||
| if device.type == "mps": | ||
| # randn does not work reproducibly on mps | ||
| image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype) | ||
| image = image.to(self.device) | ||
| image = image.to(device) | ||
| else: | ||
| image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) | ||
| image = randn_tensor(image_shape, generator=generator, device=device, dtype=self.unet.dtype) | ||
|
Comment on lines
110
to
+128
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, this + a later line resolves issue 1 |
||
|
|
||
| # set step values | ||
| self.scheduler.set_timesteps(num_inference_steps) | ||
|
|
@@ -129,9 +141,12 @@ def __call__( | |
| xm.mark_step() | ||
|
|
||
| image = (image / 2 + 0.5).clamp(0, 1) | ||
| image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
| if output_type == "pil": | ||
| image = self.numpy_to_pil(image) | ||
| if output_type != "pt": | ||
| image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
| if output_type == "pil": | ||
| image = self.numpy_to_pil(image) | ||
|
Comment on lines
143
to
+147
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per additional review in #13663 (comment) this could use |
||
|
|
||
| self.maybe_free_model_hooks() | ||
|
|
||
| if not return_dict: | ||
| return (image,) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @yiyixuxu Should output_type validation be a review rule?