Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions torchTextClassifiers/model/components/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def _get_sentence_embedding(

if self.enable_label_attention:
label_attention_result = self.label_attention_module(
token_embeddings, compute_attention_matrix=return_label_attention_matrix
token_embeddings,
attention_mask=attention_mask,
compute_attention_matrix=return_label_attention_matrix,
)
sentence_embedding = label_attention_result[
"sentence_embedding"
Expand Down Expand Up @@ -320,10 +322,11 @@ def __init__(self, config: TextEmbedderConfig):
self.c_v = nn.Linear(self.embedding_dim, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)

def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = False):
def forward(self, token_embeddings, attention_mask: Optional[torch.Tensor] = None, compute_attention_matrix: Optional[bool] = False):
"""
Args:
token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input.
attention_mask (torch.Tensor, optional), shape (batch, seq_len): Attention mask indicating non-pad tokens (1 for real tokens, 0 for padding).
compute_attention_matrix (bool): Whether to compute and return the attention matrix.
Returns:
dict: {
Expand Down Expand Up @@ -358,17 +361,36 @@ def forward(self, token_embeddings, compute_attention_matrix: Optional[bool] = F
v.transpose(1, 2),
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)

y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa)
# Prepare attention mask for scaled_dot_product_attention
# attention_mask: (B, T) with 1 for real tokens, 0 for padding
# scaled_dot_product_attention expects attn_mask: (B, H, Q, K) or broadcastable shape
# where True means "mask out" (ignore), False means "attend to"
attn_mask = None
if attention_mask is not None:
# Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to)
attn_mask = (attention_mask == 0) # (B, T)
# Expand to (B, 1, 1, T) for broadcasting across heads and queries
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)

y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False, enable_gqa=self.enable_gqa)

# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model)
y = self.c_proj(y)

attention_matrix = None
if compute_attention_matrix:
# size (B, n_head, n_labels, seq_len) - we let the user handle aggregation over heads if desired
attention_matrix = torch.softmax(
torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5), dim=-1
)
# Compute attention scores
# size (B, n_head, n_labels, seq_len)
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)

# Apply mask to attention scores before softmax
if attention_mask is not None:
# attn_mask is already in the right shape: (B, 1, 1, T)
# We need to apply it to scores of shape (B, n_head, n_labels, T)
# Set masked positions to -inf so they become 0 after softmax
attention_scores = attention_scores.masked_fill(attn_mask, float('-inf'))

attention_matrix = torch.softmax(attention_scores, dim=-1)

return {"sentence_embedding": y, "attention_matrix": attention_matrix}