Skip to content

Conversation

@CalamitousFelicitousness
Copy link
Contributor

@CalamitousFelicitousness CalamitousFelicitousness commented Jan 21, 2026

What does this PR do?

This PR adds an inpainting pipeline for Z-Image. The summary of changes are below:

  • Implemented the ZImageInpaintPipeline class for mask-based inpainting
  • Updated the pipeline structure to include ZImageInpaintPipeline alongside ZImagePipeline and ZImageImg2ImgPipeline
  • Mapped the new ZImageInpaintPipeline in AUTO_INPAINT_PIPELINES_MAPPING
  • Added unit tests for ZImageInpaintPipeline with torch.empty fix for test stability
  • Updated dummy objects to include ZImageInpaintPipeline
  • Added documentation with usage example

Closes issue #12752

Tested using a simple script:

Testing script
  #!/usr/bin/env python
  """Test script for ZImage inpaint support."""

  import sys
  sys.path.insert(0, '/home/ohiom/diffusers/src')

  import torch
  import numpy as np
  from PIL import Image
  from diffusers import ZImageInpaintPipeline

  # Paths
  MODEL_PATH =
  "database/models/huggingface/models--Tongyi-MAI--Z-Image-Turbo/snapshots/78771b7e11b922c868dd766476bda1f4fc6bfc96"
  INPUT_IMAGE_PATH = "death_remix_1024.png"

  print("Loading ZImageInpaintPipeline...")
  pipe = ZImageInpaintPipeline.from_pretrained(
      MODEL_PATH,
      torch_dtype=torch.bfloat16,
      local_files_only=True,
  )
  pipe.to("cuda")
  print("Pipeline loaded.")

  # Load input image
  print(f"\nLoading input image from {INPUT_IMAGE_PATH}...")
  input_image = Image.open(INPUT_IMAGE_PATH).convert("RGB")
  print(f"Input image size: {input_image.size}")

  # Create a mask (white = inpaint, black = preserve)
  width, height = input_image.size
  mask = np.zeros((height, width), dtype=np.uint8)
  h_start, h_end = height // 4, 3 * height // 4
  w_start, w_end = width // 4, 3 * width // 4
  mask[h_start:h_end, w_start:w_end] = 255
  mask_image = Image.fromarray(mask)

  # Generate an inpainted image
  prompt = "a woman with pale skin in a black shirt, oil painting style"
  strength = 0.75

  print(f"\nGenerating inpainted image with prompt: {prompt}")
  print(f"Strength: {strength}")

  image = pipe(
      prompt=prompt,
      image=input_image,
      mask_image=mask_image,
      strength=strength,
      num_inference_steps=8,
      guidance_scale=1.0,
      generator=torch.Generator(device="cuda").manual_seed(42),
  ).images[0]

  output_path = "test_zimage_inpaint_output.png"
  image.save(output_path)
  print(f"\nImage saved to {output_path}")

LoRA functionality is also supported (inherited from ZImageLoraLoaderMixin).

Clipboard3

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu @asomoza @sayakpaul

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a comprehensive inpainting pipeline for Z-Image, extending the existing Z-Image family of pipelines (text-to-image, img2img, controlnet) with mask-based inpainting capabilities. The implementation follows established patterns from other inpainting pipelines in the diffusers library while adapting to Z-Image's specific requirements (e.g., complex64 RoPE embeddings, flow matching scheduler).

Changes:

  • Implemented ZImageInpaintPipeline with full inpainting support including mask blending and strength-based denoising control
  • Added comprehensive test suite covering inference, batch processing, strength validation, mask functionality, VAE tiling, and device offloading
  • Integrated the pipeline into auto_pipeline infrastructure with proper model mapping for "z-image" model type
  • Updated all necessary init.py files and dummy objects for proper exports
  • Added documentation with usage examples

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py New inpainting pipeline implementation with prepare_mask_latents, prepare_latents, and main call method for mask-based image inpainting
tests/pipelines/z_image/test_z_image_inpaint.py Comprehensive test suite including inference tests, strength parameter validation, mask functionality tests, and compatibility tests
src/diffusers/pipelines/z_image/init.py Added ZImageInpaintPipeline to module exports
src/diffusers/pipelines/init.py Added ZImageInpaintPipeline to main pipelines module exports
src/diffusers/init.py Added ZImageInpaintPipeline to top-level diffusers exports
src/diffusers/pipelines/auto_pipeline.py Mapped ZImageInpaintPipeline to "z-image" in AUTO_INPAINT_PIPELINES_MAPPING
src/diffusers/utils/dummy_torch_and_transformers_objects.py Added dummy ZImageInpaintPipeline class for when dependencies are not available
docs/source/en/api/pipelines/z_image.md Added inpainting section with usage example and API reference

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@CalamitousFelicitousness
Copy link
Contributor Author

@yiyixuxu Ready for re-review.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@asomoza asomoza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, left some initial comments.

did you test it with num_images_per_prompt since you're not using them in the mask and latent generations like the original z-image pipeline.

Also you're not checking the inputs at the start at the pipeline call.

@CalamitousFelicitousness
Copy link
Contributor Author

CalamitousFelicitousness commented Jan 23, 2026

@asomoza How does this look? I still have to do cleanup, but in general structure.

@asomoza
Copy link
Member

asomoza commented Jan 23, 2026

looks good to me, let me know when you do the cleanup, also I wonder if we should do the fix for this issue here too.

ccing: @yiyixuxu

@CalamitousFelicitousness
Copy link
Contributor Author

CalamitousFelicitousness commented Jan 23, 2026

Might as well, I'll integrate it, can always revert and separate into another PR if preferred.

Updated the pipeline structure to include ZImageInpaintPipeline
    alongside ZImagePipeline and ZImageImg2ImgPipeline.
Implemented the ZImageInpaintPipeline class for inpainting
    tasks, including necessary methods for encoding prompts,
    preparing masked latents, and denoising.
Enhanced the auto_pipeline to map the new ZImageInpaintPipeline
    for inpainting generation tasks.
Added unit tests for ZImageInpaintPipeline to ensure
    functionality and performance.
Updated dummy objects to include ZImageInpaintPipeline for
    testing purposes.
- Add torch.empty fix for x_pad_token and cap_pad_token in test
- Add # Copied from annotations for encode_prompt methods
- Add documentation with usage example and autodoc directive
Add batch size validation and callback handling fixes per review,
using diffusers conventions rather than suggested code verbatim.
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
- Add missing is_torch_xla_available import for TPU support
- Add xm.mark_step() in denoising loop for proper XLA execution
- Add check_inputs() method for comprehensive input validation
- Call check_inputs() at the start of __call__

Addresses PR review feedback from @asomoza.
Z-Image uses a different CFG formula than standard diffusers pipelines:
- Standard: pred = neg + guidance_scale * (pos - neg), where scale=1 means no CFG
- Z-Image: pred = pos + guidance_scale * (pos - neg), where scale=0 means no CFG

The do_classifier_free_guidance property was checking > 1, which prevented
CFG from being applied when guidance_scale was between 0 and 1. Changed
to > 0 to match Z-Image's CFG semantics.

Fixes huggingface#12905
@CalamitousFelicitousness CalamitousFelicitousness force-pushed the feature/zimage-inpaint-pipeline branch from 8d419db to 9a74d2f Compare January 27, 2026 01:26
@CalamitousFelicitousness
Copy link
Contributor Author

@asomoza Apologies for the delay, cleanup and changes requested done.

@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JerryWu-code
is this change ok? see more context here #12905

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants