Skip to content

Commit d2a5e21

Browse files
committed
refactoring and fixes
1 parent 00d413e commit d2a5e21

5 files changed

Lines changed: 17 additions & 32 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,6 @@ def transformer_forward_pass_full_cfg(
832832
return noise_pred_merged, noise_cond, noise_uncond
833833

834834

835-
836835
@partial(jax.jit, static_argnames=("guidance_scale",))
837836
def transformer_forward_pass_cfg_cache(
838837
graphdef,
@@ -902,3 +901,13 @@ def transformer_forward_pass_cfg_cache(
902901

903902
noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx)
904903
return noise_pred_merged, noise_cond
904+
905+
def nearest_interp(src, target_len):
906+
"""Nearest neighbor interpolation for ratio scaling layout."""
907+
src_len = len(src)
908+
if target_len == 1:
909+
import numpy as np
910+
return np.array([src[-1]])
911+
import numpy as np
912+
indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32)
913+
return src[indices]

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional, Any
1818
from ...pyconfig import HyperParameters
@@ -233,7 +233,7 @@ def run_inference_2_1(
233233
cached_noise_cond = None
234234
cached_noise_uncond = None
235235

236-
if use_magcache:
236+
if use_magcache and do_classifier_free_guidance:
237237
# ── MagCache Execution Path ──
238238
accumulated_ratio_cond = 1.0
239239
accumulated_ratio_uncond = 1.0
@@ -249,12 +249,6 @@ def run_inference_2_1(
249249
mag_ratios_base = np.array([1.0]*2+[1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189])
250250

251251
if len(mag_ratios_base) != num_inference_steps * 2:
252-
def nearest_interp(src, target_len):
253-
src_len = len(src)
254-
if target_len == 1: return np.array([src[-1]])
255-
scale = (src_len - 1) / (target_len - 1)
256-
idx = np.round(np.arange(target_len) * scale).astype(int)
257-
return src[idx]
258252
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
259253
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
260254
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
@@ -466,12 +466,6 @@ def run_inference_2_2(
466466
mag_ratios_base = np.array([1.0]*2+[1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181])
467467

468468
if len(mag_ratios_base) != num_inference_steps * 2:
469-
def nearest_interp(src, target_len):
470-
src_len = len(src)
471-
if target_len <= 1: return np.array([src[-1]])
472-
scale = (src_len - 1) / (max(1, target_len - 1))
473-
idx = np.round(np.arange(target_len) * scale).astype(int)
474-
return src[idx]
475469
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
476470
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
477471
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)
@@ -534,7 +528,7 @@ def nearest_interp(src, target_len):
534528
cached_residual=cached_residual,
535529
return_residual=True,
536530
)
537-
noise_pred, _, residual_x_cur = outputs
531+
noise_pred, latents, residual_x_cur = outputs
538532
if not skip_blocks:
539533
cached_residual = residual_x_cur
540534

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from maxdiffusion import max_logging
1616
from maxdiffusion.image_processor import PipelineImageInput
17-
from .wan_pipeline import WanPipeline, transformer_forward_pass
17+
from .wan_pipeline import WanPipeline, transformer_forward_pass, nearest_interp
1818
from ...models.wan.transformers.transformer_wan import WanModel
1919
from typing import List, Union, Optional, Tuple
2020
from ...pyconfig import HyperParameters
@@ -289,12 +289,6 @@ def run_inference_2_1_i2v(
289289
mag_ratios_base = np.array([1.0]*2+[0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616])
290290

291291
if len(mag_ratios_base) != num_inference_steps * 2:
292-
def nearest_interp(src, target_len):
293-
src_len = len(src)
294-
if target_len <= 1: return np.array([src[-1]])
295-
scale = (src_len - 1) / (max(1, target_len - 1))
296-
idx = np.round(np.arange(target_len) * scale).astype(int)
297-
return src[idx]
298292
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
299293
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
300294
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from maxdiffusion.image_processor import PipelineImageInput
1616
from maxdiffusion import max_logging
17-
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
17+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp
1818
from ...models.wan.transformers.transformer_wan import WanModel
1919
from typing import List, Union, Optional, Tuple
2020
from ...pyconfig import HyperParameters
@@ -459,12 +459,6 @@ def run_inference_2_2_i2v(
459459
mag_ratios_base = np.array([1.0]*2+[0.99512, 0.99559, 0.99559, 0.99561, 0.99595, 0.99577, 0.99512, 0.99512, 0.99546, 0.99534, 0.99543, 0.99531, 0.99496, 0.99491, 0.99504, 0.99499, 0.99444, 0.99449, 0.99481, 0.99481, 0.99435, 0.99435, 0.9943, 0.99431, 0.99411, 0.99406, 0.99373, 0.99376, 0.99413, 0.99405, 0.99363, 0.99359, 0.99335, 0.99331, 0.99244, 0.99243, 0.99229, 0.99229, 0.99239, 0.99236, 0.99163, 0.9916, 0.99149, 0.99151, 0.99191, 0.99192, 0.9898, 0.98981, 0.9899, 0.98987, 0.98849, 0.98849, 0.98846, 0.98846, 0.98861, 0.98861, 0.9874, 0.98738, 0.98588, 0.98589, 0.98539, 0.98534, 0.98444, 0.98439, 0.9831, 0.98309, 0.98119, 0.98118, 0.98001, 0.98, 0.97862, 0.97859, 0.97555, 0.97558, 0.97392, 0.97388, 0.97152, 0.97145, 0.96871, 0.9687, 0.96435, 0.96434, 0.96129, 0.96127, 0.95639, 0.95638, 0.95176, 0.95175, 0.94446, 0.94452, 0.93972, 0.93974, 0.93575, 0.9359, 0.93537, 0.93552, 0.96655, 0.96616])
460460

461461
if len(mag_ratios_base) != num_inference_steps * 2:
462-
def nearest_interp(src, target_len):
463-
src_len = len(src)
464-
if target_len <= 1: return np.array([src[-1]])
465-
scale = (src_len - 1) / (max(1, target_len - 1))
466-
idx = np.round(np.arange(target_len) * scale).astype(int)
467-
return src[idx]
468462
mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps)
469463
mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps)
470464
mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1)
@@ -538,7 +532,7 @@ def nearest_interp(src, target_len):
538532
cached_residual=cached_residual,
539533
return_residual=True,
540534
)
541-
noise_pred, _, residual_x_cur = outputs
535+
noise_pred, latents, residual_x_cur = outputs
542536
if not skip_blocks:
543537
cached_residual = residual_x_cur
544538

0 commit comments

Comments
 (0)