-
Notifications
You must be signed in to change notification settings - Fork 7k
add SP support for flash_varlen_hub backend
#13479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
df8994d
2d12f46
003fa34
bc61551
86fec43
5034b2b
e05bb28
534fdc1
76414aa
1cd670b
db7b8d4
99e1660
3d8cbf4
1b39db4
04a1bf5
37a6db5
849062a
b042eb0
611fc52
4c83820
2823c24
7335596
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -374,6 +374,8 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) | |
| @is_context_parallel | ||
| @require_torch_multi_accelerator | ||
| class ContextParallelAttentionBackendsTesterMixin: | ||
| unsupported_attn_backends: list[str] = [] | ||
|
|
||
| @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) | ||
| @pytest.mark.parametrize( | ||
| "attention_backend", | ||
|
|
@@ -383,6 +385,10 @@ class ContextParallelAttentionBackendsTesterMixin: | |
| "flash_hub", | ||
| marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), | ||
| ), | ||
| pytest.param( | ||
| "flash_varlen_hub", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the |
||
| marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), | ||
| ), | ||
| pytest.param( | ||
| "_flash_3_hub", | ||
| marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), | ||
|
|
@@ -398,9 +404,14 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen | |
| if getattr(self.model_class, "_cp_plan", None) is None: | ||
| pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") | ||
|
|
||
| if attention_backend in self.unsupported_attn_backends: | ||
| pytest.skip(f"{attention_backend} is not supported for this model.") | ||
|
|
||
| if cp_type == "ring_degree": | ||
| if attention_backend == AttentionBackendName.NATIVE: | ||
| pytest.skip("Skipping test because ring isn't supported with native attention backend.") | ||
| elif attention_backend in ("flash_varlen_hub"): | ||
| pytest.skip("`ring_degree` is not yet supported for varlen attention hub kernels.") | ||
|
|
||
| if ulysses_anything and "ulysses" not in cp_type: | ||
| pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| AttentionTesterMixin, | ||
| BaseModelTesterConfig, | ||
| BitsAndBytesTesterMixin, | ||
| ContextParallelAttentionBackendsTesterMixin, | ||
| ContextParallelTesterMixin, | ||
| LoraHotSwappingForModelTesterMixin, | ||
| LoraTesterMixin, | ||
|
|
@@ -253,6 +254,16 @@ class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, | |
| """Context Parallel inference tests for QwenImage Transformer.""" | ||
|
|
||
|
|
||
| class TestQwenImageTransformerContextParallelAttnBackends( | ||
| QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin | ||
| ): | ||
| """Context Parallel inference x attention backends tests for QwenImage Transformer""" | ||
|
|
||
| # QwenImage always passes a joint attention mask (text + image), which flash_hub and | ||
| # _flash_3_hub do not support. | ||
| unsupported_attn_backends = ["flash_hub", "_flash_3_hub"] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any not varlen attention backend would fail no? If so, I would rather do something like if "varlen" not in attention_backend:
pytest.skip(...)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like FluxPipeline, it can also support I’m not sure what the most suitable place is to put this
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this because Qwen uses masks? Also, do the underlying tests use non-contiguous masks?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes, and I've made the comment more clear.
No it is simply an all true mask. Do I need to update it? |
||
|
|
||
|
|
||
| class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): | ||
| """LoRA adapter tests for QwenImage Transformer.""" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not supposed to guarded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be guard implictly at the top