Skip to content

[Bug Report] TransformerBridge parameters() returns non-leaf tensors, breaking PyTorch optimizer compatibility #1141

@speediedan

Description

@speediedan

TransformerBridge.parameters() returns non-leaf tensors created by einops.rearrange(), which cannot be optimized by PyTorch optimizers. This breaks a fundamental PyTorch contract and interferes with users fine-tuning TransformerBridge models or using the standard PyTorch parameters() API for research requiring optimization.

The Problem

When users attempt to create an optimizer with TransformerBridge parameters, PyTorch raises:

optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-4)
# ValueError: can't optimize a non-leaf Tensor

Root Cause:

The current implementation delegates named_parameters() to return TransformerLens-style parameter names and tensors:

def named_parameters(self, prefix: str = "", recurse: bool = True, 
                     remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]:
    """Return named parameters in the same format as TransformerLens."""
    params_dict = self.get_params()  # Returns TL-style dict with rearranged tensors
    for name, param in params_dict.items():
        yield (name, param)

The get_params() method creates non-leaf tensors via einops.rearrange() to match TransformerLens conventions (e.g., reshaping attention weights from [d_model, n_heads * d_head] to [n_heads, d_model, d_head]). These rearranged tensors have grad_fn set and are not leaf tensors, making them invalid for optimization.

Why This Matters:

Adherence to PyTorch's standard parameter()/named_parameters() semantics will help enable expanded/unanticipated use cases of TransformerBridge models, align with PyTorch user expectations and can be accommodated without breaking existing analysis workflows by renaming the existing TL-style parameter access methods (I have a PR ready to address this linked to this issue).

Impact

This issue blocks:

  • Mechanistic interpretability training: Research requiring gradient-based updates to bridge models
  • Compatibility with standard PyTorch ecosystems: Integration with training libraries expecting standard parameter APIs

Existing Workarounds

Users could theoretically access bridge.original_model.parameters() directly, but this:

  • Violates the abstraction TransformerBridge provides
  • Returns HuggingFace parameter names instead of TransformerLens conventions
  • Breaks compatibility with tools expecting TL-style parameter access
  • Isn't discoverable—most users expect model.parameters() to "just work"

System Info

* CUDA:
	- GPU:
		- NVIDIA GeForce RTX 4090
		- NVIDIA GeForce RTX 2070 SUPER
	- available:         True
	- version:           12.8
* Packages:
	- circuit_tracer:    0.1.0
	- datasets:          4.4.1
	- finetuning_scheduler: 2.9.1
	- interpretune:      0.1.0.dev249+g84f4d5a9a.d20251129
	- lightning:         2.5.6
	- neuronpedia:       1.2.0
	- numpy:             2.3.5
	- sae_lens:          6.22.3
	- torch:             2.9.1+cu128
	- torch_debug_mode:  False
	- torch_git_version: 5811a8d7da873dd699ff6687092c225caffcf1bb
	- tqdm:              4.67.1
	- transformer_lens:  0.0.0  # using transformer_lens from source, latest commit on dev-3.x-folding
	- transformers:      4.57.1
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.12.8
	- version:           #13~22.04.1-Ubuntu SMP Wed Jan 24 23:39:40 UTC 2024

Checklist

  • I have checked that there is no similar issue in the repo (required)

I'll be submitting a PR shortly to address this issue. 🚀

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions