Skip to content

Commit 04e267f

Browse files
committed
Fix softmax masking for illegal moves in inference
Multiplying logits by a binary 0/1 mask sets illegal move logits to 0, but softmax(0) is not zero — it still contributes probability mass. Use masked_fill with -inf so softmax correctly assigns exactly 0 probability to illegal moves. Fixes the issue raised in PR #9. https://claude.ai/code/session_01ELpknikQ4vWB4q8hWhxoVG
1 parent 16c1383 commit 04e267f

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

maia2/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_preds(model, dataloader, all_moves_dict_reversed):
6060
legal_moves = legal_moves.to(device)
6161

6262
logits_maia, _, logits_value = model(boards, elos_self, elos_oppo)
63-
logits_maia_legal = logits_maia * legal_moves
63+
logits_maia_legal = logits_maia.masked_fill(legal_moves == 0, float('-inf'))
6464
probs = logits_maia_legal.softmax(dim=-1).cpu().tolist()
6565

6666
logits_value = (logits_value / 2 + 0.5).clamp(0, 1).cpu().tolist()
@@ -154,7 +154,7 @@ def inference_each(model, prepared, fen, elo_self, elo_oppo):
154154
legal_moves = legal_moves.unsqueeze(dim=0).to(device)
155155

156156
logits_maia, _, logits_value = model(board_input, elo_self, elo_oppo)
157-
logits_maia_legal = logits_maia * legal_moves
157+
logits_maia_legal = logits_maia.masked_fill(legal_moves == 0, float('-inf'))
158158
probs = logits_maia_legal.softmax(dim=-1).cpu().tolist()
159159

160160
logits_value = (logits_value / 2 + 0.5).clamp(0, 1).item()

0 commit comments

Comments
 (0)