Skip to content

Commit f7909bf

Browse files
committed
Fix quantization
pull/281 broke the quantization for WAN
1 parent 7cbb714 commit f7909bf

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6262
@classmethod
6363
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
6464
pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer)
65-
transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh)
65+
pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh)
6666
return pipeline
6767

6868
@classmethod

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
7070
@classmethod
7171
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
7272
pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer)
73-
low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh)
74-
high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh)
73+
pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh)
74+
pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh)
7575
return pipeline
7676

7777
@classmethod

0 commit comments

Comments
 (0)