-
Notifications
You must be signed in to change notification settings - Fork 490
Description
TransformerBridge.parameters() returns non-leaf tensors created by einops.rearrange(), which cannot be optimized by PyTorch optimizers. This breaks a fundamental PyTorch contract and interferes with users fine-tuning TransformerBridge models or using the standard PyTorch parameters() API for research requiring optimization.
The Problem
When users attempt to create an optimizer with TransformerBridge parameters, PyTorch raises:
optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-4)
# ValueError: can't optimize a non-leaf TensorRoot Cause:
The current implementation delegates named_parameters() to return TransformerLens-style parameter names and tensors:
def named_parameters(self, prefix: str = "", recurse: bool = True,
remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""Return named parameters in the same format as TransformerLens."""
params_dict = self.get_params() # Returns TL-style dict with rearranged tensors
for name, param in params_dict.items():
yield (name, param)The get_params() method creates non-leaf tensors via einops.rearrange() to match TransformerLens conventions (e.g., reshaping attention weights from [d_model, n_heads * d_head] to [n_heads, d_model, d_head]). These rearranged tensors have grad_fn set and are not leaf tensors, making them invalid for optimization.
Why This Matters:
Adherence to PyTorch's standard parameter()/named_parameters() semantics will help enable expanded/unanticipated use cases of TransformerBridge models, align with PyTorch user expectations and can be accommodated without breaking existing analysis workflows by renaming the existing TL-style parameter access methods (I have a PR ready to address this linked to this issue).
Impact
This issue blocks:
- Mechanistic interpretability training: Research requiring gradient-based updates to bridge models
- Compatibility with standard PyTorch ecosystems: Integration with training libraries expecting standard parameter APIs
Existing Workarounds
Users could theoretically access bridge.original_model.parameters() directly, but this:
- Violates the abstraction TransformerBridge provides
- Returns HuggingFace parameter names instead of TransformerLens conventions
- Breaks compatibility with tools expecting TL-style parameter access
- Isn't discoverable—most users expect
model.parameters()to "just work"
System Info
* CUDA:
- GPU:
- NVIDIA GeForce RTX 4090
- NVIDIA GeForce RTX 2070 SUPER
- available: True
- version: 12.8
* Packages:
- circuit_tracer: 0.1.0
- datasets: 4.4.1
- finetuning_scheduler: 2.9.1
- interpretune: 0.1.0.dev249+g84f4d5a9a.d20251129
- lightning: 2.5.6
- neuronpedia: 1.2.0
- numpy: 2.3.5
- sae_lens: 6.22.3
- torch: 2.9.1+cu128
- torch_debug_mode: False
- torch_git_version: 5811a8d7da873dd699ff6687092c225caffcf1bb
- tqdm: 4.67.1
- transformer_lens: 0.0.0 # using transformer_lens from source, latest commit on dev-3.x-folding
- transformers: 4.57.1
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.12.8
- version: #13~22.04.1-Ubuntu SMP Wed Jan 24 23:39:40 UTC 2024
Checklist
- I have checked that there is no similar issue in the repo (required)
I'll be submitting a PR shortly to address this issue. 🚀