Skip to content
Open

Add VSA #1053

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 142 additions & 1 deletion modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):


# Configuration with RULER calibration
# Note: threshold field is omitted - calibration determines dynamic threshold λ = a / length
# Note: threshold field is omitted - calibration determines dynamic threshold lambda = a / length
# The calibrated threshold adapts to sequence length for optimal sparsity
SKIP_SOFTMAX_CALIB = {
"sparse_cfg": {
Expand All @@ -432,13 +432,154 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
}


class VSAAttributeConfig(ModeloptBaseConfig):
"""Video Sparse Attention (VSA) attribute configuration.

VSA uses a two-branch architecture optimized for video diffusion models:
1. Compression branch: Block-averaged coarse attention
2. Sparse branch: Top-K block selection for fine-grained attention
"""

method: str = ModeloptField(
default="vsa",
title="Sparse attention method.",
description="Must be 'vsa' for Video Sparse Attention.",
)

enable: bool = ModeloptField(
default=True,
title="Enable VSA.",
description="If True, enables Video Sparse Attention. If False, bypasses sparsity.",
)

block_size_3d: tuple[int, int, int] | list[int] = ModeloptField(
default=(4, 4, 4),
title="3D block size.",
description=(
"Video block dimensions (T, H, W) for spatial-temporal tiling. "
"Default (4, 4, 4) creates 64-token blocks."
),
)

top_k_ratio: float = ModeloptField(
default=0.5,
title="Top-K selection ratio.",
description=(
"Ratio of blocks to keep in sparse branch (0.0 to 1.0). "
"Lower values mean more sparsity. Default 0.5 keeps 50% of blocks."
),
)

video_shape: tuple[int, int, int] | list[int] | None = ModeloptField(
default=None,
title="Video shape.",
description=(
"Video dimensions (T, H, W) after patchification. Required unless a "
"model-specific plugin computes it from the model's patchifier. "
"If None and no plugin provides a value, VSA will raise an error at "
"forward time."
),
)

collect_stats: bool = ModeloptField(
default=False,
title="Collect statistics.",
description="Whether to collect sparsity statistics during forward pass.",
)

@field_validator("method")
@classmethod
def validate_vsa_method(cls, v):
"""Validate method is 'vsa'."""
if v != "vsa":
raise ValueError(f"VSAAttributeConfig method must be 'vsa', got '{v}'")
return v

@field_validator("block_size_3d")
@classmethod
def validate_block_size_3d(cls, v):
"""Validate 3D block size."""
if isinstance(v, list):
v = tuple(v)
if len(v) != 3:
raise ValueError(f"block_size_3d must have 3 elements (T, H, W), got {len(v)}")
if any(x <= 0 for x in v):
raise ValueError(f"All block_size_3d values must be positive, got {v}")
return v

@field_validator("top_k_ratio")
@classmethod
def validate_top_k_ratio(cls, v):
"""Validate top-K ratio is in valid range."""
if not 0.0 < v <= 1.0:
raise ValueError(f"top_k_ratio must be in range (0, 1], got {v}")
return v

@field_validator("video_shape")
@classmethod
def validate_video_shape(cls, v):
"""Validate video shape if provided."""
if v is None:
return v
if isinstance(v, list):
v = tuple(v)
if len(v) != 3:
raise ValueError(f"video_shape must have 3 elements (T, H, W), got {len(v)}")
if any(x <= 0 for x in v):
raise ValueError(f"All video_shape values must be positive, got {v}")
return v


class VSAConfig(SparseAttentionConfig):
"""Configuration for Video Sparse Attention optimization.

VSA is designed for video diffusion models with learned gate_compress
parameters in attention layers.
"""

sparse_cfg: SparseAttentionCfgType = ModeloptField(
default={
"*attn*": {
"method": "vsa",
"block_size_3d": (4, 4, 4),
"top_k_ratio": 0.5,
"enable": True,
},
"default": {"enable": False},
},
title="VSA configuration",
description=(
"Pattern-based configuration for Video Sparse Attention. "
"Keys are patterns to match module names, values are VSA configs."
),
validate_default=True,
)


# Pre-defined VSA Configuration for video diffusion models.
# Pattern "*attn*" matches attention module names by convention.
VSA_DEFAULT = {
"sparse_cfg": {
"*attn*": {
"method": "vsa",
"block_size_3d": (4, 4, 4),
"top_k_ratio": 0.5,
"enable": True,
},
"default": {"enable": False},
},
}

__all__ = [
"SKIP_SOFTMAX_CALIB",
"SKIP_SOFTMAX_DEFAULT",
"VSA_DEFAULT",
"CalibrationConfig",
"FlashSkipSoftmaxConfig",
"SparseAttentionAttributeConfig",
"SparseAttentionCfgType",
"SparseAttentionConfig",
"SparseAttributeConfig",
"VSAAttributeConfig",
"VSAConfig",
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
]

# Import method implementations to trigger registration
from . import flash_skip_softmax
from . import flash_skip_softmax, vsa
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def get_threshold_info(self) -> dict[str, Any]:
"""
return {"type": "none", "value": None}

def set_calibration_mode(self, enabled: bool):
"""Set calibration mode. Override in subclasses that support calibration."""
self._calibration_mode = enabled

@property
@abstractmethod
def name(self) -> str:
Expand Down
Loading
Loading