Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,28 @@ def run_full_pipeline(
explain_with_captum=True,
)

# Test label attention assertions
if label_attention_enabled:
assert predictions["label_attention_attributions"] is not None, (
"Label attention attributions should not be None when label_attention_enabled is True"
)
label_attention_attributions = predictions["label_attention_attributions"]
expected_shape = (
len(sample_text_data), # batch_size
model_params["n_head"], # n_head
model_params["num_classes"], # num_classes
tokenizer.output_dim, # seq_len
)
assert label_attention_attributions.shape == expected_shape, (
f"Label attention attributions shape mismatch. "
f"Expected {expected_shape}, got {label_attention_attributions.shape}"
)
else:
# When label attention is not enabled, the attributions should be None
assert predictions.get("label_attention_attributions") is None, (
"Label attention attributions should be None when label_attention_enabled is False"
)

# Test explainability functions
text_idx = 0
text = sample_text_data[text_idx]
Expand Down