Skip to content

[Bug] ”assert param.grad is not None and param.grad.abs().sum() > 0“ Failed for lora-based dmd disstillation with wan2.1-1.3B model #903

@wymhfut

Description

@wymhfut

Describe the bug

I am using the wan2.1-1.3B model for DMD LoRA distillation, and the following error message appears after running several steps:

INFO 11-26 14:21:31 [composed_pipeline_base.py:398] Running pipeline stages: dict_keys(['input_validation_stage', 'prompt_encoding_stage', 'conditioning_stage', 'timestep_preparation_stage', 'latent_preparation_stage', 'denoising_stage', 'decoding_stage'])
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.19it/s]
INFO 11-26 14:21:44 [distillation_pipeline.py:1259] rank: 2: rank_in_sp_group: 2, batch.prompt: The video shows a large, industrial press flattening objects as if they were under a hydraulic press. The press is shown in action, compressing a pile of pink objects into a pile of crumbs. The press is large and metallic, with a yellow and black striped pattern on its side. The background is a green wall with a yellow warning sign.
INFO 11-26 14:21:44 [distillation_pipeline.py:1259] rank: 0: rank_in_sp_group: 0, batch.prompt: The video shows a large, industrial press flattening objects as if they were under a hydraulic press. The press is shown in action, compressing a pile of pink objects into a pile of crumbs. The press is large and metallic, with a yellow and black striped pattern on its side. The background is a green wall with a yellow warning sign.
INFO 11-26 14:21:44 [composed_pipeline_base.py:398] Running pipeline stages: dict_keys(['input_validation_stage', 'prompt_encoding_stage', 'conditioning_stage', 'timestep_preparation_stage', 'latent_preparation_stage', 'denoising_stage', 'decoding_stage'])
INFO 11-26 14:21:44 [distillation_pipeline.py:1259] rank: 1: rank_in_sp_group: 1, batch.prompt: The video shows a large, industrial press flattening objects as if they were under a hydraulic press. The press is shown in action, compressing a pile of pink objects into a pile of crumbs. The press is large and metallic, with a yellow and black striped pattern on its side. The background is a green wall with a yellow warning sign.
INFO 11-26 14:21:44 [distillation_pipeline.py:1259] rank: 3: rank_in_sp_group: 3, batch.prompt: The video shows a large, industrial press flattening objects as if they were under a hydraulic press. The press is shown in action, compressing a pile of pink objects into a pile of crumbs. The press is large and metallic, with a yellow and black striped pattern on its side. The background is a green wall with a yellow warning sign.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.53it/s]
Steps:   0%| | 4/30000 [03:04<284:17:28, 34.12s/it, total_loss=0.1581, generator_loss=0.0000, fake_score_loss=0.1581, step_time=17.94s, grad_norm=None, ema=✓, ema2INFO 11-26 14:25:13 [training_pipeline.py:269] Starting epoch 1
Traceback (most recent call last):
  File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
    main(args)
  File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
    pipeline.train()
  File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 1526, in train
    training_batch = self.train_one_step(training_batch)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 995, in train_one_step
    assert param.grad is not None and param.grad.abs().sum() > 0
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
[rank1]: Traceback (most recent call last):
[rank1]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
[rank1]:     main(args)
[rank1]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
[rank1]:     pipeline.train()
[rank1]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 1526, in train
[rank1]:     training_batch = self.train_one_step(training_batch)
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 995, in train_one_step
[rank1]:     assert param.grad is not None and param.grad.abs().sum() > 0
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError
[rank0]: Traceback (most recent call last):
[rank0]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
[rank0]:     main(args)
[rank0]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
[rank0]:     pipeline.train()
[rank0]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 1526, in train
[rank0]:     training_batch = self.train_one_step(training_batch)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 995, in train_one_step
[rank0]:     assert param.grad is not None and param.grad.abs().sum() > 0
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError
[rank3]: Traceback (most recent call last):
[rank3]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
[rank3]:     main(args)
[rank3]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
[rank3]:     pipeline.train()
[rank3]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 1526, in train
[rank3]:     training_batch = self.train_one_step(training_batch)
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 995, in train_one_step
[rank3]:     assert param.grad is not None and param.grad.abs().sum() > 0
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: AssertionError
[rank2]: Traceback (most recent call last):
[rank2]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
[rank2]:     main(args)
[rank2]:   File "/ym/code/open_source/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
[rank2]:     pipeline.train()
[rank2]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 1526, in train
[rank2]:     training_batch = self.train_one_step(training_batch)
[rank2]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/ym/code/open_source/FastVideo/fastvideo/training/distillation_pipeline.py", line 995, in train_one_step
[rank2]:     assert param.grad is not None and param.grad.abs().sum() > 0
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: AssertionError
wandb: 
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /ym/code/open_source/FastVideo/outputs_dmd/wan_t2v_finetune_lora/tracker/wandb/offline-run-20251126_141800-lca8ivyb
wandb: Find logs at: outputs_dmd/wan_t2v_finetune_lora/tracker/wandb/offline-run-20251126_141800-lca8ivyb/logs
[rank0]:[W1126 14:25:42.383054998 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank3]:[W1126 14:25:43.855292502 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W1126 14:25:44.799390630 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W1126 14:25:45.762405038 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1126 14:25:48.547000 929752 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 929825 closing signal SIGTERM
W1126 14:25:48.547000 929752 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 929826 closing signal SIGTERM
W1126 14:25:48.549000 929752 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 929827 closing signal SIGTERM
E1126 14:25:52.789000 929752 site-packages/torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 0 (pid: 929824) of binary: /data/miniconda3/envs/fs-ym/bin/python3.1
Traceback (most recent call last):
  File "/data/miniconda3/envs/fs-ym/bin/torchrun", line 7, in <module>
    sys.exit(main())
             ^^^^^^
  File "/data/miniconda3/envs/fs-ym/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/data/miniconda3/envs/fs-ym/lib/python3.12/site-packages/torch/distributed/run.py", line 892, in main
    run(args)
  File "/data/miniconda3/envs/fs-ym/lib/python3.12/site-packages/torch/distributed/run.py", line 883, in run
    elastic_launch(
  File "/data/miniconda3/envs/fs-ym/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 139, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/miniconda3/envs/fs-ym/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
fastvideo/training/wan_distillation_pipeline.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-11-26_14:25:48
  host      : TENCENT64.site
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 929824)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Reproduction

The script I am using is as follows:

export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=offline
export WANDB_API_KEY=
export TRITON_CACHE_DIR=/tmp/triton_cache
DATA_DIR=data/mini_i2v_dataset/crush-smol_preprocessed/combined_parquet_dataset/
VALIDATION_DIR=data/mini_i2v_dataset/crush-smol_raw/validation.json
NUM_GPUS=4
export CUDA_VISIBLE_DEVICES=0,1,2,3
export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN
export TOKENIZERS_PARALLELISM=false
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
# MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"

# make sure that num_latent_t is a multiple of sp_size
torchrun --nnodes 1 --nproc_per_node $NUM_GPUS \
    --master_port 29502 \
    fastvideo/training/wan_distillation_pipeline.py \
    --model_path $MODEL_PATH \
    --real_score_model_path $MODEL_PATH \
    --fake_score_model_path $MODEL_PATH \
    --inference_mode False\
    --pretrained_model_name_or_path $MODEL_PATH \
    --cache_dir "/home/ray/.cache" \
    --data_path "$DATA_DIR" \
    --validation_dataset_file  "$VALIDATION_DIR" \
    --train_batch_size 1 \
    --num_latent_t 16 \
    --sp_size $NUM_GPUS \
    --tp_size 1 \
    --num_gpus $NUM_GPUS \
    --hsdp_replicate_dim $NUM_GPUS  \
    --hsdp-shard-dim 1 \
    --train_sp_batch_size 1 \
    --dataloader_num_workers 0 \
    --gradient_accumulation_steps 8 \
    --max_train_steps 30000 \
    --learning_rate 1e-4 \
    --mixed_precision "bf16" \
    --training_state_checkpointing_steps 400 \
    --weight_only_checkpointing_steps 400 \
    --validation_steps 100 \
    --validation_sampling_steps "3" \
    --log_validation \
    --checkpoints_total_limit 3 \
    --ema_start_step 0 \
    --training_cfg_rate 0.0 \
    --output_dir "outputs_dmd/wan_t2v_finetune_lora" \
    --tracker_project_name Wan_distillation \
    --num_height 448 \
    --num_width 832 \
    --num_frames 61 \
    --lora_rank 32 \
    --lora_training True \
    --flow_shift 8 \
    --validation_guidance_scale "6.0" \
    --master_weight_type "fp32" \
    --dit_precision "fp32" \
    --vae_precision "bf16" \
    --weight_decay 0.01 \
    --max_grad_norm 1.0 \
    --generator_update_interval 5 \
    --dmd_denoising_steps '1000,757,522' \
    --min_timestep_ratio 0.02 \
    --max_timestep_ratio 0.98 \
    --real_score_guidance_scale 3.5 \
    --seed 1024

Environment

My running environment is as follows:

PyTorch version: 2.7.1+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H20
GPU 1: NVIDIA H20
GPU 2: NVIDIA H20
GPU 3: NVIDIA H20

Nvidia driver version: 535.161.08
cuDNN version: /usr/local/cudnn/lib/libcudnn.so.9.5.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] accelerate==1.0.1
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-cufile-cu12==1.11.1.6
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-ml-py==13.580.82
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] peft==0.18.0
[pip3] torch==2.7.1
[pip3] torchcodec==0.5
[pip3] torchdata==0.11.0
[pip3] torchvision==0.22.1
[pip3] transformers==4.57.1
[pip3] triton==3.3.1
[conda] accelerate                1.0.1                    pypi_0    pypi
[conda] numpy                     2.2.6                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.6.4.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.6.77                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.5.1.17                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
[conda] nvidia-cufile-cu12        1.11.1.6                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.7.77                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.3                    pypi_0    pypi
[conda] nvidia-ml-py              13.580.82                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.26.2                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.6.85                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.6.77                  pypi_0    pypi
[conda] peft                      0.18.0                   pypi_0    pypi
[conda] torch                     2.7.1                    pypi_0    pypi
[conda] torchcodec                0.5                      pypi_0    pypi
[conda] torchdata                 0.11.0                   pypi_0    pypi
[conda] torchvision               0.22.1                   pypi_0    pypi
[conda] transformers              4.57.1                   pypi_0    pypi
[conda] triton                    3.3.1                    pypi_0    pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions