-
Notifications
You must be signed in to change notification settings - Fork 72
[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX #398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mbohlool
wants to merge
1
commit into
main
Choose a base branch
from
text_encoder_tpu3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
src/maxdiffusion/models/ltx2/text_encoders/torchax_text_encoder.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| """ | ||
| Copyright 2026 Google LLC | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| import jax | ||
| from torchax import interop, default_env | ||
|
|
||
| # --- Monkeypatch transformers masking_utils to avoid torchax integer tracing bug --- | ||
| import transformers.masking_utils | ||
|
|
||
| _orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay | ||
|
|
||
|
|
||
| def _patched_sliding_window_overlay(sliding_window: int): | ||
| # pylint: disable=unused-argument | ||
|
|
||
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: | ||
| # Since sequence length < sliding window (e.g. 256 < 4096), this mask is always True. | ||
| # We return a standard boolean tensor using new_ones to guarantee Torchax compatibility | ||
| # and prevent any implicit tracing crashes. | ||
| return q_idx.new_ones((), dtype=torch.bool) | ||
|
|
||
| return inner_mask | ||
|
|
||
|
|
||
| transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay | ||
| # ----------------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TorchaxGemma3TextEncoder(interop.JittableModule): | ||
| """ | ||
| A jittable Torchax module for wrapping the HuggingFace PyTorch | ||
| Gemma3ForConditionalGeneration text encoder. | ||
| """ | ||
|
|
||
| def __init__(self, text_encoder): | ||
| super().__init__(text_encoder, extra_jit_args={"static_argnames": ["output_hidden_states"]}) | ||
|
|
||
| def __call__( | ||
| self, input_ids: jax.Array, attention_mask: jax.Array, output_hidden_states: bool = True | ||
| ) -> Tuple[jax.Array, ...]: | ||
| with default_env(): | ||
| input_ids = interop.torch_view(input_ids) | ||
| attention_mask = interop.torch_view(attention_mask) | ||
|
|
||
| output = self.functional_call( | ||
| self._forward_inner, | ||
| params=self.params, | ||
| buffers=self.buffers, | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| output_hidden_states=output_hidden_states, | ||
| ) | ||
| return interop.jax_view(output) | ||
|
|
||
| @staticmethod | ||
| def _forward_inner(model, input_ids, attention_mask, output_hidden_states=True): | ||
| # We only return hidden states as a tuple of tensors. | ||
| # That allows interop.jax_view to convert them into a tuple of jax Arrays | ||
| return model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states).hidden_states | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Global monkeypatching of
transformers.masking_utilscan have unintended side effects if other models in the same process rely on the original behavior. While this is a necessary workaround for TorchAX + Gemma-3, consider documenting the sequence length assumption more explicitly or ensuring this patch doesn't break other potential future Gemma-based models in the same environment.